天天看點

基于pytorch對函數進行極值求解

1 import numpy as np
 2 from mpl_toolkits.mplot3d import Axes3D
 3 import matplotlib.pyplot as plt
 4 from matplotlib.colors import LinearSegmentedColormap
 5 
 6 # 待求極值的函數
 7 def himmelblau(t):# t[0]-->X; t[1]-->Y.
 8     return (t[0] ** 2 + t[1] - 11) ** 2 + (t[0] + t[1] ** 2 - 7) ** 2
 9 
10 x = np.arange(-6, 6, 0.1)
11 y = np.arange(-6, 6, 0.1)
12 X, Y = np.meshgrid(x, y)
13 Z = himmelblau([X, Y])
14 fig = plt.figure()
15 ax = fig.add_subplot(projection='3d')# ax = fig.gca(projection='3d') # ---> was deprecated in Matplotlib 3.4
16 ax.plot_surface(X, Y, Z)
17 ax.view_init(60, -30)
18 ax.set_xlabel('x')
19 ax.set_ylabel('y')
20 fig.show()
21 plt.show()
22 
23 # function test
24 def jeshy(t):
25     return t*3+10
26 
27 import torch
28 x = torch.tensor([0., 0.], requires_grad=True)
29 optimizer = torch.optim.Adam([x, ])# optim.Adam([var1, var2], lr=0.0001)# 優化器設定 ,并傳入模型參數和相應的學習率
30 for step in range(20001):
31     f = himmelblau(x)# 前向傳播
32     if step > 0:
33         optimizer.zero_grad()# 反向傳播與優化# 清空上一步的殘餘更新參數值
34         f.backward(retain_graph=True)# 反向傳播與優化# 反向傳播
35         optimizer.step()# 反向傳播與優化# 将參數更新值施加到函數f的parameters上
36     # f = jeshy(f)
37     if step % 1000 == 0:# 每疊代一定步驟,列印結果值
38         print('step:{}, x = {}, value = {}'.format(step, x.tolist(), f))      

 輸出:

step:0, x = [0.0, 0.0], value = 170.0

step:1000, x = [1.270142912864685, 1.1183991432189941], value = 88.53223419189453

step:2000, x = [2.332378387451172, 1.9535712003707886], value = 13.766233444213867

step:3000, x = [2.8519949913024902, 2.114161968231201], value = 0.6711398363113403

step:4000, x = [2.981964111328125, 2.0271568298339844], value = 0.014927156269550323

step:5000, x = [2.9991261959075928, 2.0014777183532715], value = 3.9870232285466045e-05

step:6000, x = [2.999983549118042, 2.0000221729278564], value = 1.1074007488787174e-08

step:7000, x = [2.9999899864196777, 2.000013589859009], value = 4.150251697865315e-09

step:8000, x = [2.9999938011169434, 2.0000083446502686], value = 1.5572823031106964e-09

step:9000, x = [2.9999964237213135, 2.000005006790161], value = 5.256879376247525e-10

step:10000, x = [2.999997854232788, 2.000002861022949], value = 1.8189894035458565e-10

step:11000, x = [2.9999988079071045, 2.0000014305114746], value = 5.547917680814862e-11

step:12000, x = [2.9999992847442627, 2.0000009536743164], value = 1.6370904631912708e-11

step:13000, x = [2.999999523162842, 2.000000476837158], value = 5.6843418860808015e-12

step:14000, x = [2.999999761581421, 2.000000238418579], value = 1.8189894035458565e-12

step:15000, x = [3.0, 2.0], value = 0.0

step:16000, x = [3.0, 2.0], value = 0.0

step:17000, x = [3.0, 2.0], value = 0.0

step:18000, x = [3.0, 2.0], value = 0.0

step:19000, x = [3.0, 2.0], value = 0.0

step:20000, x = [3.0, 2.0], value = 0.0

個人學習記錄