天天看点

torch.nn.functional.binary_cross_entropy_with_logits

torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, reduce=None, reduction=‘mean’, pos_weight=None)

参数说明

input

神经网络预测结果(未经过sigmoid),任意形状的tensor

target

标签值,与input形状相同

weight

权重值,可用于mask的作用, 具体作用下面讲解,形状同input

size_average

弃用,见reduction参数

reduce

弃用,见reduction参数

reduction

指定对输出结果的操作,可以设定为

none

mean

sum

; none将不对结果进行任何处理,mean对结果求均值, sum对结果求和, 默认是mean

函数功能

对神经网络的输出结果进行sigmoid操作,然后求交叉熵,当参数reduction=mean时其实际效果和下面的程序相同; weight参数实际是对交叉熵结果的权重

x = torch.sigmoid(pred)   # pred是神经网络的输出
result = -torch.mean((label*torch.log(x)+(1-label)*torch.log(1-x))*weight)
print(result)
           

注意

torch.nn.BCEWithLogitsLoss()功能及其中参数和上面的函数一样

例子

pred = torch.tensor([[-1.0, 2],[0.5, 1]])
label = torch.tensor([[1, 0.0],[2,5]])
mask = torch.tensor([[1,0], [1,1]])
l = nn.functional.binary_cross_entropy_with_logits(pred, target=label, reduction="mean", weight=mask)
print(l)
x = torch.sigmoid(pred)
result = -torch.mean((label*torch.log(x)+(1-label)*torch.log(1-x))*mask)
print("result:",result)
           

输出

tensor(-0.5998)
result: tensor(-0.5999)
           

继续阅读