天天看点

keras.utils.Sequence使用注意事项

1)在实现自己的DataLoader过程中一般都是继承自keras.utils.Sequence,继承该类必须要实现__len__与__getitem__两个函数。

2)在调用fit_generator进行训练时,如果设置了step_per_epoch参数,则每个epoch训练step_per_epoch个step,每个step有batch_size数据,因此每个epoch共训练step_per_epoch*batch_size的数据。如果没有设置step_per_epoch参数,则每个epoch训练的step个数由__len__决定。

3)在训练过程中step_per_epoch的个数可以大于 ceil(float(数据集图片数量)/batch_size) ,这个数字可以认为是遍历一遍数据集需要的实际step数量,__len__一般也实现为这个数字。在每遍历过一次数据集后(确切的说是调用__len__次),会调用一次on_epoch_end()。

4)__getitem__在调用时会有一个index参数,这个参数的取值范围就是range(__len__)的结果,index参数的值是在这个范围内随机给定的。因为__len__实现的时候使用的是ceil向上取整,因此很有可能最后一个index就无法取到一组满batch数据,因为数据集图片数量能够正好整除batch_size的情况很少。如果没有取到一组满batch数据,此时可以返回None,或者干脆什么都不返回。fit_generator在检查到是None的时候会再调用__getitem__一次。

5)所以这个地方要特别注意一点,图片无论是训练集还是验证集的数量一定不能小于batch_size,因为如果图片数量小于batch_size,则永远不能取到一组满batch,程序就会进入无限循环。另一方面在计算__len__的时候,使用了ceil,那么__len__至少大于等于1,也不存在不进入__getitem__的情况。除非数据集图片数量是0。

继续阅读