天天看點

faster rcnn源碼解讀(五)之layer(網絡裡的input-data)

轉載自:faster rcnn源碼解讀(五)之layer(網絡裡的input-data) - 野孩子的專欄 - 部落格頻道 - CSDN.NET

http://blog.csdn.net/u010668907/article/details/51945844

faster rcnn用python版本的https://github.com/rbgirshick/py-faster-rcnn

layer源碼位址:https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/roi_data_layer/layer.py

源碼:

[python]  view plain  copy  print ?

faster rcnn源碼解讀(五)之layer(網絡裡的input-data)
faster rcnn源碼解讀(五)之layer(網絡裡的input-data)
  1. # --------------------------------------------------------  
  2. # Fast R-CNN  
  3. # Copyright (c) 2015 Microsoft  
  4. # Licensed under The MIT License [see LICENSE for details]  
  5. # Written by Ross Girshick  
  6. # --------------------------------------------------------  
  7. """The data layer used during training to train a Fast R-CNN network. 
  8. RoIDataLayer implements a Caffe Python layer. 
  9. """  
  10. import caffe  
  11. from fast_rcnn.config import cfg  
  12. from roi_data_layer.minibatch import get_minibatch  
  13. import numpy as np  
  14. import yaml  
  15. from multiprocessing import Process, Queue  
  16. class RoIDataLayer(caffe.Layer):  
  17.     """Fast R-CNN data layer used for training."""  
  18.     def _shuffle_roidb_inds(self):  
  19.         """Randomly permute the training roidb."""  
  20.         if cfg.TRAIN.ASPECT_GROUPING:  
  21.             widths = np.array([r['width'] for r in self._roidb])  
  22.             heights = np.array([r['height'] for r in self._roidb])  
  23.             horz = (widths >= heights)  
  24.             vert = np.logical_not(horz)  
  25.             horz_inds = np.where(horz)[0]  
  26.             vert_inds = np.where(vert)[0]  
  27.             inds = np.hstack((  
  28.                 np.random.permutation(horz_inds),  
  29.                 np.random.permutation(vert_inds)))  
  30.             inds = np.reshape(inds, (-1, 2))  
  31.             row_perm = np.random.permutation(np.arange(inds.shape[0]))  
  32.             inds = np.reshape(inds[row_perm, :], (-1,))  
  33.             self._perm = inds#把roidb的索引打亂,造成的shuffle,打亂的索引存儲的地方  
  34.         else:  
  35.             self._perm = np.random.permutation(np.arange(len(self._roidb)))  
  36.         self._cur = 0  
  37.     def _get_next_minibatch_inds(self):  
  38.         """Return the roidb indices for the next minibatch."""  
  39.         if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):  
  40.             self._shuffle_roidb_inds()  
  41.         db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]  
  42.         self._cur += cfg.TRAIN.IMS_PER_BATCH#相當于一個指向_perm的指針,每次取走圖檔後,他會跟着變化#cfg.TRAIN.IMS_PER_BATCH: (猜測,每次取圖檔的數量)  
  43.         return db_inds#本次取得圖檔的索引  
  44.     def _get_next_minibatch(self):#取得本次圖檔的索引,即db_inds  
  45.         """Return the blobs to be used for the next minibatch. 
  46.         If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a 
  47.         separate process and made available through self._blob_queue. 
  48.         """  
  49.         if cfg.TRAIN.USE_PREFETCH:  
  50.             return self._blob_queue.get()  
  51.         else:  
  52.             db_inds = self._get_next_minibatch_inds()  
  53.             minibatch_db = [self._roidb[i] for i in db_inds]#本次的roidb  
  54.             return get_minibatch(minibatch_db, self._num_classes)  
  55.     def set_roidb(self, roidb):  
  56.         """Set the roidb to be used by this layer during training."""  
  57.         self._roidb = roidb  
  58.         self._shuffle_roidb_inds()  
  59.         if cfg.TRAIN.USE_PREFETCH:  
  60.             self._blob_queue = Queue(10)  
  61.             self._prefetch_process = BlobFetcher(self._blob_queue,  
  62.                                                  self._roidb,  
  63.                                                  self._num_classes)  
  64.             self._prefetch_process.start()  
  65.             # Terminate the child process when the parent exists  
  66.             def cleanup():  
  67.                 print 'Terminating BlobFetcher'  
  68.                 self._prefetch_process.terminate()  
  69.                 self._prefetch_process.join()  
  70.             import atexit  
  71.             atexit.register(cleanup)  
  72.     def setup(self, bottom, top):  
  73.         """Setup the RoIDataLayer."""  
  74.         # parse the layer parameter string, which must be valid YAML  
  75.         layer_params = yaml.load(self.param_str_)  
  76.         self._num_classes = layer_params['num_classes']#網絡裡的類别數值21  
  77.         self._name_to_top_map = {}#{'gt_boxes': 2, 'data': 0, 'im_info': 1}字典的value值是top的對應索引  
  78.         # data blob: holds a batch of N images, each with 3 channels  
  79.         idx = 0  
  80.         top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3,  
  81.             max(cfg.TRAIN.SCALES), cfg.TRAIN.MAX_SIZE)  
  82.         self._name_to_top_map['data'] = idx  
  83.         idx += 1  
  84.         if cfg.TRAIN.HAS_RPN:  
  85.             top[idx].reshape(1, 3)  
  86.             self._name_to_top_map['im_info'] = idx  
  87.             idx += 1  
  88.             top[idx].reshape(1, 4)  
  89.             self._name_to_top_map['gt_boxes'] = idx  
  90.             idx += 1  
  91.         else: # not using RPN  
  92.             # rois blob: holds R regions of interest, each is a 5-tuple  
  93.             # (n, x1, y1, x2, y2) specifying an image batch index n and a  
  94.             # rectangle (x1, y1, x2, y2)  
  95.             top[idx].reshape(1, 5)  
  96.             self._name_to_top_map['rois'] = idx  
  97.             idx += 1  
  98.             # labels blob: R categorical labels in [0, ..., K] for K foreground  
  99.             # classes plus background  
  100.             top[idx].reshape(1)  
  101.             self._name_to_top_map['labels'] = idx  
  102.             idx += 1  
  103.             if cfg.TRAIN.BBOX_REG:  
  104.                 # bbox_targets blob: R bounding-box regression targets with 4  
  105.                 # targets per class  
  106.                 top[idx].reshape(1, self._num_classes * 4)  
  107.                 self._name_to_top_map['bbox_targets'] = idx  
  108.                 idx += 1  
  109.                 # bbox_inside_weights blob: At most 4 targets per roi are active;  
  110.                 # thisbinary vector sepcifies the subset of active targets  
  111.                 top[idx].reshape(1, self._num_classes * 4)  
  112.                 self._name_to_top_map['bbox_inside_weights'] = idx  
  113.                 idx += 1  
  114.                 top[idx].reshape(1, self._num_classes * 4)  
  115.                 self._name_to_top_map['bbox_outside_weights'] = idx  
  116.                 idx += 1  
  117.         print 'RoiDataLayer: name_to_top:', self._name_to_top_map  
  118.         assert len(top) == len(self._name_to_top_map)  
  119.     def forward(self, bottom, top):  
  120.         """Get blobs and copy them into this layer's top blob vector."""  
  121.         blobs = self._get_next_minibatch()  
  122.         for blob_name, blob in blobs.iteritems():  
  123.             top_ind = self._name_to_top_map[blob_name]  
  124.             # Reshape net's input blobs  
  125.             top[top_ind].reshape(*(blob.shape))  
  126.             # Copy data into net's input blobs  
  127.             top[top_ind].data[...] = blob.astype(np.float32, copy=False)  
  128.     def backward(self, top, propagate_down, bottom):  
  129.         """This layer does not propagate gradients."""  
  130.         pass  
  131.     def reshape(self, bottom, top):  
  132.         """Reshaping happens during the call to forward."""  
  133.         pass  
  134. class BlobFetcher(Process):  
  135.     """Experimental class for prefetching blobs in a separate process."""  
  136.     def __init__(self, queue, roidb, num_classes):  
  137.         super(BlobFetcher, self).__init__()  
  138.         self._queue = queue  
  139.         self._roidb = roidb  
  140.         self._num_classes = num_classes  
  141.         self._perm = None  
  142.         self._cur = 0  
  143.         self._shuffle_roidb_inds()  
  144.         # fix the random seed for reproducibility  
  145.         np.random.seed(cfg.RNG_SEED)  
  146.     def _shuffle_roidb_inds(self):  
  147.         """Randomly permute the training roidb."""  
  148.         # TODO(rbg): remove duplicated code  
  149.         self._perm = np.random.permutation(np.arange(len(self._roidb)))  
  150.         self._cur = 0  
  151.     def _get_next_minibatch_inds(self):  
  152.         """Return the roidb indices for the next minibatch."""  
  153.         # TODO(rbg): remove duplicated code  
  154.         if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):  
  155.             self._shuffle_roidb_inds()  
  156.         db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]  
  157.         self._cur += cfg.TRAIN.IMS_PER_BATCH  
  158.         return db_inds  
  159.     def run(self):  
  160.         print 'BlobFetcher started'  
  161.         while True:  
  162.             db_inds = self._get_next_minibatch_inds()  
  163.             minibatch_db = [self._roidb[i] for i in db_inds]  
  164.             blobs = get_minibatch(minibatch_db, self._num_classes)  
  165.             self._queue.put(blobs)  

下面的roidb都隻是一次batch的

3.1 setup在caffe.SGDSolver時調用;setup的top(list猜測是c++的vector)的每個項是caffe._caffe.Blob

(猜測,輸出的Top shape就是上面的top,在setup中被shape;top[0],1 3 [600] 1000;top[1],1 3;top[2], 1 4)(疑問,在forward中blob的資料shape被重置,有時大小甚至會不定)

  3.2 name_to_top: {'gt_boxes': 2, 'data': 0, 'im_info': 1}字典的value值是top的對應索引

  3.3 solver.step(1)會調用layer的reshape、forward

  3.4 self._perm: 把roidb的索引打亂,造成圖檔的shuffle,打亂的索引存儲的地方

  3.5 cfg.TRAIN.IMS_PER_BATCH: (猜測,每次取圖檔的數量)

  3.6 self._cur: 相當于一個指向_perm的指針,每次取走圖檔後,他會跟着變化

  3.7 db_inds: 本次取得圖檔的索引

  3.8 def _get_next_minibatch_inds(self): 取得本次圖檔的索引,即db_inds

  3.9 minibatch_db: 本次的roidb

  3.10 _num_classes: 網絡裡的類别數值21

  3.11 forward(): 得到blob并處理放進top

solver.step(1)-》reshape-》forward-》_get_next_minbatch-》_get_next_minbatch_inds-》(前面在layers裡,現在進入minibatch組建真正的blob)get_minibatch

下一篇: SPP-Net