这个项目我们要做一个识别猫狗的模型,这和上次的数字识别一样,也是运用深度学习,不过这次模型较为复杂,我们会用到迁移学习,站在巨人的肩膀上,借用大佬们已经训练好的模型来搭建我们自己的模型并让它做我们想做的事。
安装要求Python3
Numpy
Scipy
matplotlib
tensorflow
keras
opencv
数据预处理
def make_label(file_name):
label = file_name.split('.')[0]
##one-hot-encoding
if label == 'cat':
return [0]
elif label == 'dog':
return [1]
def make_data(img_path,img_size):
path_length = len(os.listdir(img_path))
images = np.zeros((path_length,img_size,img_size, 3), dtype=np.uint8)
labels = np.zeros((path_length,1),dtype=np.float32)
count = 0
for file_name in os.listdir(img_path):
labels[count] = make_label(file_name)
images[count] = cv2.resize(cv2.imread(img_path+'/'+file_name),(img_size,img_size))
b,g,r = cv2.split(images[count]) # get b,g,r
images[count] = cv2.merge([r,g,b]) # switch it to rgb
count+=1
##shuffle
p = np.random.permutation(path_length)
images,labels = images[p],labels[p]
return images,labels
(猫的标签为0.,狗的标签为1.)
模型基于VGG16的模型
input = Input(shape=(img_size, img_size, 3))
base_model = VGG16(weights='imagenet', input_tensor=input,include_top=False)
x = Flatten()(base_model.output)
x = Dense(2048, activation='relu')(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.7)(x)
output = Dense(1, activation='sigmoid')(x)
model = Model(input=input, output=output)
训练
from keras.callbacks import TensorBoard
from keras.optimizers import SGD
for layer in model.layers[:19]:
layer.trainable = False
opt = SGD(lr=0.0001, momentum=0.9)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
model.fit(train_img,train_label,validation_split=0.2, callbacks=[TensorBoard(log_dir='./log')])
model.save('model.h5')
基于Xception的模型
train_img, train_label = make_data(train_path,299) ## Xception要求的shape为299*299
from keras.applications.xception import Xception
from keras.callbacks import TensorBoard
from keras.optimizers import SGD
input = Input(shape=(img_size, img_size, 3))
base_model_2 = Xception(weights='imagenet', input_tensor=input,include_top=False)
x = Flatten()(base_model_2.output)
#x = Dense(2048, activation='relu')(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.85)(x)
output = Dense(1, activation='sigmoid')(x)
model_2 = Model(input=input, output=output)
for layer in model_2.layers[:132]: ## Xception除了top的全连接还有132层
layer.trainable = False
opt = SGD(lr=0.0001, momentum=0.9)
model_2.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
model_2.fit(train_img,train_label,validation_split=0.2, batch_size=10, callbacks=[TensorBoard(log_dir='./log')])
model_2.save('model_2.h5')
上面是Xception模型结构,较为复杂,我只更改了它的全连接层。
Xception相比VGG16更为庞大和复杂,当然效果也更好。预测可视化
数据提升
from scipy.ndimage.interpolation import shift
def img_change_brightness(img):
# Convert the image to HSV
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
# Compute a random brightness value and apply to the image
brightness = np.random.uniform(0.25,1) ##调整范围在0.25到1之间
img[:, :, 2] = img[:, :, 2] * brightness
# Convert back to RGB
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
index = random.randint(0,len(train_images))
for i in range(4):
plt.subplot(2,2,i+1)
img = [train_images[index], #原图
np.flipud(train_images[index]), #上下翻转
np.fliplr(train_images[index]), #左右翻转
img_change_brightness(train_images[index])] #亮度调整
plt.imshow(img[i])
plt.axis('off')
#随机将25%的训练数据进行亮度调整
for i in range(int(len(train_images)*0.25)):
index = random.randint(0,len(train_images))
train_images[index] = img_change_brightness(train_images[index])
#随机将25%的训练数据进行左右翻转
for i in range(int(len(train_images)*0.25)):
index = random.randint(0,len(train_images))
train_images[index] = np.fliplr(train_images[index])
#随机将25%的训练数据进行上下翻转
for i in range(int(len(train_images)*0.25)):
index = random.randint(0,len(train_images))
train_images[index] = np.flipud(train_images[index])
训练第n次
from keras.callbacks import TensorBoard
from keras.models import load_model
model_3 = load_model('model_2.h5')
model_3.fit(train_images,train_labels,validation_data=(valid_images,valid_labels), batch_size=16, callbacks=[TensorBoard(log_dir='./log')])
model_3.save('model_3.h5')
我们可以直接用上面已经构建好的模型,直接载入新数据来训练
模型效果比较
VGG16
Xception
数据提示后的Xception
结尾
文章代码/数据地址ciozhang