天天看點

神經網絡一 動态圖和靜态圖

目錄

  • 一 動态圖和靜态圖
    • TensorFlow
    • PyTorch

一 動态圖和靜态圖

  本節内容來自這裡。

  目前神經網絡架構分為靜态圖架構和動态圖架構,PyTorch和TensorFlow、Caffe等架構最大的差別就是他們擁有不同的計算圖表現形式。TensorFlow1.*使用靜态圖(在TensorFlow2.*中使用的是動态圖),這意味着我們先定義計算圖,然後不斷使用它,而在PyTorch中,每次都會重新建構一個新的計算圖。

  靜态圖和動态圖有各自的優點。動态圖比較友善DEBUG,使用者能夠使用任何他們喜歡的方式進行debug,同時非常直覺,而靜态圖是通過先定義後運作的方式,之後再次運作的時候就不再需要重新建構計算圖,是以速度會比動态圖更快。

import torch
from torch.autograd import Variable
 
x=Variable(torch.randn(1,10))
prev_h=Variable(torch.randn(1,20))
W_h=Variable(torch.randn(20,20))
W_x=Variable(torch.randn(20,10))
 
i2h=torch.mm(W_x,x.t())
h2h=torch.mm(W_h,prev_h.t())
           
神經網絡一 動态圖和靜态圖

  比較while循環語句在TensorFlow和PyTorch中的定義。

TensorFlow

import tensorflow as tf
 
first_counter=tf.constant(0)
second_counter=tf.constant(10)
 
def cond(first_counter,second_counter,*args):
    return first_counter<second_counter
 
def body(first_counter,second_counter):
    first_counter=tf.add(first_counter,2)
    second_counter=tf.add(second_counter,1)
    return first_counter,second_counter
 
c1,c2=tf.while_loop(cond,body,[first_counter,second_counter])
 
with tf.Session() as sess:
    counter_1_res,counter_2_res=sess.run([c1,c2])
 
print(counter_1_res)
print(counter_2_res)
           

  可以看到TensorFlow需要将整個圖構成靜态的,每次運作的時候圖都是一樣的,是不能夠改變的,是以不能直接使用Python的while循環語句,需要使用輔助函數tf.while_loop寫成TensorFlow内部形式。

PyTorch

import torch
first_counter=torch.Tensor([0])
second_counter=torch.Tensor([10])
 
while(first_counter<second_counter):
    first_counter+=2
    second_counter+=1
 
print(first_counter)
print(second_counter)
           
tensor([20.])
tensor([20.])
           

  可以看到PyTorch的寫法和Python的寫法是完全一緻的,沒有任何額外的學習成本。動态圖的方式更加簡單且直覺。

繼續閱讀