交叉验证在fit()函数的参数里边,完整参数传送
https://blog.csdn.net/Forrest97/article/details/106635664
fit()里边相关交叉验证的参数
validation_data=test,就是自己划分好的测试集
validation_steps, 验证样本总数 Total validation Samples/验证样本大小Validation Batch Size,多少组验证样本的数据后代入网络验证,可能每次输出都能看到验证结果,val_loss 和 val_accuracy 会输出验证精度
validation_batch_size:一般是32
validation_freq:仅当validation_data设置时有效,表示训练完几组epoch后,进行验证
先给一段代码,以下代码的数据集是60000张图片,50000张来训练集,10000张来测试集。其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值,一共有10中分类结果:飞机、汽车、鸟、猫咪、鹿子、狗子、小青蛙、马儿、船、大卡车。
先用DNN建模后,存储模型,再重新调用模型。
import keras
import tensorflow as tf
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
def preprocess(x,y):
x = 2*tf.cast(x,dtype=tf.float32)/225.-1.
y = tf.cast(y,dtype=tf.int32)
return x,y
batch_size = 128
tf.random.set_seed(1)
(x,y),(test_x,test_y) = datasets.cifar10.load_data()
y = tf.squeeze(y)
test_y = tf.squeeze(test_y)
y = tf.one_hot(y,depth=10)
test_y = tf.one_hot(test_y,depth=10)
train = tf.data.Dataset.from_tensor_slices((x,y))
train = train.map(preprocess).shuffle(60000).batch(batch_size=batch_size)
test = tf.data.Dataset.from_tensor_slices((test_x,test_y))
test = test.map(preprocess).batch(batch_size=batch_size)
sample = next(iter(train))
class mydense(layers.Layer):
def __init__(self,input_dim,output_dim):
super(mydense, self).__init__()
self.kernel = self.add_weight('w',[input_dim,output_dim])
self.bias = self.add_weight('b',[1,output_dim])
def call(self,inputs,training = None):
x = inputs @ self.kernel+self.bias
return x
class my_network(keras.Model):
def __init__(self):
super(my_network,self).__init__()
self.fc1=mydense(32*32*3,256)
self.fc2=mydense(256,128)
self.fc3=mydense(128,64)
self.fc4=mydense(64,32)
self.fc5=mydense(32,10)
def call(self,inputs,training = None):
x = tf.reshape(inputs,[-1,32*32*3])
x = self.fc1(x)
x=tf.nn.relu(x)
x = self.fc2(x)
x=tf.nn.relu(x)
x=self.fc3(x)
x = tf.nn.relu(x)
x = tf.nn.relu(self.fc4(x))
x=self.fc5(x)
return x
network = my_network()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
#交叉验证validation_steps : 验证样本总数 Total validation Samples/验证批量大小Validation Batch Size,多少组验证样本的数据后代入网络验证)
#validation_freq:仅当validation_data设置时有效,表示训练完几组epoch后,进行验证。
network.fit(train,epochs=5,validation_data=test,validation_freq=2)
#print("validation_batch_size",network.fit.validation_freq)
network.evaluate(test)
network.save_weights('weights/mynetwork')
del network
print('saved weights')
network = my_network()
network.compile(optimizer=optimizers.Adam(lr=1e-4),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights/mynetwork')
network.fit(train,epochs=5,validation_data=test,validation_batch_size=80)
network.evaluate(test)
输出结果,当设置
validation_freq=2时,第一条fit函数输出是隔一次输出测试集的验证结果,就是在epoch/2的整数倍后,验证数据
Epoch 1/5
391/391 [==============================] - 3s 6ms/step - loss: 1.7262 - accuracy: 0.3867
Epoch 2/5
391/391 [==============================] - 4s 9ms/step - loss: 1.4883 - accuracy: 0.4739 - val_loss: 1.4526 - val_accuracy: 0.4891
Epoch 3/5
391/391 [==============================] - 3s 7ms/step - loss: 1.3768 - accuracy: 0.5139
Epoch 4/5
391/391 [==============================] - 4s 10ms/step - loss: 1.2961 - accuracy: 0.5428 - val_loss: 1.4032 - val_accuracy: 0.5021
Epoch 5/5
391/391 [==============================] - 3s 9ms/step - loss: 1.2186 - accuracy: 0.5733
79/79 [==============================] - 0s 5ms/step - loss: 1.3872 - accuracy: 0.5197
saved weights
Epoch 1/5
391/391 [==============================] - 4s 11ms/step - loss: 1.1536 - accuracy: 0.5942 - val_loss: 1.3735 - val_accuracy: 0.5234
Epoch 2/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0945 - accuracy: 0.6120 - val_loss: 1.3713 - val_accuracy: 0.5252
Epoch 3/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0385 - accuracy: 0.6360 - val_loss: 1.3727 - val_accuracy: 0.5346
Epoch 4/5
391/391 [==============================] - 4s 11ms/step - loss: 0.9848 - accuracy: 0.6520 - val_loss: 1.4310 - val_accuracy: 0.5297
Epoch 5/5
391/391 [==============================] - 5s 12ms/step - loss: 0.9356 - accuracy: 0.6687 - val_loss: 1.4377 - val_accuracy: 0.5266
79/79 [==============================] - 0s 5ms/step - loss: 1.4377 - accuracy: 0.5266
当只设置validation_batch_size=88、88、30,第二条fit函数输出是每次都输出val_loss和val_accuracy,且每次结果都一样,如下:
当只设置validation_steps=30、80,第二条fit函数输出是当30时每次都输出val_loss和val_accuracy,如下:
saved weights
Epoch 1/5
391/391 [==============================] - 3s 9ms/step - loss: 1.1536 - accuracy: 0.5942 - val_loss: 1.3933 - val_accuracy: 0.5185
Epoch 2/5
391/391 [==============================] - 4s 10ms/step - loss: 1.0945 - accuracy: 0.6120 - val_loss: 1.3774 - val_accuracy: 0.5286
Epoch 3/5
391/391 [==============================] - 4s 11ms/step - loss: 1.0385 - accuracy: 0.6360 - val_loss: 1.3880 - val_accuracy: 0.5365
Epoch 4/5
391/391 [==============================] - 4s 11ms/step - loss: 0.9848 - accuracy: 0.6520 - val_loss: 1.4451 - val_accuracy: 0.5326
Epoch 5/5
391/391 [==============================] - 4s 10ms/step - loss: 0.9356 - accuracy: 0.6687 - val_loss: 1.4410 - val_accuracy: 0.5318
79/79 [==============================] - 0s 5ms/step - loss: 1.4377 - accuracy: 0.5266
当设置validation_steps=80,超过验证数据集总数,没有进行验证。
结论:
比较有效果的参数设置:
validation_data和validation_batch_size一起用,validation_batch_size不要超过测试集数据的大小
或者validation_data和validation_freq一起用,validation_freq不要超过epoch的大小