Xception
網絡
大家好,小編又來啦。在前面的文章中呢我們介紹了關于
Inception
的系列網絡,在
2017
年谷歌也是在
Inception
-
V3
的基礎上推出
Xception
,在性能上超越了原有的
Inception
-
V3
。下面我們帶大家見識一下
Xception
的廬山真面目吧!
簡介
在
Xception
中作者主要提出了以下一些亮點:
- 作者從
Inception
-
V3
的假設出發,解耦通道相關性和空間相關性,進行簡化網絡,推導出深度可分離卷積。
- 提出了一個新的
Xception
網絡。
相信小夥伴們看到這裡一定會發出驚呼,納尼,深度可分離卷積不是在
MobileNet
中提出的麼?在這裡需要注意的是,在
Xception
中提出的深度可分離卷積和
MobileNet
中是有差異的,具體我們會在下文中聊到咯。
Xception
的進化之路
在前面我們說過
Xception
是
繼
Inception
後提出的對
Inception
-
v3
的另一種改進,作者認為跨通道相關性和空間相關性應充分解耦(獨立互不相關),是以最好不要将它們共同映射處理,應分而治之。具體是怎麼做呢?
1、先使用
1×1
的卷積核,将特征圖各個通道映射到一個新空間,學習通道間相關性。
2、再使用
3×3
或
5×5
卷積核,同時學習空間相關性和通道間相關性。
進化
1
在
Inception
-
v3
中使用了如下圖
1
所示的多個這樣的子產品堆疊而成,能夠用較小的參數學習到更豐富的資訊。
圖
1Inception
中的子產品
進化
2
在原論文中作者對于圖
1
中的子產品進行了簡化,去除
Inception
-
v3
中的
AvgPool
後,輸入的下一步操作就都是
1
×
1
卷積,如下圖
2
所示:
圖
2
簡化後的
Inception
子產品
進化
3
更進一步,作者對于簡化後的
Inception
-
v3
子產品中的所有
1
×
1
卷積進行合并,什麼意思呢?就是将
Inception
-
v3
子產品重新構造為
1
×
1
卷積,再進行空間卷積(
3
×
3
是标準的那種多通道的卷積),相當于把
1
×
1
卷積後的輸出拼接起來為一個整體,然後進行分組卷積。如下圖
3
所示:
圖
3
經過進化
3
這種操作後,自然會有以下問題:分組數及大小會産生什麼影響?是否有更一般的假設?空間關系映射和通道關系映射是否能夠完全解耦呢?
進化
4
基于進化
3
中提出的問題,作者提出了“
extreme
“版本的
Inception
子產品,如下圖
4
所示。從圖
4
中我們可以看出,所謂的“
extreme
“版本其實就是首先使用
1
x
1
卷積來映射跨通道相關性,然後分别映射每個輸出通道的空間相關性,即對每個通道分别做
3
×
3
卷積。
圖
4
“
extreme
“版本的
Inception
子產品
在此作者也說明了這種
Inception
子產品的“
extreme
“版本幾乎與深度可分離卷積相同,但是依然是存在以下差別的:
1、通常實作的深度可分離卷積(如
MobileNet
中)首先執行通道空間卷積(
DW
卷積),然後執行
1
×
1
卷積,而
Xception
首先執行
1
×
1
卷積。
2、第一次操作後是否加
ReLU
,
Inception
中
2
個操作後都加入
ReLU
。其中“
extreme
“版本的
Inception
子產品為:
Conv
(
1
×
1
)+
BN
+
ReLU
+
Depthconv
(
3
×
3
)+
BN
+
ReLU
;而普通的深度可分離卷積結構為:
Depthconv
(
3
×
3
)+
B
N+
Conv
(
1
×
1
)+
BN
+
ReLU
。
而作者認為第一個差別不大,因為這些操作都是堆疊在一起的;但第二個影響很大,他發現在“
extreme
“版本的
Inception
中
1
×
1
與
3
×
3
之間不用
ReLU
收斂更快、準确率更高,這個作者是做了實驗得到的結論,後面我們會介紹。
Xception
網絡結構
在提出了上面新的子產品結構後,認識卷積神經網絡的特征圖中跨通道相關性和空間相關性的映射是可以完全解耦的。因為結構是由
Inception
體系結構得到的“
extreme
“版本,是以将這種新的子產品結構命名為
Xception
,表示“
ExtremeInception
”。并且作者還結合了
ResNet
的殘差思想,給出了如下圖
5
所示的基于
Xception
的網絡結構:
圖
5Xception
網絡結構
實驗評估
在訓練驗證階段,作者使用了
ImageNet
和
JFT
這兩個資料集做驗證。精度和參數量對比如下圖所示,從圖中可以看到,在精度上
Xception
在
ImageNet
領先較小,但在
JFT
上領先很多;在參數量和推理速度上,
Xception
參數量少于
Inception
,但速度更快。
圖
6ImageNet
資料集上精度對比
圖
7JFT
資料集上精度對比
圖
8
參數量和推理速度對比
如下圖所示,除此之外,作者還比較了是否使用
Residual
殘差結構、是否在
Xception
子產品中兩個操作(
1
×
1
卷積和
3
×
3
卷積)之間加入
ReLu
下的訓練收斂速度和精度。從圖中可以看出,在使用了
Residual
殘差結構和去掉
Xception
子產品中兩個操作之間的
ReLu
激活函數下訓練收斂的速度更快,精度更高。
圖
9
是否采用
Residual
殘差結構的訓練收斂速度和精度
圖
10
是否在
Xception
子產品中兩個操作加入
ReLu
的訓練收斂速度和精度
總結
在
Xception
網絡中作者解耦通道相關性和空間相關性,提出了“
extreme
“版本的
Inception
子產品,結合
ResNet
的殘差思想設計了新的
Xception
網絡結構,相比于之前的
Inception
-
V3
獲得更高的精度和使用了更少的參數量。
這裡給出
Keras
代碼實作:
from keras.models import Model
from keras import layers
from keras.layers import Dense, Input, BatchNormalization, Activation
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, GlobalAveragePooling2D, GlobalMaxPooling2D
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.utils.data_utils import get_file
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5'
def Xception():
# Determine proper input shape
input_shape = _obtain_input_shape(None, default_size=299, min_size=71, data_format='channels_last', include_top=False)
img_input = Input(shape=input_shape)
# Block 1
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(64, (3, 3), use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 2
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 2 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 3
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 3 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
residual = Conv2D(728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 4
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
# Block 5 - 12
for i in range(8):
residual = x
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = layers.add([x, residual])
residual = Conv2D(1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
# Block 13
x = Activation('relu')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Block 13 Pool
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])
# Block 14
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Block 14 part 2
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
# Fully Connected Layer
x = GlobalAveragePooling2D()(x)
x = Dense(1000, activation='softmax')(x)
inputs = img_input
# Create model
model = Model(inputs, x, name='xception')
# Download and cache the Xception weights file
weights_path = get_file('xception_weights.h5', WEIGHTS_PATH, cache_subdir='models')
# load weights
model.load_weights(weights_path)
return model
複制