天天看點

torch.nn.functional.binary_cross_entropy_with_logits

torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, reduce=None, reduction=‘mean’, pos_weight=None)

參數說明

input

神經網絡預測結果(未經過sigmoid),任意形狀的tensor

target

标簽值,與input形狀相同

weight

權重值,可用于mask的作用, 具體作用下面講解,形狀同input

size_average

棄用,見reduction參數

reduce

棄用,見reduction參數

reduction

指定對輸出結果的操作,可以設定為

none

mean

sum

; none将不對結果進行任何處理,mean對結果求均值, sum對結果求和, 預設是mean

函數功能

對神經網絡的輸出結果進行sigmoid操作,然後求交叉熵,當參數reduction=mean時其實際效果和下面的程式相同; weight參數實際是對交叉熵結果的權重

x = torch.sigmoid(pred)   # pred是神經網絡的輸出
result = -torch.mean((label*torch.log(x)+(1-label)*torch.log(1-x))*weight)
print(result)
           

注意

torch.nn.BCEWithLogitsLoss()功能及其中參數和上面的函數一樣

例子

pred = torch.tensor([[-1.0, 2],[0.5, 1]])
label = torch.tensor([[1, 0.0],[2,5]])
mask = torch.tensor([[1,0], [1,1]])
l = nn.functional.binary_cross_entropy_with_logits(pred, target=label, reduction="mean", weight=mask)
print(l)
x = torch.sigmoid(pred)
result = -torch.mean((label*torch.log(x)+(1-label)*torch.log(1-x))*mask)
print("result:",result)
           

輸出

tensor(-0.5998)
result: tensor(-0.5999)
           

繼續閱讀