天天看点

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

AI-DSW(Data science workshop)是专门为算法开发者准备的云端深度学习开发环境,

进入DSW,目前只有KerasCode和KerasGraph两个Kernel实现了FastNeuralNetwork功能。

  • KerasCode:先写深度学习网络代码,然后将代码转成图
  • KerasGraph:直接通过画布构建深度学习网络,并且将图转成代码

接下来我们通过实现Resnet18实现手势识别为例,体验AI-DSW的使用

我们的任务为,手语英文字母数据集中包含用手语表示的26个英文字母的信息,我们通过建立ResNet18模型进行手语英文字母识别

在AI-DSW 的官方文档中推荐我们采用序贯式(sequential)的方式构建模型,但是嵌套式封装来构建模型可以使结构更清晰,一些内容可以复用,我们来具体看下代码:

def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same'):

    x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides)(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)
    return x           

首先我们将最常见的CNN模块封装,包括卷积,BN,激活函数;用于Resnet模型的复用;

def identity_Block(inpt, nb_filter, kernel_size, strides=(1, 1), with_conv_shortcut=False):
    x = Conv2d_BN(inpt, nb_filter=nb_filter, kernel_size=kernel_size, strides=strides, padding='same')
    x = Conv2d_BN(x, nb_filter=nb_filter, kernel_size=kernel_size, padding='same')
    if with_conv_shortcut:#shortcut的含义是:将输入层x与最后的输出层y进行连接,如上图所示
        shortcut = Conv2d_BN(inpt, nb_filter=nb_filter, strides=strides, kernel_size=kernel_size)
        x = add([x, shortcut])
        return x
    else:
        x = add([x, inpt])
        return x           

接下来我们实现Resnet用于Residual Block的模块,即残差块,基于残差块可以有效提升网络性能,提升模型泛化能力,如图所示:

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

有了核心模块后,我们可着手搭建模型的核心结构,包括输入,卷积,残差,池化,全连接,输出等一系列步骤

def resnet_18(width,height,channel,classes):
    inpt = Input(shape=(width, height, channel))
    # x = ZeroPadding2D((3, 3))(inpt)
 
    #conv1
    x = Conv2d_BN(inpt, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
    x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
    #conv2_x
    x = identity_Block(x, nb_filter=64, kernel_size=(3, 3))
    x = identity_Block(x, nb_filter=64, kernel_size=(3, 3))
 
    #conv3_x
    x = identity_Block(x, nb_filter=128, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
    x = identity_Block(x, nb_filter=128, kernel_size=(3, 3))
 
    #conv4_x     
    x = identity_Block(x, nb_filter=256, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
    x = identity_Block(x, nb_filter=256, kernel_size=(3, 3))
 
    #conv5_x
    x = identity_Block(x, nb_filter=512, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
    x = identity_Block(x, nb_filter=512, kernel_size=(3, 3))

    x = GlobalAvgPool2D()(x)
    x = Dense(classes, activation='softmax')(x)
 
    model = Model(inputs=inpt, outputs=x)
    return model           

在 DSW的官方介绍

https://www.alibabacloud.com/help/zh/doc-detail/126303.htm

采用的是序贯式来做模型展示,这里我们发现,基于嵌套式策略同样可以做生成模型结构,如图所示:

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

同样的,我们按照官方文档介绍的,也可做模型可视化编辑,调整参数等

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

有了模型后,我们定义损失函数,加入训练集验证集来训练优化模型,最终得到结果。

AI-DSW 上编辑嵌套式模型实现Resnet手势识别

综上,体验了KerasGraph后,个人感觉它代表了最新的ai开发环境演进方向——类似轻代码(low code)编辑器,可以快速构建模型结构并验证模型效果,提升了我们对模型结构的实现效率,避免纠结与在TF过于繁琐的源码,而是Focus在模型结构优化本身,总体来说还是不错的。

当然KerasGraph当前使用也存在一些问题:

  • 暂不支持各类预训练模型,比如keras_bert,resnet这些,不过在支持了预训练模型,甚至支持对预训练模型最后几层做编辑,将大大提升实用性
  • KerasGraph图形化界面前端占用过多内存,有的时候会导致页面卡塞
  • KerasGraph对于各层参数编辑和定义易用性还需要提升,目前并不比查阅文档方便多少

当然这不妨碍KerasGraph已经是个较为出色的模型展示工具,我也相信假以时日KerasGraph在模型编辑上取得突破

继续阅读