天天看點

tensorflow标準差方差_CNN基礎知識 || 均方差損失函數

定義

在深度學習中,做回歸任務時使用的loss多為均方差。公式為

tensorflow标準差方差_CNN基礎知識 || 均方差損失函數

其中:Batch為樣本數量,M為網絡輸出層的元素的個數

實作

loss = tf.nn.l2_loss(x, x')  *2.0/ (Batch*M)

loss = tf.losses.mean_squared_error(x, x')

loss = tf.reduce_mean((x - x')**2)

loss = tf.reduce_mean(tf.aquare(x - x'))    (與上面的同意)

對于tf.nn.l2_loss(),數學表達式為 output = sum(t**2)/2

測試

認為a中有3個資料,計算a和b的均方差

import tensorflow as tf

a= [[1.0,2.0,3.0,5.0],[1.0,2.0,3.0,7.0],[1.0,2.0,3.0,2.0]]

b= [[2.0,4.0,3.0,7.0],[2.0,4.0,3.0,5.0],[2.0,4.0,3.0,5.0]]

c = tf.convert_to_tensor(a)

d = tf.convert_to_tensor(b)

loss1 = tf.nn.l2_loss(c-d) *2/(3*4)

loss2 = tf.reduce_mean((c-d)**2)

loss3 = tf.reduce_mean(tf.square(c-d))

loss4 = tf.losses.mean_squared_error(c,d)

with tf.Session() as sess:

print(sess.run(loss1)) # 2.6666667

print(sess.run(loss2)) # 2.6666667

print(sess.run(loss3)) # 2.6666667

print(sess.run(loss4)) # 2.6666667

上面四種表達結果是一樣的