相关资料:
【人工智能笔记】第十节 Tensorflow 2.0 实现指针仪表方向纠正及指针识别(上)
这次创建预测模型,用于预测仪表指针值,与用于方向纠正4个点坐标。预测坐标点采用预测4个基准点偏移值的方式,而不是直接预测实际坐标值。训练前会对图片做数据增强,进行随机颜色、透视变换、增加噪点等操作,防止模型过拟合。
一、创建模型
1.特征提取模型使用Darknet-53,最后接(9*13*13)卷积,输出:1预测值+4*2个坐标点偏移值。
基准点4个坐标为:
[50, 50] # 左上
[50, 350] # 左下
[350, 50] # 右上
[350, 350] # 右上
创建模型代码如下:
def build_model(self):
'''建立模型'''
# 建立预测模型
self.build_classes_model()
# 优化器
self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=3e-5)
# 保存模型
self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
classes_model=self.classes_model)
self.checkpoint_manager = tf.train.CheckpointManager(
self.checkpoint, self.model_path, max_to_keep=3)
def conv_layer(self, input, filters, kernel_size, strides=(1, 1), padding='same'):
x = tf.keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides,
kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(input)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
return x
def resnet_layer(self, input, filters, layer_sizes):
x = tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0)))(input)
x = self.conv_layer(x, filters, (3, 3),
strides=(2, 2), padding='valid')
for _ in range(layer_sizes):
x2 = x
x = self.conv_layer(x, filters // 2, (1, 1))
x = self.conv_layer(x, filters, (3, 3))
x = tf.keras.layers.Add()([x, x2])
return x
def build_classes_model(self):
'''建立预测模型'''
# 所有参数
input_classes = tf.keras.Input([400, 400, 3], dtype=tf.float32)
x = tf.pad(input_classes, [[0, 0], [8, 8], [8, 8], [0, 0]], "CONSTANT")
# (416 * 416)
x = self.conv_layer(x, 32, (3, 3))
# (208 * 208)
x = self.resnet_layer(x, 64, 1)
# (104 * 104)
x = self.resnet_layer(x, 128, 2)
# (52 * 52)
x = self.resnet_layer(x, 256, 8)
# (26 * 26)
x = self.resnet_layer(x, 512, 8)
y2 = x
# (13 * 13)
x = self.resnet_layer(x, 1024, 4)
x = self.conv_layer(x, 512, (1, 1))
x = self.conv_layer(x, 1024, (3, 3))
x = self.conv_layer(x, 512, (1, 1))
x = self.conv_layer(x, 1024, (3, 3))
x = self.conv_layer(x, 512, (1, 1))
x = self.conv_layer(x, 1024, (3, 3))
x = tf.keras.layers.Conv2D(512, (1, 1), padding='same',
kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(x)
x2 = tf.keras.layers.MaxPool2D((2, 2))(y2)
x2 = tf.keras.layers.Conv2D(512, (1, 1), padding='same',
kernel_regularizer=tf.keras.regularizers.l2(5e-4), use_bias=False)(x2)
x = tf.keras.layers.Concatenate()([x, x2])
x = tf.keras.layers.Conv2D(
9, (13, 13), padding='valid', use_bias=False)(x)
x = tf.keras.layers.Flatten()(x)
self.classes_model = tf.keras.Model(inputs=input_classes, outputs=x)
二、数据增强
对原图片进行数据增强:
def get_random_data(self, image, value, target_points=None):
'''生成随机图片与标签,用于训练'''
# 画矩形
# cv2.rectangle(image, (20, 20), (380, 380), tuple(np.random.randint(0, 30, (3), dtype=np.int32)), thickness=8)
# 变换图像
random_offset_x = random.random()*90-45
random_offset_y = random.random()*90-45
random_angle_x = random.random()*60-30
random_angle_y = random.random()*60-30
random_scale = random.random()*0.6+0.7
# random_offset_x = 0
# random_offset_y = 0
# random_angle_x = 0
# random_angle_y = 0
# random_scale = 1
# 点列表
points = np.float32([[50, 50], # 左上
[50, 350], # 左下
[350, 50], # 右上
[350, 350]]) # 右下
if target_points is None:
target_points = points
image, org, dst, perspective_points = image_helpler.opencvPerspective(image, offset=(random_offset_x, random_offset_y, 0),
angle=(random_angle_x, random_angle_y, 0), scale=(random_scale, random_scale, 1), points=target_points)
# 计算四个角变换差值
perspective_points = (perspective_points - points)/400
# 增加噪声
# image = image_helpler.opencvRandomLines(image, 8)
image = image_helpler.opencvNoise(image)
# 颜色抖动
image = image_helpler.opencvRandomColor(image)
# cv2.imwrite(path, image)
# 最后输出图片
random_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 调整参数范围
random_img = random_img.astype(np.float32)
random_img = random_img / 255
value = value / 100
# 标签
# target_data = np.float32([[value]])
# target_data = np.float32([[value, random_offset_x/400, random_offset_y/400,
# random_angle_x/180, random_angle_y/180, random_scale]])
target_data = np.float32([value, perspective_points[0,0], perspective_points[0,1],
perspective_points[1,0], perspective_points[1,1],
perspective_points[2,0], perspective_points[2,1],
perspective_points[3,0], perspective_points[3,1]])
# print('random_img:', random_img.shape)
# print('target_data:', target_data.shape)
return random_img, target_data
def get_random_image(self, image):
'''生成随机图片,用于测试'''
# 画矩形
# cv2.rectangle(image, (20, 20), (380, 380), tuple(np.random.randint(0, 30, (3), dtype=np.int32)), thickness=8)
# 变换图像
random_offset_x = random.random()*90-45
random_offset_y = random.random()*90-45
random_angle_x = random.random()*60-30
random_angle_y = random.random()*60-30
random_scale = random.random()*0.6+0.7
# random_offset_x = 0
# random_offset_y = 0
# random_angle_x = 0
# random_angle_y = 0
# random_scale = 1
random_img, org, dst, perspective_points = image_helpler.opencvPerspective(image, offset=(random_offset_x, random_offset_y, 0),
angle=(random_angle_x, random_angle_y, 0), scale=(random_scale, random_scale, 1))
# 增加噪声
# random_img = image_helpler.opencvRandomLines(random_img, 8)
random_img = image_helpler.opencvNoise(random_img)
# 颜色抖动
random_img = image_helpler.opencvRandomColor(random_img)
return random_img
三、训练模型
1. 分别计算每个点坐标与预测值的Loss:
@tf.function
def loss_fun(self, y_true, y_pred):
value_loss = tf.math.reduce_sum(tf.math.abs(y_true[:,0]-y_pred[:,0]))
value_p1 = tf.math.reduce_sum(tf.math.square(y_true[:,1:3]-y_pred[:,1:3]))
value_p2 = tf.math.reduce_sum(tf.math.square(y_true[:,3:5]-y_pred[:,3:5]))
value_p3 = tf.math.reduce_sum(tf.math.square(y_true[:,5:7]-y_pred[:,5:7]))
value_p4 = tf.math.reduce_sum(tf.math.square(y_true[:,7:9]-y_pred[:,7:9]))
loss = (value_loss * 3 + value_p1 + value_p2 + value_p3 + value_p4) * 0.2
return loss
2.单步训练与批量训练方法:
@tf.function(input_signature=(
tf.TensorSpec(shape=(None, 400, 400, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 9), dtype=tf.float32),
))
def train_step(self, input_image, target_data):
'''
单步训练
input_image:图片(400,400,3)
target_data:一个指针值(1)+4个点坐标偏移值(8个值)
'''
print('Tracing with train_step', type(input_image), type(target_data))
print('Tracing with train_step', input_image.shape, target_data.shape)
loss = 0.0
with tf.GradientTape() as tape:
# 预测
output_classes = self.classes_model(input_image)
# 计算损失
# loss = self.loss_object(y_true=target_data, y_pred=output_classes)
loss = self.loss_fun(y_true=target_data, y_pred=output_classes)
trainable_variables = self.classes_model.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return loss
def fit_generator(self, generator, steps_per_epoch, epochs, initial_epoch=1, auto_save=False):
'''批量训练'''
for epoch in range(initial_epoch, epochs+1):
start = time.process_time()
epoch_loss = 0
for steps in range(1, steps_per_epoch+1):
x, y = next(generator)
# print('generator', x.shape, y.shape)
loss = self.train_step(x, y)
epoch_loss += loss
print('\rsteps:%d/%d, epochs:%d/%d, loss:%0.4f'
% (steps, steps_per_epoch, epoch, epochs, loss), end='')
end = time.process_time()
print('\rsteps:%d/%d, epochs:%d/%d, %0.4f S, loss:%0.4f, epoch_loss:%0.4f'
% (steps, steps_per_epoch, epoch, epochs, (end - start), loss, epoch_loss))
if auto_save:
self.save_model()
四、预测
1.预测方法:
@tf.function(input_signature=(
tf.TensorSpec(shape=(None, 400, 400, 3), dtype=tf.float32),
))
def predict(self, input_image):
'''
预测(编译模式)
input_image:图片(400,400,3)
return:两个指针值(2)
'''
# 预测
output_classes = self.classes_model(input_image)
return output_classes
2.加载与保存模型代码:
def save_model(self):
'''保存模型'''
save_path = self.checkpoint_manager.save()
print('保存模型 {}'.format(save_path))
def load_model(self):
'''加载模型'''
self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
if self.checkpoint_manager.latest_checkpoint:
print('加载模型 {}'.format(self.checkpoint_manager.latest_checkpoint))
五、编写测试页面
1.在 ai_api\static\gauge\ 路径下创建 predict_image.html 页面。
测试过程,1.生成一张仪表图,2.随机生成颜色、角度、噪点,3.经过网络预测,显示预测后的值与方向纠正后的效果图。
浏览器打开:http://127.0.0.1:8000/static/predict_image.html
代码如下:
<!DOCTYPE html>
<head>
<meta charset="utf-8">
<title>ECharts</title>
</head>
<body>
<input id="btnSubmit" type="button" title="测试" value="测试"></input> 实际值:<span id="txtTrueValue"></span> 估计值:<span id="txtValue"></span>
<br />
<span>
原图:
<!-- 为ECharts准备一个具备大小(宽高)的Dom -->
<div id="main" style="height:400px; width: 400px; vertical-align: top;"></div>
</span>
<br />
<span>
随机变换角度:
<img id="random_img" style="height:400px; width: 400px; vertical-align: top;"></img>
</span>
<span>
算法纠正后的图:
<img id="perspective_img" style="height:400px; width: 400px; vertical-align: top;"></img>
</span>
<script src="https://unpkg.com/axios/dist/axios.min.js"></script>
<!-- ECharts单文件引入 -->
<script src="http://echarts.baidu.com/build/dist/echarts.js"></script>
<script type="text/javascript">
function rgb() {//rgb颜色随机
var r = Math.floor(Math.random() * 256);
var g = Math.floor(Math.random() * 256);
var b = Math.floor(Math.random() * 256);
var rgb = '(' + r + ',' + g + ',' + b + ')';
return rgb;
}
function color16() {//十六进制颜色随机
var r = Math.floor(Math.random() * 256);
var g = Math.floor(Math.random() * 256);
var b = Math.floor(Math.random() * 256);
var color = '#' + r.toString(16) + g.toString(16) + b.toString(16);
return color;
}
// 路径配置
require.config({
paths: {
echarts: 'http://echarts.baidu.com/build/dist'
}
});
// 使用
require(
[
'echarts',
'echarts/chart/gauge' // 按需加载
],
function (ec) {
// 基于准备好的dom,初始化echarts图表
let myChart = ec.init(document.getElementById('main'));
// 为echarts对象加载数据
// myChart.setOption(option);
let updateFun = () => {
// let startAngle = Math.round(Math.random() * 360)
// let endAngle = startAngle - Math.round(Math.random() * 300) - 30
let startAngle = 180 + Math.round(Math.random() * 90) - 45
let endAngle = 30 + Math.round(Math.random() * 120) - 60
// let startAngle = 0
// let endAngle = -360
let option = {
animation: false,
// tooltip: {
// formatter: "{a} <br/>{b} : {c}%"
// },
// toolbox: {
// show: true,
// feature: {
// mark: { show: true },
// restore: { show: true },
// saveAsImage: { show: true }
// }
// },
series: [
{
name: '业务指标',
type: 'gauge',
legendHoverLink: false,
splitNumber: Math.round(Math.random() * 10), // 分割段数,默认为5
startAngle: startAngle,
endAngle: endAngle,
axisLine: { // 坐标轴线
lineStyle: { // 属性lineStyle控制线条样式
color: [[0.2, color16()], [0.8, color16()], [1, color16()]],
width: Math.round(Math.random() * 10) + 3
}
},
axisTick: { // 坐标轴小标记
splitNumber: Math.round(Math.random() * 10), // 每份split细分多少段
length: Math.round(Math.random() * 20), // 属性length控制线长
lineStyle: { // 属性lineStyle控制线条样式
color: color16()
}
},
axisLabel: { // 坐标轴文本标签,详见axis.axisLabel
textStyle: { // 其余属性默认使用全局文本样式,详见TEXTSTYLE
color: color16()
}
},
splitLine: { // 分隔线
show: true, // 默认显示,属性show控制显示与否
length: Math.round(Math.random() * 20) + 25, // 属性length控制线长
lineStyle: { // 属性lineStyle(详见lineStyle)控制线条样式
color: color16()
}
},
pointer: {
length: (Math.round(Math.random() * 40) + 60) + '%',
width: Math.round(Math.random() * 8) + 1,
color: color16()
},
title: {
show: false,
offsetCenter: [0, '-40%'], // x, y,单位px
textStyle: { // 其余属性默认使用全局文本样式,详见TEXTSTYLE
fontWeight: 'bolder'
}
},
detail: {
show: false,
formatter: '{value}%',
textStyle: { // 其余属性默认使用全局文本样式,详见TEXTSTYLE
color: color16(),
fontWeight: 'bolder'
}
},
data: [{ value: 50, name: '' }]
}
]
};
option.series[0].data[0].value = (Math.random() * 100).toFixed(2) - 0;
myChart.setOption(option, true);
let img = myChart.getDataURL();
// console.log('图片数据:', img);
axios.post('/ai_api/gauge/gauge_predict', {
img_data: img,
read: 0,
})
.then(function (response) {
console.log(response);
// alert('识别值:'+response.data.value[0][0]);
document.getElementById('random_img').src = 'data:image/jpg;base64,' + response.data.random_img;
document.getElementById('perspective_img').src = 'data:image/jpg;base64,' + response.data.perspective_img;
document.getElementById('txtValue').innerText = response.data.value[0][0] * 100;
document.getElementById('txtTrueValue').innerText = option.series[0].data[0].value;
})
.catch(function (error) {
console.log(error);
});
}
document.getElementById('btnSubmit').addEventListener('click', updateFun);
}
);
</script>
</body>
效果图:
下一节,将会讲解如何将训练好的模型应用于真实环境,进行方向纠正与指针识别。源码会在下一节发布,敬请关注!