天天看点

ArcFace(InsightFace)pytorch代码实现

ArcFace是比较新的人脸分类的Loss函数,详细论文可以看论文:ArcFace: Additive Angular Margin Loss for Deep Face Recognition

论文: https://arxiv.org/abs/1801.07698

官方代码: https://github.com/deepinsight/insightface

本文主要对代码进行讲解和注释。

class ArcMarginModel(nn.Module):
    def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):
        super(ArcMarginModel, self).__init__()

        self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))
        # num_classes 训练集中总的人脸分类数
        # emb_size 特征向量长度
        nn.init.xavier_uniform_(self.weight)
        # 使用均匀分布来初始化weight

        self.easy_margin = easy_margin
        self.m = m
        # 夹角差值 0.5 公式中的m
        self.s = s
        # 半径 64 公式中的s
        # 二者大小都是论文中推荐值

        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        # 差值的cos和sin
        self.th = math.cos(math.pi - self.m)
        # 阈值,避免theta + m >= pi
        self.mm = math.sin(math.pi - self.m) * self.m

    def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        # 正则化
        cosine = F.linear(x, W)
        # cos值
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        # sin
        phi = cosine * self.cos_m - sine * self.sin_m
        # cos(theta + m) 余弦公式
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
            # 如果使用easy_margin
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # 将样本的标签映射为one hot形式 例如N个标签,映射为(N,num_classes)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        # 对于正确类别(1*phi)即公式中的cos(theta + m),对于错误的类别(1*cosine)即公式中的cos(theta)
        # 这样对于每一个样本,比如[0,0,0,1,0,0]属于第四类,则最终结果为[cosine, cosine, cosine, phi, cosine, cosine]
        # 再乘以半径,经过交叉熵,正好是ArcFace的公式
        output *= self.s
        # 乘以半径
        return output
           

其中easy_marign处代码还有点疑惑,理解下来是为了保持函数的单调,避免theta + m >= pi,细节还没搞太清楚,需要向大神请教。

继续阅读