先講一下Tensorflow模型的組成,如果是通過tf.train.Saver()儲存的模型,那麼會生成3種檔案:
- .meta是網絡結構,就是深度學習網絡的那些隐層和全連接配接層等的定義
- checkpoint是記錄輸出模型的checkpoint
- 剩下的檔案儲存的是模型網絡中的具體的參數
下面看下訓練代碼:
代碼很簡單,先定義兩個Tensor的變量w1和w2,b1是一個常量2,然後定義一個字典,其中w1是4,w2是8。接着定義op,op指的是Tensorflow計算符,tf.add将w1和w2相加,然後通過tf.multiply将w1和w2相加的結果乘以2。接着生成全局的參數tf.global_variables_intitializer(),就是初始化參數,取第1000次的checkpoint把模型儲存為my_test_model。這個代碼的意思是輸入w1和w2,然後模型會傳回(w1+w2)*b1的結果,b1是常量,等于2。
運作後模型就儲存下來,下面看下怎麼調用:
通過import_meta_graph這個函數加載訓練時的網絡結構,然後用restore方法加載網絡結構中的權重,到了這步預測模型就加載好了。接着設定一組預測值,使得w1=6,w2=7。擷取計算op,也就是當初訓練的時候定義的op名稱‘op_to_restore’。然後就可以把資料傳到op裡進行計算,生成的結果為(6+7)*2=26。
彩蛋:如果想基于已有的模型refine,可以在原有模型上增加計算op,參考第二張圖注釋部分,可以自己試下。