天天看点

【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)

相关资料:

【人工智能笔记】第十节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(上)

这次创建预测模型,用于预测仪表指针值,与用于方向纠正4个点坐标。预测坐标点采用预测4个基准点偏移值的方式,而不是直接预测实际坐标值。训练前会对图片做数据增强,进行随机颜色、透视变换、增加噪点等操作,防止模型过拟合。

一、创建模型

 1.特征提取模型使用Darknet-53,最后接(9*13*13)卷积,输出:1预测值+4*2个坐标点偏移值。

基准点4个坐标为:

[50, 50]  # 左上

[50, 350]  # 左下

[350, 50]  # 右上

[350, 350]  # 右上

创建模型代码如下:

def build_model(self):
        '''建立模型'''
        # 建立预测模型
        self.build_classes_model()
        # 优化器
        self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=3e-5)
        # 保存模型
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                              classes_model=self.classes_model)
        self.checkpoint_manager = tf.train.CheckpointManager(
            self.checkpoint, self.model_path, max_to_keep=3)

    def conv_layer(self, input, filters, kernel_size, strides=(1, 1), padding='same'):
        x = tf.keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides,
                                   kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(input)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
        return x

    def resnet_layer(self, input, filters, layer_sizes):
        x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(input)
        x = self.conv_layer(x, filters, (3, 3),
                            strides=(2, 2), padding='valid')
        for _ in range(layer_sizes):
            x2 = x
            x = self.conv_layer(x, filters // 2, (1, 1))
            x = self.conv_layer(x, filters, (3, 3))
            x = tf.keras.layers.Add()([x, x2])
        return x

    def build_classes_model(self):
        '''建立预测模型'''
        # 所有参数
        input_classes = tf.keras.Input([400, 400, 3], dtype=tf.float32)
        x = tf.pad(input_classes, [[0, 0], [8, 8], [8, 8], [0, 0]], "CONSTANT")
        # (416 * 416)
        x = self.conv_layer(x, 32, (3, 3))
        # (208 * 208)
        x = self.resnet_layer(x, 64, 1)
        # (104 * 104)
        x = self.resnet_layer(x, 128, 2)
        # (52 * 52)
        x = self.resnet_layer(x, 256, 8)
        # (26 * 26)
        x = self.resnet_layer(x, 512, 8)
        y2 = x
        # (13 * 13)
        x = self.resnet_layer(x, 1024, 4)
        x = self.conv_layer(x, 512, (1, 1))
        x = self.conv_layer(x, 1024, (3, 3))
        x = self.conv_layer(x, 512, (1, 1))
        x = self.conv_layer(x, 1024, (3, 3))
        x = self.conv_layer(x, 512, (1, 1))
        x = self.conv_layer(x, 1024, (3, 3))
        x = tf.keras.layers.Conv2D(512, (1, 1), padding='same',
                                   kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(x)
        x2 = tf.keras.layers.MaxPool2D((2, 2))(y2)
        x2 = tf.keras.layers.Conv2D(512, (1, 1), padding='same',
                                   kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(x2)
        x = tf.keras.layers.Concatenate()([x, x2])
        x = tf.keras.layers.Conv2D(
            9, (13, 13), padding='valid', use_bias=False)(x)
        x = tf.keras.layers.Flatten()(x)
        self.classes_model = tf.keras.Model(inputs=input_classes, outputs=x)
    
           

二、数据增强

 对原图片进行数据增强:

def get_random_data(self, image, value, target_points=None):
        '''生成随机图片与标签,用于训练'''
        # 画矩形
        # cv2.rectangle(image, (20, 20), (380, 380), tuple(np.random.randint(0, 30, (3), dtype=np.int32)), thickness=8)
        # 变换图像
        random_offset_x = random.random()*90-45
        random_offset_y = random.random()*90-45
        random_angle_x = random.random()*60-30
        random_angle_y = random.random()*60-30
        random_scale = random.random()*0.6+0.7
        # random_offset_x = 0
        # random_offset_y = 0
        # random_angle_x = 0
        # random_angle_y = 0
        # random_scale = 1
        # 点列表
        points = np.float32([[50, 50], # 左上
                            [50, 350], # 左下
                            [350, 50], # 右上
                            [350, 350]]) # 右下
        if target_points is None:
            target_points = points
        image, org, dst, perspective_points = image_helpler.opencvPerspective(image, offset=(random_offset_x, random_offset_y, 0),
                                                        angle=(random_angle_x, random_angle_y, 0), scale=(random_scale, random_scale, 1), points=target_points)
        # 计算四个角变换差值
        perspective_points = (perspective_points - points)/400
        # 增加噪声
        # image = image_helpler.opencvRandomLines(image, 8)
        image = image_helpler.opencvNoise(image)
        # 颜色抖动
        image = image_helpler.opencvRandomColor(image)

        # cv2.imwrite(path, image)

        # 最后输出图片
        random_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # 调整参数范围
        random_img = random_img.astype(np.float32)
        random_img = random_img / 255
        value = value / 100
        # 标签
        # target_data = np.float32([[value]])
        # target_data = np.float32([[value, random_offset_x/400, random_offset_y/400,
        #                            random_angle_x/180, random_angle_y/180, random_scale]])
        target_data = np.float32([value, perspective_points[0,0], perspective_points[0,1],
                                    perspective_points[1,0], perspective_points[1,1],
                                    perspective_points[2,0], perspective_points[2,1],
                                    perspective_points[3,0], perspective_points[3,1]])
        # print('random_img:', random_img.shape)
        # print('target_data:', target_data.shape)
        return random_img, target_data
    
    def get_random_image(self, image):
        '''生成随机图片,用于测试'''
        # 画矩形
        # cv2.rectangle(image, (20, 20), (380, 380), tuple(np.random.randint(0, 30, (3), dtype=np.int32)), thickness=8)
        # 变换图像
        random_offset_x = random.random()*90-45
        random_offset_y = random.random()*90-45
        random_angle_x = random.random()*60-30
        random_angle_y = random.random()*60-30
        random_scale = random.random()*0.6+0.7
        # random_offset_x = 0
        # random_offset_y = 0
        # random_angle_x = 0
        # random_angle_y = 0
        # random_scale = 1
        random_img, org, dst, perspective_points = image_helpler.opencvPerspective(image, offset=(random_offset_x, random_offset_y, 0),
                                                        angle=(random_angle_x, random_angle_y, 0), scale=(random_scale, random_scale, 1))
        # 增加噪声
        # random_img = image_helpler.opencvRandomLines(random_img, 8)
        random_img = image_helpler.opencvNoise(random_img)
        # 颜色抖动
        random_img = image_helpler.opencvRandomColor(random_img)
        return random_img
           

三、训练模型

1. 分别计算每个点坐标与预测值的Loss:

@tf.function
    def loss_fun(self, y_true, y_pred):
        value_loss = tf.math.reduce_sum(tf.math.abs(y_true[:,0]-y_pred[:,0]))
        value_p1 = tf.math.reduce_sum(tf.math.square(y_true[:,1:3]-y_pred[:,1:3]))
        value_p2 = tf.math.reduce_sum(tf.math.square(y_true[:,3:5]-y_pred[:,3:5]))
        value_p3 = tf.math.reduce_sum(tf.math.square(y_true[:,5:7]-y_pred[:,5:7]))
        value_p4 = tf.math.reduce_sum(tf.math.square(y_true[:,7:9]-y_pred[:,7:9]))
        loss = (value_loss * 3 + value_p1 + value_p2 + value_p3 + value_p4) * 0.2
        return loss
           

2.单步训练与批量训练方法: 

@tf.function(input_signature=(
        tf.TensorSpec(shape=(None, 400, 400, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 9), dtype=tf.float32),
    ))
    def train_step(self, input_image, target_data):
        '''
        单步训练
        input_image:图片(400,400,3)
        target_data:一个指针值(1)+4个点坐标偏移值(8个值)
        '''
        print('Tracing with train_step', type(input_image), type(target_data))
        print('Tracing with train_step', input_image.shape, target_data.shape)
        loss = 0.0
        with tf.GradientTape() as tape:
            # 预测
            output_classes = self.classes_model(input_image)
            # 计算损失
            # loss = self.loss_object(y_true=target_data, y_pred=output_classes)
            loss = self.loss_fun(y_true=target_data, y_pred=output_classes)

        trainable_variables = self.classes_model.trainable_variables
        gradients = tape.gradient(loss, trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, trainable_variables))

        return loss

    def fit_generator(self, generator, steps_per_epoch, epochs, initial_epoch=1, auto_save=False):
        '''批量训练'''
        for epoch in range(initial_epoch, epochs+1):
            start = time.process_time()
            epoch_loss = 0
            for steps in range(1, steps_per_epoch+1):
                x, y = next(generator)
                # print('generator', x.shape, y.shape)
                loss = self.train_step(x, y)
                epoch_loss += loss
                print('\rsteps:%d/%d, epochs:%d/%d, loss:%0.4f'
                      % (steps, steps_per_epoch, epoch, epochs, loss), end='')
            end = time.process_time()
            print('\rsteps:%d/%d, epochs:%d/%d, %0.4f S, loss:%0.4f, epoch_loss:%0.4f'
                  % (steps, steps_per_epoch, epoch, epochs, (end - start), loss, epoch_loss))
            if auto_save:
                self.save_model()
           

四、预测 

1.预测方法:

@tf.function(input_signature=(
        tf.TensorSpec(shape=(None, 400, 400, 3), dtype=tf.float32),
    ))
    def predict(self, input_image):
        '''
        预测(编译模式)
        input_image:图片(400,400,3)
        return:两个指针值(2)
        '''
        # 预测
        output_classes = self.classes_model(input_image)
        return output_classes
           

2.加载与保存模型代码:

def save_model(self):
        '''保存模型'''
        save_path = self.checkpoint_manager.save()
        print('保存模型 {}'.format(save_path))

    def load_model(self):
        '''加载模型'''
        self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
        if self.checkpoint_manager.latest_checkpoint:
            print('加载模型 {}'.format(self.checkpoint_manager.latest_checkpoint))

           

 五、编写测试页面

 1.在 ai_api\static\gauge\ 路径下创建 predict_image.html 页面。

测试过程,1.生成一张仪表图,2.随机生成颜色、角度、噪点,3.经过网络预测,显示预测后的值与方向纠正后的效果图。

浏览器打开:http://127.0.0.1:8000/static/predict_image.html

代码如下:

<!DOCTYPE html>

<head>
  <meta charset="utf-8">
  <title>ECharts</title>
</head>

<body>
  <input id="btnSubmit" type="button" title="测试" value="测试"></input> 实际值:<span id="txtTrueValue"></span> 估计值:<span id="txtValue"></span>
  <br />
  <span>
    原图:
    <!-- 为ECharts准备一个具备大小(宽高)的Dom -->
    <div id="main" style="height:400px; width: 400px; vertical-align: top;"></div>
  </span>
  <br />
  <span>
    随机变换角度:
    <img id="random_img" style="height:400px; width: 400px; vertical-align: top;"></img>
  </span>
  <span>
    算法纠正后的图:
    <img id="perspective_img" style="height:400px; width: 400px; vertical-align: top;"></img>
  </span>
  <script src="https://unpkg.com/axios/dist/axios.min.js"></script>
  <!-- ECharts单文件引入 -->
  <script src="http://echarts.baidu.com/build/dist/echarts.js"></script>
  <script type="text/javascript">
    function rgb() {//rgb颜色随机
      var r = Math.floor(Math.random() * 256);
      var g = Math.floor(Math.random() * 256);
      var b = Math.floor(Math.random() * 256);
      var rgb = '(' + r + ',' + g + ',' + b + ')';
      return rgb;
    }
    function color16() {//十六进制颜色随机
      var r = Math.floor(Math.random() * 256);
      var g = Math.floor(Math.random() * 256);
      var b = Math.floor(Math.random() * 256);
      var color = '#' + r.toString(16) + g.toString(16) + b.toString(16);
      return color;
    }
    // 路径配置
    require.config({
      paths: {
        echarts: 'http://echarts.baidu.com/build/dist'
      }
    });

    // 使用
    require(
      [
        'echarts',
        'echarts/chart/gauge' // 按需加载
      ],
      function (ec) {
        // 基于准备好的dom,初始化echarts图表
        let myChart = ec.init(document.getElementById('main'));

        // 为echarts对象加载数据 
        // myChart.setOption(option);
        let updateFun = () => {
          // let startAngle = Math.round(Math.random() * 360)
          // let endAngle = startAngle - Math.round(Math.random() * 300) - 30
          let startAngle = 180 + Math.round(Math.random() * 90) - 45
          let endAngle = 30 + Math.round(Math.random() * 120) - 60
          // let startAngle = 0
          // let endAngle = -360
          let option = {
            animation: false,
            // tooltip: {
            //   formatter: "{a} <br/>{b} : {c}%"
            // },
            // toolbox: {
            //   show: true,
            //   feature: {
            //     mark: { show: true },
            //     restore: { show: true },
            //     saveAsImage: { show: true }
            //   }
            // },
            series: [
              {
                name: '业务指标',
                type: 'gauge',
                legendHoverLink: false,
                splitNumber: Math.round(Math.random() * 10),       // 分割段数,默认为5
                startAngle: startAngle,
                endAngle: endAngle,
                axisLine: {            // 坐标轴线
                  lineStyle: {       // 属性lineStyle控制线条样式
                    color: [[0.2, color16()], [0.8, color16()], [1, color16()]],
                    width: Math.round(Math.random() * 10) + 3
                  }
                },
                axisTick: {            // 坐标轴小标记
                  splitNumber: Math.round(Math.random() * 10),   // 每份split细分多少段
                  length: Math.round(Math.random() * 20),        // 属性length控制线长
                  lineStyle: {       // 属性lineStyle控制线条样式
                    color: color16()
                  }
                },
                axisLabel: {           // 坐标轴文本标签,详见axis.axisLabel
                  textStyle: {       // 其余属性默认使用全局文本样式,详见TEXTSTYLE
                    color: color16()
                  }
                },
                splitLine: {           // 分隔线
                  show: true,        // 默认显示,属性show控制显示与否
                  length: Math.round(Math.random() * 20) + 25,         // 属性length控制线长
                  lineStyle: {       // 属性lineStyle(详见lineStyle)控制线条样式
                    color: color16()
                  }
                },
                pointer: {
                  length: (Math.round(Math.random() * 40) + 60) + '%',
                  width: Math.round(Math.random() * 8) + 1,
                  color: color16()
                },
                title: {
                  show: false,
                  offsetCenter: [0, '-40%'],       // x, y,单位px
                  textStyle: {       // 其余属性默认使用全局文本样式,详见TEXTSTYLE
                    fontWeight: 'bolder'
                  }
                },
                detail: {
                  show: false,
                  formatter: '{value}%',
                  textStyle: {       // 其余属性默认使用全局文本样式,详见TEXTSTYLE
                    color: color16(),
                    fontWeight: 'bolder'
                  }
                },
                data: [{ value: 50, name: '' }]
              }
            ]
          };
          option.series[0].data[0].value = (Math.random() * 100).toFixed(2) - 0;
          myChart.setOption(option, true);
          let img = myChart.getDataURL();
          // console.log('图片数据:', img);
          axios.post('/ai_api/gauge/gauge_predict', {
            img_data: img,
            read: 0,
          })
            .then(function (response) {
              console.log(response);
              // alert('识别值:'+response.data.value[0][0]);
              document.getElementById('random_img').src = 'data:image/jpg;base64,' + response.data.random_img;
              document.getElementById('perspective_img').src = 'data:image/jpg;base64,' + response.data.perspective_img;
              document.getElementById('txtValue').innerText = response.data.value[0][0] * 100;
              document.getElementById('txtTrueValue').innerText = option.series[0].data[0].value;
            })
            .catch(function (error) {
              console.log(error);
            });
        }
        document.getElementById('btnSubmit').addEventListener('click', updateFun);
      }
    );
  </script>
</body>
           

效果图:

【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)
【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)
【人工智能笔记】第十一节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(中)

下一节,将会讲解如何将训练好的模型应用于真实环境,进行方向纠正与指针识别。源码会在下一节发布,敬请关注!

继续阅读