1 import numpy as np
2 from mpl_toolkits.mplot3d import Axes3D
3 import matplotlib.pyplot as plt
4 from matplotlib.colors import LinearSegmentedColormap
5
6 # 待求極值的函數
7 def himmelblau(t):# t[0]-->X; t[1]-->Y.
8 return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2
9
10 x = np.arange(-6, 6, 0.1)
11 y = np.arange(-6, 6, 0.1)
12 X, Y = np.meshgrid(x, y)
13 Z = himmelblau([X, Y])
14 fig = plt.figure()
15 ax = fig.add_subplot(projection='3d')# ax = fig.gca(projection='3d') # ---> was deprecated in Matplotlib 3.4
16 ax.plot_surface(X, Y, Z)
17 ax.view_init(60, -30)
18 ax.set_xlabel('x')
19 ax.set_ylabel('y')
20 fig.show()
21 plt.show()
22
23 # function test
24 def jeshy(t):
25 return t*3+10
26
27 import torch
28 x = torch.tensor([0., 0.], requires_grad=True)
29 optimizer = torch.optim.Adam([x, ])# optim.Adam([var1, var2], lr=0.0001)# 優化器設定 ,并傳入模型參數和相應的學習率
30 for step in range(20001):
31 f = himmelblau(x)# 前向傳播
32 if step > 0:
33 optimizer.zero_grad()# 反向傳播與優化# 清空上一步的殘餘更新參數值
34 f.backward(retain_graph=True)# 反向傳播與優化# 反向傳播
35 optimizer.step()# 反向傳播與優化# 将參數更新值施加到函數f的parameters上
36 # f = jeshy(f)
37 if step % 1000 == 0:# 每疊代一定步驟,列印結果值
38 print('step:{}, x = {}, value = {}'.format(step, x.tolist(), f))
輸出:
step:0, x = [0.0, 0.0], value = 170.0
step:1000, x = [1.270142912864685, 1.1183991432189941], value = 88.53223419189453
step:2000, x = [2.332378387451172, 1.9535712003707886], value = 13.766233444213867
step:3000, x = [2.8519949913024902, 2.114161968231201], value = 0.6711398363113403
step:4000, x = [2.981964111328125, 2.0271568298339844], value = 0.014927156269550323
step:5000, x = [2.9991261959075928, 2.0014777183532715], value = 3.9870232285466045e-05
step:6000, x = [2.999983549118042, 2.0000221729278564], value = 1.1074007488787174e-08
step:7000, x = [2.9999899864196777, 2.000013589859009], value = 4.150251697865315e-09
step:8000, x = [2.9999938011169434, 2.0000083446502686], value = 1.5572823031106964e-09
step:9000, x = [2.9999964237213135, 2.000005006790161], value = 5.256879376247525e-10
step:10000, x = [2.999997854232788, 2.000002861022949], value = 1.8189894035458565e-10
step:11000, x = [2.9999988079071045, 2.0000014305114746], value = 5.547917680814862e-11
step:12000, x = [2.9999992847442627, 2.0000009536743164], value = 1.6370904631912708e-11
step:13000, x = [2.999999523162842, 2.000000476837158], value = 5.6843418860808015e-12
step:14000, x = [2.999999761581421, 2.000000238418579], value = 1.8189894035458565e-12
step:15000, x = [3.0, 2.0], value = 0.0
step:16000, x = [3.0, 2.0], value = 0.0
step:17000, x = [3.0, 2.0], value = 0.0
step:18000, x = [3.0, 2.0], value = 0.0
step:19000, x = [3.0, 2.0], value = 0.0
step:20000, x = [3.0, 2.0], value = 0.0
個人學習記錄