天天看點

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:

1. xflow_demo中部分資料結構

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:
【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:

2. AUC計算方法:

    AUC表示正樣本排在負樣本前面的機率。采用排列組合,找出所有正樣本排在負樣本前面的情況,然後除以所有的組合,即可得到AUC。

struct auc_key{
    int label;
    float pctr;
  };

  void calculate_auc(std::vector<auc_key>& auc_vec) {
    std::sort(auc_vec.begin(), auc_vec.end(), [](const auc_key& a,
          const auc_key& b){
        return a.pctr > b.pctr;
        });
    float area = 0.0;
    int tp_n = 0;
    for (size_t i = 0; i < auc_vec.size(); ++i) {
      if (auc_vec[i].label == 1) {
        tp_n += 1;
      } else {
        area += tp_n;
      }
      logloss += auc_vec[i].label * std::log2(auc_vec[i].pctr)+
        + (1.0 - auc_vec[i].label) * std::log2(1.0 - auc_vec[i].pctr);
    }
    logloss /= auc_vec.size();
    std::cout << "logloss: " << logloss << "\t";
    if (tp_n == 0 || tp_n == auc_vec.size()) {
      std::cout << "tp_n = " << tp_n << std::endl;
    } else {
      area /= 1.0 * (tp_n * (auc_vec.size() - tp_n));
      std::cout << "auc = " << area
        << "\ttp = " << tp_n
        << " fp = " << auc_vec.size() - tp_n << std::endl;
    }
  }
           

3. 參數伺服器配置設定結構

      server用于更新權重。不同種類的權重可以放在不同的參數伺服器上更新,比如w和b可以放在不同的伺服器上進行參數更新。或者w的次元太大時候,也可以放在多台伺服器上進行更新。

     work主要進行預測或者訓練(計算梯度),可以使用一台work進行預測。使用其他的work進行訓練。(預測主要用于驗證,計算AUC等等,再每一輪訓練完畢後,可以進行預測)

#include "ps/ps.h"

int main(int argc, char *argv[]) {
  if (ps::IsScheduler()) {
    std::cout << "start scheduler" << std::endl;
  }
  if (ps::IsServer()) {
    std::cout << "start server" << std::endl;
    xflow::Server* server = new xflow::Server();
  }
  ps::Start();
  if (ps::IsWorker()) {
    std::cout << "start worker" << std::endl;
    int epochs = std::atoi(argv[4]);
    if (*(argv[3]) == '0') {
      std::cout << "start LR " << std::endl;
      xflow::LRWorker* lr_worker = new xflow::LRWorker(argv[1], argv[2]);
      lr_worker->epochs = epochs;
      lr_worker->train();
    }
    if (*(argv[3]) == '1') {
      std::cout << "start FM " << std::endl;
      xflow::FMWorker* fm_worker = new xflow::FMWorker(argv[1], argv[2]);
      fm_worker->epochs = epochs;
      fm_worker->train();
    }
    if (*(argv[3]) == '2') {
      std::cout<< "start MVM " << std::endl;
      xflow::MVMWorker* mvm_worker = new xflow::MVMWorker(argv[1], argv[2]);
      mvm_worker->epochs = epochs;
      mvm_worker->train();
    }
  }
  ps::Finalize();
}
           

 LRWorker部分代碼:

void LRWorker::train() {
    rank = ps::MyRank();
    std::cout << "my rank is = " << rank << std::endl;
    snprintf(train_data_path, 1024, "%s-%05d", train_file_path, rank);
    batch_training(pool_);
    if (rank == 0) {
      std::cout << "LR AUC: " << std::endl;
      predict(pool_, rank, 0);
    }
    std::cout << "train end......" << std::endl;
  }
}  // namespace xflow
           

4.參數更新

對于結構體

typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;
           

    std::unordered_map<ps::Key, sgdentry_w> store;

ps::Key: 特征索引

sgdentry_w:特征對應的w值,如果沒有做embeeding,則w_dim = 1;  如果做了Embeeding,則是embedding的大小。

如下是server端SGD優化的梯度更新代碼。

namespace xflow {
extern int w_dim;
extern int v_dim;
float learning_rate = 0.001;

class SGD {
 public:
  SGD() {}
  ~SGD() {}

  typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;

  struct KVServerSGDHandle_w {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        w_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / w_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * w_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_w& val = store[key];
        for (int j = 0; j < w_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * w_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < w_dim; ++j) {
              res.vals[i * w_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_w> store;
  };

  typedef struct SGDEntry_v {
    SGDEntry_v(int k = v_dim) {
      w.resize(k, 0.001);
    }
    std::vector<float> w;
  } sgdentry_v;

  struct KVServerSGDHandle_v {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        v_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / v_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * v_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_v& val = store[key];
        for (int j = 0; j < v_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * v_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < v_dim; ++j) {
              res.vals[i * v_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_v> store;
  };

 private:
};
}  // namespace xflow
           

5. 梯度計算

參考:https://github.com/ljzzju/logistic-regression-ftrl-ps

如下mb_g中存儲的是梯度。

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:

轉載:https://zhuanlan.zhihu.com/p/109841554

6 其他知識點介紹

Node的表示

在ps-lite中,各個節點發送消息,自然需要id來辨別需要發送到哪些節點。ps-lite采用的編号方式也十分簡單:

  • Group id:scheduler=1,kServerGroup=2,kWorkerGroup=4;(定義于頭檔案base.h中)
  • node id:server id = rank * 2 + 8;worker id = rank * 2 + 9;(定義于postoffice.h中)

為什麼要這麼定義呢?首先看group id,寫為2進制分别為 b001, b010, b100, 這樣要給多個group中的node發送資訊,隻需要發送給 或運算 計算後的号碼即可。比如說要同時發送給三個group,就發送給7=b111這個id。是以,node id是從8号開始,且這個算法保證server id為偶數,node id為奇數。

消息封裝

這個部分的代碼位于ps/internel/message.h。回顧UML圖中 Node —— Control ——Meta —— Message 這個依賴鍊(雖然與其說是依賴關系,我認為合成關系更加貼切)并類比一下http包的結構:基本上可以看出一個消息(Message)具有消息頭(Meta)和消息體,消息頭中含有控制資訊(Control)表示這個消息表示的意義(例如終止,确認ACK,同步等),控制資訊中含有節點id(vector<Node>)以及group id表示這個控制指令對誰執行。【控制指令=EMPTY時,代表這個消息是普通消息】

小結一下:

  • Node:存放node的ip,port,類型,id等資訊;
  • Control:存放command類型,barrier_group(用于辨別哪些節點需要同步,當command=BARRIER時使用),node(Node類,用于辨別控制指令對哪些節點使用)等;
  • Meta:時間戳,發送者id,接受者id,控制資訊Control,消息類型等;
  • Message:消息頭Meta,消息體data;

每次發送消息時,各個node都按這個格式封裝好,負責發送消息的類成員(Customer類)就會按照Meta之中的資訊将消息送貨上門了。

Node之間的協同工作

接下來看一看UML圖的上半部分是怎麼一起合作,構成一個具有parameter server的系統。在講解之前,先補全幾個小的資料結構:

  • SArray:可以把SArray了解為vector,隻不過使用了share_ptr來管理SArray的記憶體:當對某個SArray的引用為0時,就自動回收該SArray的記憶體。(sarray.h)
  • KVPairs:包含SArray<Key> keys, SArray<Val> vals, SArray<int> lens的模闆類。Key其實是int64的别名,Val是模闆變量。lens和keys等長,表示每個key對應的val有多長。每個server其實對應一段連續的key,存放這些key對應的val。(kv_app.h)
  • Range:定義begin()和end()兩個位置(uint64),.size()獲得size=end-begin。根據這個Range确定要拉取的參數在哪個server上,以及一個server對應的key的range。(range.h)
  • TreadsafeQueue:一個可以供多個線程讀取的隊列,有Push和WaitAndPop兩個方法。通過鎖和條件量合作來達到線程安全,用來做消息隊列。(threadsafe_queue.h)

再補全一下各個工作類的作用:

  • Environment:一個單例模式類,它通過一個ordered_map儲存環境變量名以及值;
  • PostOffice:一個單例模式類,一個node在生命期内具有一個PostOffice,依賴它的類成員對Node進行管理;
  • Van:負責Message的實際收發工作,比PostOffice更加底層,PostOffice持有一個Van成員;
  • SimpleApp:KVServer和KVWorker的父類,它提供了簡單的Request, Wait, Response,Process功能;KVServer和KVWorker分别根據自己的使命重寫了這些功能;
  • Customer:每個SimpleApp對象持有一個Customer類的成員,且Customer需要在PostOffice進行注冊,這個成員負責 1.跟蹤由SimpleApp發送出去的消息的回複情況 2. 維護一個Node的消息隊列,為Node接收消息;

類比一下http在網絡中經過的層次關系,我認為ps-lite中的一個Message經過的層次關系是這樣的:(這個層次關系可以說是本文精華了QAQ)

SimpleApp --> PostOffice --> Van --> Van --> PostOffice -->Customer-->SimpleApp

做好了預備工作,下面整體說一下系統的啟動流程:

設定環境變量(參考腳本tests/local.sh):

# set global information
export DMLC_NUM_SERVER=$var1
export DMLC_NUM_WORKER=$var2
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
# setting for scheduler
export DMLC_ROLE='scheduler'
# setting for servers
export DMLC_ROLE='server'
# setting for worker
export DMLC_ROLE='worker'
           

除了DMLC_ROLE設定得不同外,其他的變量在每個節點上都相同。

設定好環境變量後,看tests/test simple http://app.cc這個簡單example:

Start(0);// 啟動網絡環境,設定本節點的資訊
SimpleApp app(0, 0); // 啟動一個工作類
           

第一句話做了什麼

ps::Start(custom_id, argv0) or ps::StartAsync(custom id, argv0)
           
  • 調用PostOffice::Start(custom_id, argv0, barrier=true/false)
PostOffice::Start(custom_id, argv0, barrier=true/false)
// 
// PostOffice管理該Node的所有事;
           
  • 如果是第一次調用PostOffice::Start,
    • 先調用InitEnvironment,根據環境變量設定成員num_workers_, num_servers_, is_worker_, is_server_, is_scheduler_, verbose_ 的值
    • 設定node_ids_成員:
  • 調用Van::Start(custom_id);
  • 如果是第一次調用PostOffice::Start,初始化start_time_成員;
  • 如果barrier=true,則所有Node準備并向Scheduler發送要求同步的Message,進行第一次同步;
Van::Start(custom_id)
// Van負責比PostOffice更底層的網絡通信;
           
  • 如果是第一次調用Van::Start,則設根據環境變量設定scheduler_的Node資訊(ip:port),如果該節點不是scheduler節點,則需要再根據環境變量設定好本節點my_node_的Node資訊;之後該node與端口綁定,然後連接配接scheduler;建立線程receiver_thread_
  • 如果該node不是scheduler,則給scheduler發送一條資訊,直到ready_被設為true開始繼續運作(此時所有節點已經ready_);
  • 如果是第一次調用Van::Start,則再建立一個Resender對象,和一個heartbeat_thread_;

第二句話做了什麼?

SimpleApp(app_id, custom_id)
           
  • 建立一個Custom對象初始化obj_成員;
Customer(app_id, custom_id, recv_handle)
           
  • 分别用傳入構造函數的參數初始化app_id_, custom_id_, recv_handle成員
  • 調用PostOffice::AddCustomer将目前Customer注冊到PostOffice;
    • PostOffice的customers_成員: 在對應的app_id的元素上添加custom_id;
    • PostOffice的barrier_done_成員将該custom_id的同步狀态設為false
  • 新起一個Receiving線程recv_thread_;

了解了SimpleApp,它的兩個子類的構造就容易了解了:

KVServer(app_id)

  1. 建立一個Customer對象初始化obj_成員,用KVServer::Process傳入Customer構造函數,對于Server來說,app_id=custom_id=server's id;

KVWorker(app_id, custom_id)

  1. 用預設的KVWorker::DefaultSlicer綁定slicer_成員;
  2. 建立一個Customer對象初始化obj_成員,不傳入handle參數;

工作類SimpleApp,KVWorker,KVServer的api

SimpleApp

  1. set_request_handle,set_response_handle:設定成員request_handle_, response_handle_。在用戶端調用SimpleApp::Process時,根據message.meta中的訓示變量判斷是request還是response,調用相應handle處理;

KVServer

  1. set_request_handle,在調用KVServer::Process時,該函數使用request_handle處理message

KVWorker

  1. set_slicer:設定slicer_成員,該函數在調用Send函數時,将KVPairs按照每個server的Range切片;
  2. Pull(key_vector, val_vector, option: len_vector, cmd, callback):根據key_vector從Server上拉取val_vector,傳回timestamp,該函數不阻塞,可用worker.Wait(timestamp)等待;
  3. ZPull同理Pull,不過在調用内部Pull_函數時,不會copy一個key_vector,是以需要保證在ZPull完成前,調用者沒有改變key_vector;
  4. Push(key_vector, val_vector, optional: len_vector, cmd, callback) 以及ZPush:
    1. 由obj_成員準備一個送到ServerGroup的request傳回stamp;
    2. 設定好對應timestamp的callback;
    3. 使用傳入的參數構造KVPair對象,調用Send送出該對象;

7.  總體代碼講解:

注明:如下LR實作的代碼,使用的資料都是類别行資料,然後進行了編碼。是以其 value統一是1,在編碼過程中會有一些差別。

void LRWorker::calculate_loss(std::vector<float>& w,
                      std::vector<Base::sample_key>& all_keys,
                      std::vector<ps::Key>& unique_keys,
                      size_t start,
                      size_t end,
                      std::vector<float>& loss) {
    auto wx = std::vector<float>(end - start);
    for (int j = 0, i = 0; j < all_keys.size(); ) {
      size_t allkeys_fid = all_keys[j].fid;
      size_t weight_fid = (unique_keys)[i];
      if (allkeys_fid == weight_fid) {
        // sid : 樣本ID
        // 該處直接加w[i],相當于預設val的值是1.
        wx[all_keys[j].sid] += (w)[i];
        ++j;
      }
      else if (allkeys_fid > weight_fid) {
        ++i;
      }
    }
    for (int i = 0; i < wx.size(); i++) {
      float pctr = base_->sigmoid(wx[i]);
      loss[i] = pctr - train_data->label[start++];
    }
  }
           

當val的值不是全部設定為1時,正确的寫法:

參考:https://github.com/ljzzju/logistic-regression-ftrl-ps

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:

轉載:https://zhuanlan.zhihu.com/p/109841554

1. ps-lite

ps-lite是參數伺服器(ps)的一種輕量實作,用于建構高可用分布式的機器學習應用。通常是多個節點運作在多台實體機器上用于處理機器學習問題,一般包含一個schedule節點和若幹個worker/server節點。

ps-lite的三種節點:

  • Worker:負責主要的計算工作,如讀取資料,資料預處理,梯度計算等。它通過push和pull的方式和server節點進行通信。例如,worker節點push計算得到的梯度到server,或者從server節點pull最新的參數。
  • Server:用于維護和更新模型權重。每個server節點維護部分模型資訊。
  • Scheduler:用于監聽其他節點的存活狀态。也負責發送控制指令(比如Server和Worker的連接配接通信),并且收集其它節點的工作進度。

ps-lite 支援同步和異步兩種機制

【分布式】基于ps-lite的分布式計算執行個體解析1. xflow_demo中部分資料結構2. AUC計算方法:3. 參數伺服器配置設定結構4.參數更新5. 梯度計算6 其他知識點介紹7.  總體代碼講解:
  1. 在同步的機制下,系統運作的時間是由最慢的worker節點與通信時間決定的
  2. 在異步的機制下,每個worker不用等待其它workers完成再運作下一次疊代。這樣可以提高效率,但從疊代次數的角度來看,會減慢收斂的速度。

ps-lite常應用于推薦系統中的ctr預測中,用以解決資料量較大,特征較多的問題,同時也有很多開發者貢獻自己的應用源碼。本文就結合基于ps-lite實作的分布式lr為例,記錄一下ps-lite的具體使用,代碼放在了

https://github.com/peterzhang2029/xflow_demo

https://github.com/xswang/xflow 

2. 代碼執行個體

接下來就以分布式LR為例,簡單介紹一下代碼的各部分功能:

主函數:

int main(int argc, char *argv[]) {
  if (ps::IsScheduler()) {
    std::cout << "start scheduler" << std::endl;
  }
  if (ps::IsServer()) {
    std::cout << "start server" << std::endl;
    xflow::Server* server = new xflow::Server();
  }
  ps::Start();
  if (ps::IsWorker()) {
    std::cout << "start worker" << std::endl;
    int epochs = std::atoi(argv[4]);
    if (*(argv[3]) == '0') {
      std::cout << "start LR " << std::endl;
      xflow::LRWorker* lr_worker = new xflow::LRWorker(argv[1], argv[2]);
      lr_worker->epochs = epochs;
      lr_worker->train();
    }
    if (*(argv[3]) == '1') {
      std::cout << "start FM " << std::endl;
      xflow::FMWorker* fm_worker = new xflow::FMWorker(argv[1], argv[2]);
      fm_worker->epochs = epochs;
      fm_worker->train();
    }
    if (*(argv[3]) == '2') {
      std::cout<< "start MVM " << std::endl;
      xflow::MVMWorker* mvm_worker = new xflow::MVMWorker(argv[1], argv[2]);
      mvm_worker->epochs = epochs;
      mvm_worker->train();
    }
  }
  ps::Finalize();
}
           
  • 對于 Scheduler 節點,直接調用ps-lite的Start()和Finalize()啟動和準備回收;
  • 對于 Server 節點,可以定義一個單獨的Server類;
  • 對于 Worker 節點,對于不同的模型方法定義對應的處理類,例如:xflow::LRWorker;

2.1 Scheduler 節點

對于scheduler節點進行初始化,Scheduler機器對應的IP位址設定為$DMLC_PS_ROOT_URI

2.2 Server 節點

Server類的定義:

namespace xflow {
class Server {
 public:
  Server() {
    server_w_ = new ps::KVServer<float>(0);
    server_w_->set_request_handle(SGD::KVServerSGDHandle_w());
    //server_w_->set_request_handle(FTRL::KVServerFTRLHandle_w());

    server_v_ = new ps::KVServer<float>(1);
    //server_v_->set_request_handle(FTRL::KVServerFTRLHandle_v());
    server_v_->set_request_handle(SGD::KVServerSGDHandle_v());
    std::cout << "init server success " << std::endl;
  }
  ~Server() {}
  ps::KVServer<float>* server_w_;
  ps::KVServer<float>* server_v_;
};
}  // namespace xflow
#endif  // SRC_MODEL_SERVER_H_
           

其中定義了兩個KVServer類型的權值: server_w_、server_v_對應模型的w和b,然後将在request handler中定義優化政策,這部分是server的核心處理邏輯,以SGD為例:

typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;

  struct KVServerSGDHandle_w {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        w_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / w_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * w_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_w& val = store[key];
        for (int j = 0; j < w_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * w_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < w_dim; ++j) {
              res.vals[i * w_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_w> store;
  };
           
  • 對于push操作,接收來自Worker計算的梯度,然後記錄到維護的map裡
  • 對于pull操作,按照key将上一輪更新的weight拿出來傳回給Worker
  • 具體的優化算法,同步或者異步都可以在這裡實作,上述實作的是簡單的異步SGD

2.3 Worker 節點

Worker類的構造函數:

LRWorker(const char *train_file,
    const char *test_file) :
                           train_file_path(train_file),
                           test_file_path(test_file) {
    kv_w_ = new ps::KVWorker<float>(0);
    base_ = new Base;
    core_num = std::thread::hardware_concurrency();
    pool_ = new ThreadPool(core_num);
  }
           

此處定義了一個線程池,利用多線程來計算梯度,代碼中也包含了線程池的具體實作。

讀取資料之後,進行訓練:

while (1) {
        train_data_loader.load_minibatch_hash_data_fread();
        if (train_data->fea_matrix.size() <= 0) break;
        int thread_size = train_data->fea_matrix.size() / core_num;
        gradient_thread_finish_num = core_num;
        for (int i = 0; i < core_num; ++i) {
          int start = i * thread_size;
          int end = (i + 1)* thread_size;
          pool->enqueue(std::bind(&LRWorker::update, this, start, end));
        }
        while (gradient_thread_finish_num > 0) {
          usleep(5);
        }
        ++block;
      }
           

将資料分塊,每個線程處理一份資料,調用LRWorker::update函數對參數進行更新

void LRWorker::update(int start, int end) {
    size_t idx = 0;
    auto all_keys = std::vector<Base::sample_key>();
    auto unique_keys = std::vector<ps::Key>();;
    int line_num = 0;
    for (int row = start; row < end; ++row) {
      int sample_size = train_data->fea_matrix[row].size();
      Base::sample_key sk;
      sk.sid = line_num;
      for (int j = 0; j < sample_size; ++j) {
        idx = train_data->fea_matrix[row][j].fid;
        sk.fid = idx;
        all_keys.push_back(sk);
        unique_keys.push_back(idx);
      }
      ++line_num;
    }
    std::sort(all_keys.begin(), all_keys.end(), base_->sort_finder);
    std::sort(unique_keys.begin(), unique_keys.end());
    unique_keys.erase(unique(unique_keys.begin(), unique_keys.end()),
                        unique_keys.end());
    int keys_size = (unique_keys).size();

    auto w = std::vector<float>(keys_size);
    auto push_gradient = std::vector<float>(keys_size);
    kv_w_->Wait(kv_w_->Pull(unique_keys, &(w)));
    auto loss = std::vector<float>(end - start);
    calculate_loss(w, all_keys, unique_keys, start, end, loss);
    calculate_gradient(all_keys, unique_keys, loss, push_gradient);

    kv_w_->Wait(kv_w_->Push(unique_keys, push_gradient));
    --gradient_thread_finish_num;
  }
           

這部分是Worker的核心處理邏輯,包含以下幾個方面:

  • 結合資料組織的Key(ps::Key),去掉沒有用到的特征;
  • 根據參數對應的Key從Server那裡Pull取最新的參數;
  • 用最新的參數計算loss和gradient,code較長,感興趣的可以到github了解
  • 計算得到的gradient,然後Push到Server中,進行參數的更新

直到Worker的疊代結束,Server和Scheduler也相繼關閉連結,調用ps::Finalize() 進行回收。

簡單的分布式LR就這樣實作了

如果本文對你有幫助的話,麻煩點個贊支援一下,同時歡迎評論或者私信讨論。

參考

  1. ^xflow https://github.com/xswang/xflow