天天看点

通过中心点生成heatmap

        使用2D高斯函数来构建学习目标(heatmap)。将某一关键点的ground-truth作为中心点,这样一来,中心点处将具有最高的得分,越远离中心点,得分将越低。公式地表示,则有

通过中心点生成heatmap

         将关键点的grounnd-truth转换为mask,可以使用分割的方式预测k张feature-map对应k个关键点,最后在每个热力图中选择最大值对应的坐标,即为model预测输出的结果。

对应代码段如下所示:

        参考github:hourglass代码:

                https://github.com/raymon-tian/hourglass-facekeypoints-detection

def _putGaussianMap(self, center, visible_flag, crop_size_y, crop_size_x, stride, sigma):
        """
        根据一个中心点,生成一个heatmap
        :param center:
        :return:
        """
        grid_y = crop_size_y / stride
        grid_x = crop_size_x / stride
        if visible_flag == False:
            return np.zeros((grid_y,grid_x))
        start = stride / 2.0 - 0.5
        y_range = [i for i in range(grid_y)]
        x_range = [i for i in range(grid_x)]
        xx, yy = np.meshgrid(x_range, y_range)
        xx = xx * stride + start
        yy = yy * stride + start
        d2 = (xx - center[0]) ** 2 + (yy - center[1]) ** 2
        exponent = d2 / 2.0 / sigma / sigma
        heatmap = np.exp(-exponent)
        return heatmap

    def _putGaussianMaps(self,keypoints,crop_size_y, crop_size_x, stride, sigma):
        """
        :param keypoints: (15,2)
        :param crop_size_y: int
        :param crop_size_x: int
        :param stride: int
        :param sigma: float
        :return:
        """
        all_keypoints = keypoints
        point_num = all_keypoints.shape[0]
        heatmaps_this_img = []
        for k in range(point_num):
            flag = ~np.isnan(all_keypoints[k,0])
            heatmap = self._putGaussianMap(all_keypoints[k],flag,crop_size_y,crop_size_x,stride,sigma)
            heatmap = heatmap[np.newaxis,...]
            heatmaps_this_img.append(heatmap)
        heatmaps_this_img = np.concatenate(heatmaps_this_img,axis=0) # (num_joint,crop_size_y/stride,crop_size_x/stride)
        return heatmaps_this_img

    def visualize_heatmap_target(self,oriImg,heatmap,stride):

        plt.imshow(oriImg)
        plt.imshow(heatmap, alpha=.5)
        plt.show()
           

继续阅读