天天看點

ArcFace(InsightFace)pytorch代碼實作

ArcFace是比較新的人臉分類的Loss函數,詳細論文可以看論文:ArcFace: Additive Angular Margin Loss for Deep Face Recognition

論文: https://arxiv.org/abs/1801.07698

官方代碼: https://github.com/deepinsight/insightface

本文主要對代碼進行講解和注釋。

class ArcMarginModel(nn.Module):
    def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):
        super(ArcMarginModel, self).__init__()

        self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))
        # num_classes 訓練集中總的人臉分類數
        # emb_size 特征向量長度
        nn.init.xavier_uniform_(self.weight)
        # 使用均勻分布來初始化weight

        self.easy_margin = easy_margin
        self.m = m
        # 夾角內插補點 0.5 公式中的m
        self.s = s
        # 半徑 64 公式中的s
        # 二者大小都是論文中推薦值

        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        # 內插補點的cos和sin
        self.th = math.cos(math.pi - self.m)
        # 門檻值,避免theta + m >= pi
        self.mm = math.sin(math.pi - self.m) * self.m

    def forward(self, input, label):
        x = F.normalize(input)
        W = F.normalize(self.weight)
        # 正則化
        cosine = F.linear(x, W)
        # cos值
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        # sin
        phi = cosine * self.cos_m - sine * self.sin_m
        # cos(theta + m) 餘弦公式
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
            # 如果使用easy_margin
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # 将樣本的标簽映射為one hot形式 例如N個标簽,映射為(N,num_classes)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        # 對于正确類别(1*phi)即公式中的cos(theta + m),對于錯誤的類别(1*cosine)即公式中的cos(theta)
        # 這樣對于每一個樣本,比如[0,0,0,1,0,0]屬于第四類,則最終結果為[cosine, cosine, cosine, phi, cosine, cosine]
        # 再乘以半徑,經過交叉熵,正好是ArcFace的公式
        output *= self.s
        # 乘以半徑
        return output
           

其中easy_marign處代碼還有點疑惑,了解下來是為了保持函數的單調,避免theta + m >= pi,細節還沒搞太清楚,需要向大神請教。

繼續閱讀