天天看點

花朵分類(一)

本次教程的目的是帶領大家學會基本的花朵圖像分類

首先我們來介紹下資料集,該資料集有5種花,一共有3670張圖檔,分别是daisy、dandelion、roses、sunflowers、tulips,資料存放結構如下所示

花朵分類(一)
我們可以展示下roses的幾張圖檔
花朵分類(一)

接下來我們需要加載資料集,然後對資料集進行劃分,最後形成訓練集、驗證集、測試集,注意此處的驗證集是從訓練集切分出來的,比例是8:2

對資料進行探索的時候,我們發現原始的像素值是0-255,為了模型訓練更穩定以及更容易收斂,我們需要标準化資料集,一般來說就是把像素值縮放到0-1,可以用下面的layer來實作

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)           

為了使訓練的時候I/O不成為瓶頸,我們可以進行如下設定

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)           

下一步就是模型搭建,然後對模型進行訓練

num_classes = 5

model = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes)
])

model.compile(
  optimizer='adam',
  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['accuracy'])

model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=3
)           
花朵分類(一)

從上圖的訓練記錄可以發現,該模型處于欠拟合狀态,我們可以通過多訓練幾輪來解決這個問題,而且為了快速實驗,我們這裡用了一個非常簡單的模型,我們可以通過更換更強的模型,來提升模型的表現

代碼連結:

https://codechina.csdn.net/csdn_codechina/enterprise_technology/-/blob/master/load_preprocess_images.ipynb

繼續閱讀