前言
- 每迭代100次保存一次loss值
- 每迭代1000次保存一次validation值
- 曲线表示平均值,背景表示方差
码代码
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter
def mean_list(valid_ids, valid_values, stride=2, iters=1000):
length = len(valid_ids)
nums = length // stride
used_ids = []
mean_values = []
std_vallues = []
valid_values = np.asarray(valid_values)
for k in range(nums):
used_ids.append(int(k*stride*iters))
temp_list = valid_values[k:(k+1)*stride]
mean = np.mean(temp_list)
std = np.std(temp_list)
mean_values.append(mean)
std_vallues.append(std)
return used_ids, mean_values, std_vallues
# 设置全局格式,包括字体风格和大小等等
# 这里主要用来修改文本字体里面的格式
font_size = 25
config = {
"font.family":'serif',
"font.size": font_size,
"mathtext.fontset":'stix',
"font.serif": ['SimSun'],
}
rcParams.update(config)
# 修改x轴的显示方式,科学计数法
def formatnumx(x, pos):
return '%d' % (x/1000)
formatterx = FuncFormatter(formatnumx)
fig, ax1 = plt.subplots(figsize=(7,5),dpi=100)
ax2 = ax1.twinx()
# train_ids 为迭代次数(每一百次为一个单位)
# train_mse 为每100次的loss值
used_ids, mean_values, std_vallues = mean_list(train_ids, train_mse, stride=2, iters=100)
std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
ax1.plot(used_ids, mean_values, color='C0')
ax1.fill_between(used_ids, std_down, std_up, color='C0', alpha=0.3)
# valid_ids 为迭代次数(每一千次为一个单位)
# valid_mse 为每1000次的validation值
used_ids, mean_values, std_vallues = mean_list(valid_ids, valid_mse, stride=2, iters=1000)
std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
ax2.plot(used_ids, mean_values, color='C1', label='Ours w/o PT')
ax2.fill_between(used_ids, std_down, std_up, color='C1', alpha=0.3)
ax1.set_xlabel(r'Iterations $\times 10^3$', fontdict={'family': 'Times New Roman', 'size': font_size})
ax1.set_ylabel('Training loss (BCE)', color='C0', fontdict={'family': 'Times New Roman', 'size': font_size})
ax2.set_ylabel('Validation result (VOI)', color='C1', fontdict={'family': 'Times New Roman', 'size': font_size})
plt.gca().xaxis.set_major_formatter(formatterx)
ax1.tick_params(labelsize=font_size)
labels = ax1.get_xticklabels() + ax1.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
ax2.tick_params(labelsize=font_size)
labels = ax2.get_xticklabels() + ax2.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
ax2.set_ylim((1, 6.5))
plt.show()
结果展示
参考
https://blog.csdn.net/qq_33757398/article/details/115056483?spm=1001.2014.3001.5501