天天看点

Matplotlib画论文图之loss和validation曲线前言码代码结果展示参考

前言

  1. 每迭代100次保存一次loss值
  2. 每迭代1000次保存一次validation值
  3. 曲线表示平均值,背景表示方差

码代码

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter

def mean_list(valid_ids, valid_values, stride=2, iters=1000):
    length = len(valid_ids)
    nums = length // stride
    used_ids = []
    mean_values = []
    std_vallues = []
    valid_values = np.asarray(valid_values)
    for k in range(nums):
        used_ids.append(int(k*stride*iters))
        temp_list = valid_values[k:(k+1)*stride]
        mean = np.mean(temp_list)
        std = np.std(temp_list)
        mean_values.append(mean)
        std_vallues.append(std)
    return used_ids, mean_values, std_vallues

# 设置全局格式,包括字体风格和大小等等
# 这里主要用来修改文本字体里面的格式
font_size = 25
config = {
    "font.family":'serif',
    "font.size": font_size,
    "mathtext.fontset":'stix',
    "font.serif": ['SimSun'],
}
rcParams.update(config)

# 修改x轴的显示方式,科学计数法
def formatnumx(x, pos):
    return '%d' % (x/1000)
formatterx = FuncFormatter(formatnumx)

fig, ax1 = plt.subplots(figsize=(7,5),dpi=100)
ax2 = ax1.twinx()

# train_ids 为迭代次数(每一百次为一个单位)
# train_mse 为每100次的loss值
used_ids, mean_values, std_vallues = mean_list(train_ids, train_mse, stride=2, iters=100)
std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
ax1.plot(used_ids, mean_values, color='C0')
ax1.fill_between(used_ids, std_down, std_up, color='C0', alpha=0.3)

# valid_ids 为迭代次数(每一千次为一个单位)
# valid_mse 为每1000次的validation值
used_ids, mean_values, std_vallues = mean_list(valid_ids, valid_mse, stride=2, iters=1000)
std_down = [mean_values[x]-std_vallues[x] for x in range(len(mean_values))]
std_up = [mean_values[x]+std_vallues[x] for x in range(len(mean_values))]
ax2.plot(used_ids, mean_values, color='C1', label='Ours w/o PT')
ax2.fill_between(used_ids, std_down, std_up, color='C1', alpha=0.3)

ax1.set_xlabel(r'Iterations $\times 10^3$', fontdict={'family': 'Times New Roman', 'size': font_size})
ax1.set_ylabel('Training loss (BCE)', color='C0', fontdict={'family': 'Times New Roman', 'size': font_size})
ax2.set_ylabel('Validation result (VOI)', color='C1', fontdict={'family': 'Times New Roman', 'size': font_size})

plt.gca().xaxis.set_major_formatter(formatterx)

ax1.tick_params(labelsize=font_size)
labels = ax1.get_xticklabels() + ax1.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
ax2.tick_params(labelsize=font_size)
labels = ax2.get_xticklabels() + ax2.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]

ax2.set_ylim((1, 6.5))
plt.show()
           

结果展示

Matplotlib画论文图之loss和validation曲线前言码代码结果展示参考

参考

https://blog.csdn.net/qq_33757398/article/details/115056483?spm=1001.2014.3001.5501