ArcFace是比較新的人臉分類的Loss函數,詳細論文可以看論文:ArcFace: Additive Angular Margin Loss for Deep Face Recognition
論文: https://arxiv.org/abs/1801.07698
官方代碼: https://github.com/deepinsight/insightface
本文主要對代碼進行講解和注釋。
class ArcMarginModel(nn.Module):
def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):
super(ArcMarginModel, self).__init__()
self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))
# num_classes 訓練集中總的人臉分類數
# emb_size 特征向量長度
nn.init.xavier_uniform_(self.weight)
# 使用均勻分布來初始化weight
self.easy_margin = easy_margin
self.m = m
# 夾角內插補點 0.5 公式中的m
self.s = s
# 半徑 64 公式中的s
# 二者大小都是論文中推薦值
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
# 內插補點的cos和sin
self.th = math.cos(math.pi - self.m)
# 門檻值,避免theta + m >= pi
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, label):
x = F.normalize(input)
W = F.normalize(self.weight)
# 正則化
cosine = F.linear(x, W)
# cos值
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
# sin
phi = cosine * self.cos_m - sine * self.sin_m
# cos(theta + m) 餘弦公式
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
# 如果使用easy_margin
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=device)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# 将樣本的标簽映射為one hot形式 例如N個标簽,映射為(N,num_classes)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
# 對于正确類别(1*phi)即公式中的cos(theta + m),對于錯誤的類别(1*cosine)即公式中的cos(theta)
# 這樣對于每一個樣本,比如[0,0,0,1,0,0]屬于第四類,則最終結果為[cosine, cosine, cosine, phi, cosine, cosine]
# 再乘以半徑,經過交叉熵,正好是ArcFace的公式
output *= self.s
# 乘以半徑
return output
其中easy_marign處代碼還有點疑惑,了解下來是為了保持函數的單調,避免theta + m >= pi,細節還沒搞太清楚,需要向大神請教。