天天看點

Tensorflow2.0---SSD網絡原理及代碼解析(五)- 損失函數的計算Tensorflow2.0—SSD網絡原理及代碼解析(五)- 損失函數的計算

Tensorflow2.0—SSD網絡原理及代碼解析(五)- 損失函數的計算

前面寫了三篇關于SSD代碼的講解,還有最後一個關鍵代碼—損失函數的計算,廢話不多說,直接上幹貨~

這行代碼是進行損失計算函數的調用。

Tensorflow2.0---SSD網絡原理及代碼解析(五)- 損失函數的計算Tensorflow2.0—SSD網絡原理及代碼解析(五)- 損失函數的計算

損失函數被包裝成一個MultiboxLoss類,最後一個compute_loss方法用于調用計算。

首先,先搞清楚y_true, y_pred分别是個啥shape的,其實二者的shape都是(2,8732,33)。2表示的是batch_size,8732表示的是每張圖檔的錨點框,33表示每個訓練圖檔進行encode之後的結果。

#   分類的loss
        #   batch_size,8732,21 -> batch_size,8732
        # --------------------------------------------- #
        conf_loss = self._softmax_loss(y_true[:, :, 4:-8],
                                       y_pred[:, :, 4:-8])
           

先計算所有真實框(其實就是錨點框基于真實框進行encode之後)與預測框的分類的loss。這裡用的是softmax。

def _softmax_loss(self, y_true, y_pred):
        y_pred = tf.maximum(y_pred, 1e-7)
        softmax_loss = -tf.reduce_sum(y_true * tf.math.log(y_pred),
                                      axis=-1)
        return softmax_loss
           

輸入的shape為(2,8732,21),輸出的shape為(2,8732)。

然後,計算所有真實框(其實就是錨點框基于真實框進行encode之後)與預測框的坐标的loss,這裡用的是l1平滑損失函數。

什麼是L_1損失函數呢???

Tensorflow2.0---SSD網絡原理及代碼解析(五)- 損失函數的計算Tensorflow2.0—SSD網絡原理及代碼解析(五)- 損失函數的計算
https://blog.csdn.net/weixin_41940752/article/details/93159710

代碼實作:

def _l1_smooth_loss(self, y_true, y_pred):
        abs_loss = tf.abs(y_true - y_pred)  # y_1 = |y_t - y_p|
        sq_loss = 0.5 * (y_true - y_pred)**2  #y_2 = 0.5 * (y_t - y_p)^2
        l1_loss = tf.where(tf.less(abs_loss, 1.0), sq_loss, abs_loss - 0.5)
        return tf.reduce_sum(l1_loss, -1)
           

輸入的shape為(2,8732,4),輸出的shape為(2,8732)。

接下來,計算所有正樣本的先驗框的loss:

#   擷取所有的正标簽的loss
        # --------------------------------------------- #
        pos_loc_loss = tf.reduce_sum(loc_loss * y_true[:, :, -8],
                                     axis=1)
        pos_conf_loss = tf.reduce_sum(conf_loss * y_true[:, :, -8],
                                      axis=1)
           

loc_loss和y_true[:, :, -8]的shape為(2,8732),按照最後一維進行先相乘,然後相加,最後得到了shape為(2,)的pos_loc_loss 和pos_conf_loss 。

現在,損失就是一個:負樣本的conf的loss。

# --------------------------------------------- #
        #   每一張圖的正樣本的個數
        #   batch_size,
        # --------------------------------------------- #
        num_pos = tf.reduce_sum(y_true[:, :, -8], axis=-1)  #計算每個批次中每個圖正樣本的個數

        # --------------------------------------------- #
        #   每一張圖的負樣本的個數
        #   batch_size,
        # --------------------------------------------- #
        num_neg = tf.minimum(self.neg_pos_ratio * num_pos,
                             num_boxes - num_pos)
        # 找到了哪些值是大于0的
        pos_num_neg_mask = tf.greater(num_neg, 0)
        # --------------------------------------------- #
        #   如果有些圖,它的正樣本數量為0,
        #   預設負樣本為100
        # --------------------------------------------- #
        has_min = tf.cast(tf.reduce_any(pos_num_neg_mask),tf.float32)
        num_neg = tf.concat(axis=0, values=[num_neg, [(1 - has_min) * self.negatives_for_hard]])

           

以上這麼一大堆代碼,其實任務很簡單,就是在找負樣本,如果有正樣本,那麼就取3倍的負樣本。如果沒有正樣本,那麼就篩選出100個負樣本。

# --------------------------------------------- #
        num_neg_batch = tf.reduce_sum(tf.boolean_mask(num_neg, tf.greater(num_neg, 0)))
        num_neg_batch = tf.cast(num_neg_batch,tf.int32)   #一個批次中所有的負樣本的個數

        # --------------------------------------------- #
        #   對預測結果進行判斷,如果該先驗框沒有包含物體
        #   那麼它的不屬于背景的預測機率過大的話
        #   就是難分類樣本
        # --------------------------------------------- #
        confs_start = 4 + self.background_label_id + 1
        confs_end = confs_start + self.num_classes - 1

        # --------------------------------------------- #
        #   batch_size,8732
        #   把不是背景的機率求和,求和後的機率越大
        #   代表越難分類。
        # --------------------------------------------- #
        max_confs = tf.reduce_sum(y_pred[:, :, confs_start:confs_end], axis=2)

        # --------------------------------------------------- #
        #   隻有沒有包含物體的先驗框才得到保留
        #   我們在整個batch裡面選取最難分類的num_neg_batch個
        #   先驗框作為負樣本。
        # --------------------------------------------------- #
        max_confs = tf.reshape(max_confs * (1 - y_true[:, :, -8]), [-1])
        _, indices = tf.nn.top_k(max_confs, k=num_neg_batch)

        neg_conf_loss = tf.gather(tf.reshape(conf_loss, [-1]), indices)
           

這一步比較難了解,我個人了解為:先找到每個批次中所有負樣本的數量,然後計算所有預測框的不是背景的機率進行求和,求和後的機率越大,代表越難分類。(我認為可以這麼了解:除了背景的機率,其他機率相加越大,說明這個預測框就是屬于那種越難分類的,比如說一隻狗,預測出它為貓為0.3,狗為0.4,老虎為0.35,像這種的就很難差別出到底是哪個動物。)然後,我們在整個batch裡面選取最難分類的num_neg_batch個先驗框作為負樣本。

最後,将三個損失進行相加,并歸一化。

# 進行歸一化
        num_pos     = tf.where(tf.not_equal(num_pos, 0), num_pos, tf.ones_like(num_pos))
        total_loss  = tf.reduce_sum(pos_conf_loss) + tf.reduce_sum(neg_conf_loss) + tf.reduce_sum(self.alpha * pos_loc_loss)
        total_loss /= tf.reduce_sum(num_pos)
        return total_loss
           

繼續閱讀