1、概述
solver算是caffe中比較核心的一個概念,在我們訓練train我們的網絡時,就必須要帶上這個參數,
如下例是我要對Lenet進行訓練的時候要調用的程式,現在不知道什麼意思沒關系,隻需要知道這個solver.prototxt是個必不可少的東西就ok
./build/tools/caffe train--solver=examples/mnist/lenet_solver.prototxt
Solver通過協調Net的前向推斷計算和反向梯度計算對參數進行更新,進而達到減小loss的目的
Solver的主要功能:
○ 設計好需要優化的對象,建立train網絡和test網絡。(通過調用另外一個配置檔案prototxt來進行)
○ 通過forward和backward疊代的進行優化來更新新參數。
○ 定期的評價測試網絡。 (可設定多少次訓練後,進行一次測試)。
○ 在優化過程中記錄模型和solver的狀态的快照。
在每一次的疊代過程中,solver做了這幾步工作:
1、調用forward算法來計算最終的輸出值,以及對應的loss
2、調用backward算法來計算每層的梯度
3、根據選用的slover方法,利用梯度進行參數更新
4、根據學習率、曆史資料、求解方法更新solver狀态,使得權重從初始化狀态逐漸更新到最終的狀态。
2、caffe.proto關于solver的描述
雖然内容很多,但是基本上是注釋占的篇幅多,而且我也基本上都翻譯成了中文注釋,建議仔細閱讀,這是一切solver的模版
// NOTE
// Update the next available ID when you add a new SolverParameter field.
// ## 注意,如果你要增加一個新的sovler參數,需要給它更新ID,就是下面内容中的數字
// SolverParameter next available ID: 41 (last added: type) ##下一個可用的ID是41,上一次caffe增加的是type參數,就是下文中的40
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
// ##指定訓練和測試網絡
// Exactly one train net must be specified using one of the following fields:
// train_net_param, train_net, net_param, net
// One or more test nets may be specified using any of the following fields:
// test_net_param, test_net, net_param, net
// If more than one test net field is specified (e.g., both net and
// test_net are specified), they will be evaluated in the field order given
// above: (1) test_net_param, (2) test_net, (3) net_param/net.
// A test_iter must be specified for each test_net.
// A test_level and/or a test_stage may also be specified for each test_net.
//////////////////////////////////////////////////////////////////////////////
// Proto filename for the train net, possibly combined with one or more
// test nets. ##這個訓練網絡的Proto檔案名,可能結合一個或多個測試網絡。
optional string net = 24;
// Inline train net param, possibly combined with one or more test nets. ## 對應的訓練網絡的參數,可能結合一個或多個測試網絡
optional NetParameter net_param = 25;
optional string train_net = 1; // Proto filename for the train net. ## train net的proto檔案名
repeated string test_net = 2; // Proto filenames for the test nets. ## test nets的proto檔案名
optional NetParameter train_net_param = 21; // Inline train net params. ## 與上面train網絡一緻對應的參數
repeated NetParameter test_net_param = 22; // Inline test net params. ## 與上面test網絡一緻對應的參數
// The states for the train/test nets. Must be unspecified or
// specified once per net.
// ## train/test網絡的狀态。 必須是未指定或每個網絡指定一次
// By default, all states will have solver = true; ##預設情況下,所有狀态都将有solver = true;
// train_state will have phase = TRAIN, ##train_state會有phase = TRAIN,
// and all test_state's will have phase = TEST. ##所有的test_state都将進行phase = TEST
// Other defaults are set according to the NetState defaults. ##其他預設值是根據NetState預設設定的。
optional NetState train_state = 26;
repeated NetState test_state = 27;
// The number of iterations for each test net. ## test網絡的疊代次數:
repeated int32 test_iter = 3;
// The number of iterations between two testing phases.
// ## 兩次test之間(train)的疊代次數
//## <訓練test_interval個批次,再測試test_iter個批次,為一個回合(epoch), 合理設定應使得每個回合内,周遊覆寫到全部訓練樣本和測試樣本 >
optional int32 test_interval = 4 [default = 0];
optional bool test_compute_loss = 19 [default = false]; // ## 預設不計算測試時損失
// If true, run an initial test pass before the first iteration,
// ensuring memory availability and printing the starting value of the loss.
// ##如設定為真,則在訓練前運作一次測試,以確定記憶體足夠,并列印初始損失值
optional bool test_initialization = 32 [default = true];
optional float base_lr = 5; // The base learning rate ##基本學習速率
// the number of iterations between displaying info. If display = 0, no info
// will be displayed. ##列印資訊的周遊間隔,周遊多少個批次列印一次資訊。設定為0則不列印。
optional int32 display = 6;
// Display the loss averaged over the last average_loss iterations ## 列印最後一個疊代批次下的平均損失
optional int32 average_loss = 33 [default = 1];
optional int32 max_iter = 7; // the maximum number of iterations ##train的最大疊代次數
// accumulate gradients over `iter_size` x `batch_size` instances
// ## 累積梯度誤差基于“iter_size×batchSize”個樣本執行個體,< “批次數×批量數”=“周遊的批次數×每批的樣本數”個樣本執行個體 >
optional int32 iter_size = 36 [default = 1];
// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
//##學習率衰退政策.目前實行的學習率政策如下:
// - fixed: always return base_lr. ##保持base_lr不變.
// - step: return base_lr * gamma ^ (floor(iter / step)) ##傳回 base_lr * gamma ^(floor(iter / stepsize)),
// - exp: return base_lr * gamma ^ iter ##傳回base_lr * gamma ^ iter, iter為目前疊代次數
// - inv: return base_lr * (1 + gamma * iter) ^ (- power) ##如果設定為inv,還需設定一個power,傳回return 後的内容
// - multistep: similar to step but it allows non uniform steps defined by ##這個參數和step很相似,還需要設定一個stepvalue。
// stepvalue ##但step是均勻等間隔變化,而此參數根據stepvalue變化
// - poly: the effective learning rate follows a polynomial decay, to be ##學習率進行多項式衰減,由max_iter變為0
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) , ##傳回 base_lr (1- iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay ##學習率進行sigmod衰減,
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) ##傳回return 後的内容
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
// ## 在上述參數中,base_lr, max_iter, gamma, step, stepvalue and power 被定義
// 在solver.prototxt檔案中,iter是目前疊代次數。
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value. ## 動量
optional float weight_decay = 12; // The weight decay. ##權值衰減系數
// regularization types supported: L1 and L2
// controlled by weight_decay
// ## 由權值衰減系數所控制的正則化類型:L1或L2範數,預設L2
optional string regularization_type = 29 [default = "L2"];
// the stepsize for learning rate policy "step" ##"step"政策下,學習率的步長值
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep" ## "multistep"政策下的步長值
repeated int32 stepvalue = 34;
// Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
// whenever their actual L2 norm is larger.
optional float clip_gradients = 35 [default = -1];
optional int32 snapshot = 14 [default = 0]; // The snapshot interval ##快照間隔<周遊多少次對模型和求解器狀态儲存一次>
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
// ## 是否對diff快照,有助調試,但最終的protocol buffer尺寸會很大
optional bool snapshot_diff = 16 [default = false];
// ## 快照資料儲存格式{ hdf5,binaryproto(預設) }
enum SnapshotFormat {
HDF5 = 0;
BINARYPROTO = 1;
}
optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. ##選CPU或GPU模式,預設是GPU
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [default = GPU];
// the device_id will that be used in GPU mode. Use device_id = 0 in default. ##如果選了GPU模式,此參數指定哪個GPU,預設是0号GPU
optional int32 device_id = 18 [default = 0];
// If non-negative, the seed with which the Solver will initialize the Caffe
// random number generator -- useful for reproducible results. Otherwise,
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1];
// type of the solver ## 求解器類型=SGD(預設),目前一共有6種
optional string type = 40 [default = "SGD"];
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
//## RMSProp,AdaGrad和AdaDelta和Adam的數值穩定性
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver ## Adam類型時的參數
optional float momentum2 = 39 [default = 0.999];
// RMSProp decay value ##RMSProp的衰減值
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38;
// If true, print information about the state of the net that may help with
// debugging learning problems.
//## 此參數預設為false,若為true,則列印網絡狀态資訊,有助于調試問題
optional bool debug_info = 23 [default = false];
// If false, don't save a snapshot after training finishes.
//## 此參數預設為true,若為false,則不會在訓練後儲存快照
optional bool snapshot_after_train = 28 [default = true];
// DEPRECATED: old solver enum types, use string instead ##已經棄用,本來表示6種sovler類型,現在用string type中的string代替
enum SolverType {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
ADAM = 5;
}
// DEPRECATED: use type instead of solver_type ##已經棄用,用string type中的type代替
optional SolverType solver_type = 30 [default = SGD];
}
複制
3、舉例說明
我仍以caffe/examples/mnist/lenet_solver.prototxt這個檔案為例,下圖是我的截圖
我把上圖的内容複制過來看的清楚一些,并把注釋翻譯了一下:
---------------這一部分可以對照着2中proto中的描述看,你會發現其實solver的編寫也就是對着模版填參數的一個過程,----------
# 我們需要的Net的模型,這個模型定義在另一個prototxt檔案中,這個就是我上一篇博文舉的Net的例子
# 顯然這裡根據需要你可以選擇其他的一些Net
net: "examples/mnist/lenet_train_test.prototxt"
#test_iter 設定了test一共疊代多少次,這裡是100
# 至于test每一次疊代處理多少張圖檔,在Net那個prototxt裡面batch_size規定了的
test_iter: 100
# 訓練每疊代500次,測試一次(這每一次測試要疊代100次).
test_interval:500
#設定學習率。base_lr用于設定基礎學習率,在疊代的過程中,可以對基礎學習率進行調整。怎麼樣進行調整,就是調整的政策,由lr_policy來設定。
#momentum稱為動量,使得權重更新更為平緩
#weight_decay稱為衰減率因子,防止過拟合的一個參數
base_lr: 0.01
momentum: 0.9
# 這裡省略了一個内容 type: SGD ,就是solver方法的選擇,因為預設就是SGD,是以這個 solver. prototxt 中省略沒寫, 如果你想用其他的sovler方法就要指明寫出來
weight_decay:0.0005
# 學習率調整的政策,詳細見我下面的補充
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# train每疊代100次就顯示一次
display: 100
#train最大疊代次數
max_iter: 10000
#快照。将訓練出來的model和solver狀态進行儲存,snapshot用于設定訓練多少次後進行儲存,預設為0,不儲存。snapshot_prefix設定儲存路徑。
還可以設定snapshot_diff,是否儲存梯度值,預設為false,不儲存。
也可以設定snapshot_format,儲存的類型。有兩種選擇:HDF5和BINARYPROTO,預設為BINARYPROTO
#這裡設定train每疊代5000次就存儲一次資料
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
#設定運作模式,預設為GPU,如果你沒有GPU,則需要改成CPU,否則會出錯.
solver_mode: GPU
4、solver方法
Solver方法就是計算最小化損失值(loss)的方法,也就是我上面解析中說的省略掉的一行,其實一共有6種sovler方法:
· Stochastic Gradient Descent (type: "SGD"),
· AdaDelta (type: "AdaDelta"),
· Adaptive Gradient (type: "AdaGrad"),
· Adam (type: "Adam"),
· Nesterov’s Accelerated Gradient (type: "Nesterov") and
· RMSprop (type: "RMSProp")
預設設定的是SGD 随機梯度下降,是以就可以不寫,但是如果想用其他的,就必須要寫出來,比如type:Adam
這個方法對于我這種小白來說暫時沒有研究的必要,而且SGD方法的數學原理至少我是知道的,是以我這裡就隻把這幾種方法列出來了,沒有詳細解讀,如果有興趣可以參考下面這篇部落格:
http://www.cnblogs.com/denny402/p/5074212.html,
這篇關于sovler講解的博文就寫完了,下面一篇就準備來将一下caffe中的hello world--使用Lenet來識别mnist手寫資料。