首先,簡單介紹下,Tensorflow Object Detection API是一個建構在TensorFlow之上的開源架構,它使建構、訓練和部署對象檢測模型變得很容易
首先,關于win10下深度學習基本環境的搭建,比如,anaconda, Tensorflow CPU或GPU版本,pycharm等安裝這塊就不說了,網上的教程很多。
額外需要的python庫有 pillow, lxml,可以通過pip install 指令進行安裝
1.Tensorflow Object Detection API 下載下傳
https://github.com/tensorflow/models,直接從github上下載下傳源碼
2.Protoc下載下傳
Protoc是用來将下載下傳來的 中的 object_detection/protos目錄下的proto檔案編譯為py檔案
WIN下,建議下載下傳3.4的版本,下載下傳連結
下載下傳完成後,将對應目錄的bin檔案夾目錄添加到環境變量中
cmd打開指令行,輸入 protoc,顯示如下内容說明安裝成功
3.object_detection\protos目錄下的檔案編譯
将之前下載下傳好的Tensorflow Object Detection檔案解壓,指令行cd進入models-master\research目錄下,然後執行指令
protoc ./object_detection/protos/*.proto --python_out=.
将object_detection/protos目錄下的proto檔案編譯為py檔案,
執行完畢後,進入object_detection/protos目錄下檢視,可以看到生成了對應的py檔案
4.使用訓練好的目标檢測模型完成目标檢測任務
首先,在Pycharm中重新建立一個你的新項目,我這塊項目名稱為 using_pre-trained_model_to_detect_objects,然後将下載下傳的Tensorflow Object Detection中的models-master\research\object_detection拷貝進using_pre-trained_model_to_detect_objects新項目中
在項目中建立 object_detection_tutorial.py 檔案用來進行目标檢測,項目結構為:
預測程式如下,需要注意相關路徑問題:
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
import matplotlib.pyplot as plt
from PIL import Image
from object_detection.utils import ops as utils_ops
if StrictVersion(tf.__version__) < StrictVersion('1.12.0'):
raise ImportError('Please upgrade your TensorFlow installation to v1.12.*.')
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
# pb模型存放位置.
PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'
# coco資料集的label映射檔案
PATH_TO_LABELS = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt')
PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
#模型下載下傳與解壓
def downloadModel():
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
#加載模型
def loadModel():
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
#将圖檔轉換為三維數組,資料類型為uint8
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
#進行目标檢測
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
# Get handles to input and output tensors
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in [
'num_detections', 'detection_boxes', 'detection_scores',
'detection_classes'
]:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
tensor_name)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Run inference
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: image})
# all outputs are float32 numpy arrays, so convert types as appropriate
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict[
'detection_classes'][0].astype(np.int64)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
return output_dict
def predict(detection_graph):
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
output_dict = run_inference_for_single_image(image_np_expanded, detection_graph)
# 得到一個儲存編号和類别描述映射關系的清單
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
line_thickness=8)
plt.figure(figsize=(12, 8))
plt.imshow(image_np)
plt.axis('off')
plt.show()
if __name__ == '__main__':
# downloadModel()
detection_graph = loadModel()
predict(detection_graph)
輸出結果為:
可以看到,成功檢測到了相關物體。
歡迎關注我的個人公衆号 AI計算機視覺工坊,本公衆号不定期推送機器學習,深度學習,計算機視覺等相關文章,歡迎大家和我一起學習,交流。