天天看點

通過中心點生成heatmap

        使用2D高斯函數來建構學習目标(heatmap)。将某一關鍵點的ground-truth作為中心點,這樣一來,中心點處将具有最高的得分,越遠離中心點,得分将越低。公式地表示,則有

通過中心點生成heatmap

         将關鍵點的grounnd-truth轉換為mask,可以使用分割的方式預測k張feature-map對應k個關鍵點,最後在每個熱力圖中選擇最大值對應的坐标,即為model預測輸出的結果。

對應代碼段如下所示:

        參考github:hourglass代碼:

                https://github.com/raymon-tian/hourglass-facekeypoints-detection

def _putGaussianMap(self, center, visible_flag, crop_size_y, crop_size_x, stride, sigma):
        """
        根據一個中心點,生成一個heatmap
        :param center:
        :return:
        """
        grid_y = crop_size_y / stride
        grid_x = crop_size_x / stride
        if visible_flag == False:
            return np.zeros((grid_y,grid_x))
        start = stride / 2.0 - 0.5
        y_range = [i for i in range(grid_y)]
        x_range = [i for i in range(grid_x)]
        xx, yy = np.meshgrid(x_range, y_range)
        xx = xx * stride + start
        yy = yy * stride + start
        d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2
        exponent = d2 / 2.0 / sigma / sigma
        heatmap = np.exp(-exponent)
        return heatmap

    def _putGaussianMaps(self,keypoints,crop_size_y, crop_size_x, stride, sigma):
        """
        :param keypoints: (15,2)
        :param crop_size_y: int
        :param crop_size_x: int
        :param stride: int
        :param sigma: float
        :return:
        """
        all_keypoints = keypoints
        point_num = all_keypoints.shape[0]
        heatmaps_this_img = []
        for k in range(point_num):
            flag = ~np.isnan(all_keypoints[k,0])
            heatmap = self._putGaussianMap(all_keypoints[k],flag,crop_size_y,crop_size_x,stride,sigma)
            heatmap = heatmap[np.newaxis,...]
            heatmaps_this_img.append(heatmap)
        heatmaps_this_img = np.concatenate(heatmaps_this_img,axis=0) # (num_joint,crop_size_y/stride,crop_size_x/stride)
        return heatmaps_this_img

    def visualize_heatmap_target(self,oriImg,heatmap,stride):

        plt.imshow(oriImg)
        plt.imshow(heatmap, alpha=.5)
        plt.show()
           

繼續閱讀