天天看點

Torch nn.Concat,nn.ConcatTable,nn.JoinTable.

1、根據字面意思,nn.Concat,nn.ConcatTable都是将輸出concat在一起,那麼二者有什麼差別呢。concat将最終的concat的結果變成一個整體,ConcatTable将輸出儲存在一個table裡面。

mlp=nn.Concat();
mlp:add(nn.SpatialConvolution(,,,,,,,))
mlp:add(nn.SpatialConvolution(,,,,,,,))
print(mlp:forward(torch.randn(,,,)))          --輸出結果為x64x128x128
           
mlp=nn.ConcatTable();
mlp:add(nn.SpatialConvolution(,,,,,,,))
mlp:add(nn.SpatialConvolution(,,,,,,,))
print(mlp:forward(torch.randn(,,,)))
           

第二段代碼輸出結果為

Torch nn.Concat,nn.ConcatTable,nn.JoinTable.

二者的共同點則是接受同一個輸入,對多個輸出進行操作,是以這兩個操作無法進行将多個輸入連在一起的操作,如果想要執行這個操作,就用到了nn.JoinTable().

2、JoinTable()将多個輸入concat在一起,并且生成一個整體

h1 = nn.SpatialConvolution(3,3,7,7,2,2,3,3)()
h2 = nn.SpatialConvolution(3,64,7,7,2,2,3,3)(h1)
h3 = nn.SpatialConvolution(3,64,7,7,2,2,3,3)(h1)
h4 = nn.JoinTable(2)({h3,h2})
mlp = nn.gModule({h1}, {h4})

x = torch.rand(2,3,256,256)
output = mlp:forward(x)
print(output:size())
           

輸出為2x128x64x64

3、/gmodule.lua:135: expecting only one start

遇見這種錯誤,檢視這句話input和output是否都是一個table

local model = nn.gModule({inp}, {tmpOut})
           

4、nn.sigmoid()之前是不能加bn和relu的,否則損失極大無法下降