天天看點

使用object detection訓練并識别自己的模型

1.安裝tensorflow(version>=1.4.0)

2.部署tensorflow models

  - 在這裡下載下傳

  - 解壓并安裝

    - 解壓後重命名為models複制到tensorflow/目錄下

    - 在linux下

      - 進入tensorflow/models/research/目錄,運作protoc object_detection/protos/*.proto --python_out=.

      - 在~/.bashrc file.中添加slim和models/research路徑

      export PYTHONPATH=$PYTHONPATH:/path/to/slim:/path/to/research

    - 在windows下

      - 下載下傳protoc-3.3.0-win32.zip(version==3.3,已知3.5版本會報錯) 

      - 解壓後将protoc.exe放入C:\Windows下

      - 在tensorflow/models/research/打開powershell,運作protoc object_detection/protos/*.proto --python_out=.

3.訓練資料準備(标記分類的圖檔)

  - 安裝labelImg 用來手動标注圖檔 ,圖檔需要是png或者jpg格式

  - 标注資訊會被儲存為xml檔案,使用 這個腳本 将所有xml檔案轉換為一個csv檔案(xml檔案路徑識别在29行,根據情況自己修改)

  - 把生成的csv檔案分成訓練集和測試集

4.生成TFRecord檔案

  - 使用 這個腳本 将兩個csv檔案生成出兩個TFRecord檔案(訓練自己的模型,必須使用TFRecord格式檔案。圖檔路徑識别在86行,根據情況自己修改)

5.建立label map檔案

  id需要從1開始,class-N便是自己需要識别的物體類别名,檔案字尾為.pbtxt

    item{

      id:1

      name: 'class-1'

      }

    item{

      id:2

      name: 'class-2'

      }

6.下載下傳模型并配置檔案

  - 下載下傳一個模型(檔案字尾.tar.gz)

  - 修改對應的訓練pipline配置檔案 

    - 查找檔案中的PATH_TO_BE_CONFIGURED字段,并做相應修改

      - num_classes 改為你模型中包含類别的數量

      - fine_tune_checkpoint 解壓.tar.gz檔案後的路徑 + /model.ckpt

      - from_detection_checkpoint:true

      - train_input_reader

        - input_path 由train.csv生成的record格式訓練資料

        - label_map_path 第5步建立的pbtxt檔案路徑

      - eval_input_reader

        - input_path 由test.csv生成的record格式訓練資料

        - label_map_path 第5步建立的pbtxt檔案路徑

7. 訓練模型

  - 進入tensorflow/models/research/目錄,運作

  python object_detection/train.py --logtostderr  --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} //第六步中修改的pipline配置檔案路徑// --train_dir=${PATH_TO_TRAIN_DIR} //生成的模型儲存路徑//

8.導出模型

  - 在第7步中,--train_dir指向的路徑中會生成一系列訓練中自動儲存的checkpoint,一個checkpoint由三個檔案組成,字尾分别是.data-00000-of-00001 .index和.meta,任然在第7步的路徑中,運作

    python object_detection/export_inference_graph.py \

--input_type image_tensor  \

--pipeline_config_path ${PIPELINE_CONFIG_PATH} //第六步中修改的pipline配置檔案路徑\--trained_checkpoint_prefix ${TRAIN_PATH} //上述的一個checkpoint,例如model.ckpt-112254 \ --output_directory ${OUTPUT_PATH} //輸出模型檔案的路徑//

使用object detection訓練并識别自己的模型

9.使用新模型識别圖檔

  調用predict.py

首先導入包

import time
import cv2
import numpy as np
import tensorflow as tf
import pandas as pd
import math
import os

from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util      

然後定義類和函數

class TOD(object):
    def __init__(self):
        self.PATH_TO_CKPT = r'D:/xiangchuang/new_train_model/result/frozen_inference_graph.pb'
        self.PATH_TO_LABELS = r'D:/xiangchuang/pig.pbtxt'
        self.NUM_CLASSES = 1
        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()


    def _load_model(self):
        global detection_graph
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph


    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map,
                                                                    max_num_classes=self.NUM_CLASSES,
                                                                    use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        image_np_expanded = np.expand_dims(image, axis=0)
        image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
        boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
        scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
        classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
        num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
            image,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            self.category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
        cv2.imshow("detection", image)
        cv2.waitKey(1)      

最後執行

if __name__ == '__main__':
    detector = TOD()
    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            cap = cv2.VideoCapture(r'Your Vedio Path')
            n = 1
            success = True
            while  (success) :
                success, frame = cap.read()
                t1=time.clock()
                print('正在預測第%d張' % n)
                n = n + 1
                if success == True:
                    detector.detect(frame)
                t2=time.clock()
                t = t2-t1
                print('cost time %f s'%t)

    cv2.destroyAllWindows()
      

即可以實作基于視訊的目标目标檢測

參考文檔

https://gist.github.com/douglasrizzo/c70e186678f126f1b9005ca83d8bd2ce

https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9

繼續閱讀