天天看点

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:

1. xflow_demo中部分数据结构

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:
【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:

2. AUC计算方法:

    AUC表示正样本排在负样本前面的概率。采用排列组合,找出所有正样本排在负样本前面的情况,然后除以所有的组合,即可得到AUC。

struct auc_key{
    int label;
    float pctr;
  };

  void calculate_auc(std::vector<auc_key>& auc_vec) {
    std::sort(auc_vec.begin(), auc_vec.end(), [](const auc_key& a,
          const auc_key& b){
        return a.pctr > b.pctr;
        });
    float area = 0.0;
    int tp_n = 0;
    for (size_t i = 0; i < auc_vec.size(); ++i) {
      if (auc_vec[i].label == 1) {
        tp_n += 1;
      } else {
        area += tp_n;
      }
      logloss += auc_vec[i].label * std::log2(auc_vec[i].pctr)+
        + (1.0 - auc_vec[i].label) * std::log2(1.0 - auc_vec[i].pctr);
    }
    logloss /= auc_vec.size();
    std::cout << "logloss: " << logloss << "\t";
    if (tp_n == 0 || tp_n == auc_vec.size()) {
      std::cout << "tp_n = " << tp_n << std::endl;
    } else {
      area /= 1.0 * (tp_n * (auc_vec.size() - tp_n));
      std::cout << "auc = " << area
        << "\ttp = " << tp_n
        << " fp = " << auc_vec.size() - tp_n << std::endl;
    }
  }
           

3. 参数服务器分配结构

      server用于更新权重。不同种类的权重可以放在不同的参数服务器上更新,比如w和b可以放在不同的服务器上进行参数更新。或者w的维度太大时候,也可以放在多台服务器上进行更新。

     work主要进行预测或者训练(计算梯度),可以使用一台work进行预测。使用其他的work进行训练。(预测主要用于验证,计算AUC等等,再每一轮训练完毕后,可以进行预测)

#include "ps/ps.h"

int main(int argc, char *argv[]) {
  if (ps::IsScheduler()) {
    std::cout << "start scheduler" << std::endl;
  }
  if (ps::IsServer()) {
    std::cout << "start server" << std::endl;
    xflow::Server* server = new xflow::Server();
  }
  ps::Start();
  if (ps::IsWorker()) {
    std::cout << "start worker" << std::endl;
    int epochs = std::atoi(argv[4]);
    if (*(argv[3]) == '0') {
      std::cout << "start LR " << std::endl;
      xflow::LRWorker* lr_worker = new xflow::LRWorker(argv[1], argv[2]);
      lr_worker->epochs = epochs;
      lr_worker->train();
    }
    if (*(argv[3]) == '1') {
      std::cout << "start FM " << std::endl;
      xflow::FMWorker* fm_worker = new xflow::FMWorker(argv[1], argv[2]);
      fm_worker->epochs = epochs;
      fm_worker->train();
    }
    if (*(argv[3]) == '2') {
      std::cout<< "start MVM " << std::endl;
      xflow::MVMWorker* mvm_worker = new xflow::MVMWorker(argv[1], argv[2]);
      mvm_worker->epochs = epochs;
      mvm_worker->train();
    }
  }
  ps::Finalize();
}
           

 LRWorker部分代码:

void LRWorker::train() {
    rank = ps::MyRank();
    std::cout << "my rank is = " << rank << std::endl;
    snprintf(train_data_path, 1024, "%s-%05d", train_file_path, rank);
    batch_training(pool_);
    if (rank == 0) {
      std::cout << "LR AUC: " << std::endl;
      predict(pool_, rank, 0);
    }
    std::cout << "train end......" << std::endl;
  }
}  // namespace xflow
           

4.参数更新

对于结构体

typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;
           

    std::unordered_map<ps::Key, sgdentry_w> store;

ps::Key: 特征索引

sgdentry_w:特征对应的w值,如果没有做embeeding,则w_dim = 1;  如果做了Embeeding,则是embedding的大小。

如下是server端SGD优化的梯度更新代码。

namespace xflow {
extern int w_dim;
extern int v_dim;
float learning_rate = 0.001;

class SGD {
 public:
  SGD() {}
  ~SGD() {}

  typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;

  struct KVServerSGDHandle_w {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        w_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / w_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * w_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_w& val = store[key];
        for (int j = 0; j < w_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * w_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < w_dim; ++j) {
              res.vals[i * w_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_w> store;
  };

  typedef struct SGDEntry_v {
    SGDEntry_v(int k = v_dim) {
      w.resize(k, 0.001);
    }
    std::vector<float> w;
  } sgdentry_v;

  struct KVServerSGDHandle_v {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        v_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / v_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * v_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_v& val = store[key];
        for (int j = 0; j < v_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * v_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < v_dim; ++j) {
              res.vals[i * v_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_v> store;
  };

 private:
};
}  // namespace xflow
           

5. 梯度计算

参考:https://github.com/ljzzju/logistic-regression-ftrl-ps

如下mb_g中存储的是梯度。

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:

转载:https://zhuanlan.zhihu.com/p/109841554

6 其他知识点介绍

Node的表示

在ps-lite中,各个节点发送消息,自然需要id来标识需要发送到哪些节点。ps-lite采用的编号方式也十分简单:

  • Group id:scheduler=1,kServerGroup=2,kWorkerGroup=4;(定义于头文件base.h中)
  • node id:server id = rank * 2 + 8;worker id = rank * 2 + 9;(定义于postoffice.h中)

为什么要这么定义呢?首先看group id,写为2进制分别为 b001, b010, b100, 这样要给多个group中的node发送信息,只需要发送给 或运算 计算后的号码即可。比如说要同时发送给三个group,就发送给7=b111这个id。因此,node id是从8号开始,且这个算法保证server id为偶数,node id为奇数。

消息封装

这个部分的代码位于ps/internel/message.h。回顾UML图中 Node —— Control ——Meta —— Message 这个依赖链(虽然与其说是依赖关系,我认为合成关系更加贴切)并类比一下http包的结构:基本上可以看出一个消息(Message)具有消息头(Meta)和消息体,消息头中含有控制信息(Control)表示这个消息表示的意义(例如终止,确认ACK,同步等),控制信息中含有节点id(vector<Node>)以及group id表示这个控制命令对谁执行。【控制命令=EMPTY时,代表这个消息是普通消息】

小结一下:

  • Node:存放node的ip,port,类型,id等信息;
  • Control:存放command类型,barrier_group(用于标识哪些节点需要同步,当command=BARRIER时使用),node(Node类,用于标识控制命令对哪些节点使用)等;
  • Meta:时间戳,发送者id,接受者id,控制信息Control,消息类型等;
  • Message:消息头Meta,消息体data;

每次发送消息时,各个node都按这个格式封装好,负责发送消息的类成员(Customer类)就会按照Meta之中的信息将消息送货上门了。

Node之间的协同工作

接下来看一看UML图的上半部分是怎么一起合作,构成一个具有parameter server的系统。在讲解之前,先补全几个小的数据结构:

  • SArray:可以把SArray理解为vector,只不过使用了share_ptr来管理SArray的内存:当对某个SArray的引用为0时,就自动回收该SArray的内存。(sarray.h)
  • KVPairs:包含SArray<Key> keys, SArray<Val> vals, SArray<int> lens的模板类。Key其实是int64的别名,Val是模板变量。lens和keys等长,表示每个key对应的val有多长。每个server其实对应一段连续的key,存放这些key对应的val。(kv_app.h)
  • Range:定义begin()和end()两个位置(uint64),.size()获得size=end-begin。根据这个Range确定要拉取的参数在哪个server上,以及一个server对应的key的range。(range.h)
  • TreadsafeQueue:一个可以供多个线程读取的队列,有Push和WaitAndPop两个方法。通过锁和条件量合作来达到线程安全,用来做消息队列。(threadsafe_queue.h)

再补全一下各个工作类的作用:

  • Environment:一个单例模式类,它通过一个ordered_map保存环境变量名以及值;
  • PostOffice:一个单例模式类,一个node在生命期内具有一个PostOffice,依赖它的类成员对Node进行管理;
  • Van:负责Message的实际收发工作,比PostOffice更加底层,PostOffice持有一个Van成员;
  • SimpleApp:KVServer和KVWorker的父类,它提供了简单的Request, Wait, Response,Process功能;KVServer和KVWorker分别根据自己的使命重写了这些功能;
  • Customer:每个SimpleApp对象持有一个Customer类的成员,且Customer需要在PostOffice进行注册,这个成员负责 1.跟踪由SimpleApp发送出去的消息的回复情况 2. 维护一个Node的消息队列,为Node接收消息;

类比一下http在网络中经过的层次关系,我认为ps-lite中的一个Message经过的层次关系是这样的:(这个层次关系可以说是本文精华了QAQ)

SimpleApp --> PostOffice --> Van --> Van --> PostOffice -->Customer-->SimpleApp

做好了预备工作,下面整体说一下系统的启动流程:

设置环境变量(参考脚本tests/local.sh):

# set global information
export DMLC_NUM_SERVER=$var1
export DMLC_NUM_WORKER=$var2
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
# setting for scheduler
export DMLC_ROLE='scheduler'
# setting for servers
export DMLC_ROLE='server'
# setting for worker
export DMLC_ROLE='worker'
           

除了DMLC_ROLE设置得不同外,其他的变量在每个节点上都相同。

设置好环境变量后,看tests/test simple http://app.cc这个简单example:

Start(0);// 启动网络环境,设置本节点的信息
SimpleApp app(0, 0); // 启动一个工作类
           

第一句话做了什么

ps::Start(custom_id, argv0) or ps::StartAsync(custom id, argv0)
           
  • 调用PostOffice::Start(custom_id, argv0, barrier=true/false)
PostOffice::Start(custom_id, argv0, barrier=true/false)
// 
// PostOffice管理该Node的所有事;
           
  • 如果是第一次调用PostOffice::Start,
    • 先调用InitEnvironment,根据环境变量设置成员num_workers_, num_servers_, is_worker_, is_server_, is_scheduler_, verbose_ 的值
    • 设置node_ids_成员:
  • 调用Van::Start(custom_id);
  • 如果是第一次调用PostOffice::Start,初始化start_time_成员;
  • 如果barrier=true,则所有Node准备并向Scheduler发送要求同步的Message,进行第一次同步;
Van::Start(custom_id)
// Van负责比PostOffice更底层的网络通信;
           
  • 如果是第一次调用Van::Start,则设根据环境变量设置scheduler_的Node信息(ip:port),如果该节点不是scheduler节点,则需要再根据环境变量设置好本节点my_node_的Node信息;之后该node与端口绑定,然后连接scheduler;新建线程receiver_thread_
  • 如果该node不是scheduler,则给scheduler发送一条信息,直到ready_被设为true开始继续运行(此时所有节点已经ready_);
  • 如果是第一次调用Van::Start,则再创建一个Resender对象,和一个heartbeat_thread_;

第二句话做了什么?

SimpleApp(app_id, custom_id)
           
  • 新建一个Custom对象初始化obj_成员;
Customer(app_id, custom_id, recv_handle)
           
  • 分别用传入构造函数的参数初始化app_id_, custom_id_, recv_handle成员
  • 调用PostOffice::AddCustomer将当前Customer注册到PostOffice;
    • PostOffice的customers_成员: 在对应的app_id的元素上添加custom_id;
    • PostOffice的barrier_done_成员将该custom_id的同步状态设为false
  • 新起一个Receiving线程recv_thread_;

理解了SimpleApp,它的两个子类的构造就容易理解了:

KVServer(app_id)

  1. 新建一个Customer对象初始化obj_成员,用KVServer::Process传入Customer构造函数,对于Server来说,app_id=custom_id=server's id;

KVWorker(app_id, custom_id)

  1. 用默认的KVWorker::DefaultSlicer绑定slicer_成员;
  2. 新建一个Customer对象初始化obj_成员,不传入handle参数;

工作类SimpleApp,KVWorker,KVServer的api

SimpleApp

  1. set_request_handle,set_response_handle:设置成员request_handle_, response_handle_。在客户端调用SimpleApp::Process时,根据message.meta中的指示变量判断是request还是response,调用相应handle处理;

KVServer

  1. set_request_handle,在调用KVServer::Process时,该函数使用request_handle处理message

KVWorker

  1. set_slicer:设置slicer_成员,该函数在调用Send函数时,将KVPairs按照每个server的Range切片;
  2. Pull(key_vector, val_vector, option: len_vector, cmd, callback):根据key_vector从Server上拉取val_vector,返回timestamp,该函数不阻塞,可用worker.Wait(timestamp)等待;
  3. ZPull同理Pull,不过在调用内部Pull_函数时,不会copy一个key_vector,所以需要保证在ZPull完成前,调用者没有改变key_vector;
  4. Push(key_vector, val_vector, optional: len_vector, cmd, callback) 以及ZPush:
    1. 由obj_成员准备一个送到ServerGroup的request返回stamp;
    2. 设置好对应timestamp的callback;
    3. 使用传入的参数构造KVPair对象,调用Send送出该对象;

7.  总体代码讲解:

注明:如下LR实现的代码,使用的数据都是类别行数据,然后进行了编码。所以其 value统一是1,在编码过程中会有一些区别。

void LRWorker::calculate_loss(std::vector<float>& w,
                      std::vector<Base::sample_key>& all_keys,
                      std::vector<ps::Key>& unique_keys,
                      size_t start,
                      size_t end,
                      std::vector<float>& loss) {
    auto wx = std::vector<float>(end - start);
    for (int j = 0, i = 0; j < all_keys.size(); ) {
      size_t allkeys_fid = all_keys[j].fid;
      size_t weight_fid = (unique_keys)[i];
      if (allkeys_fid == weight_fid) {
        // sid : 样本ID
        // 该处直接加w[i],相当于默认val的值是1.
        wx[all_keys[j].sid] += (w)[i];
        ++j;
      }
      else if (allkeys_fid > weight_fid) {
        ++i;
      }
    }
    for (int i = 0; i < wx.size(); i++) {
      float pctr = base_->sigmoid(wx[i]);
      loss[i] = pctr - train_data->label[start++];
    }
  }
           

当val的值不是全部设置为1时,正确的写法:

参考:https://github.com/ljzzju/logistic-regression-ftrl-ps

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:

转载:https://zhuanlan.zhihu.com/p/109841554

1. ps-lite

ps-lite是参数服务器(ps)的一种轻量实现,用于构建高可用分布式的机器学习应用。通常是多个节点运行在多台物理机器上用于处理机器学习问题,一般包含一个schedule节点和若干个worker/server节点。

ps-lite的三种节点:

  • Worker:负责主要的计算工作,如读取数据,数据预处理,梯度计算等。它通过push和pull的方式和server节点进行通信。例如,worker节点push计算得到的梯度到server,或者从server节点pull最新的参数。
  • Server:用于维护和更新模型权重。每个server节点维护部分模型信息。
  • Scheduler:用于监听其他节点的存活状态。也负责发送控制命令(比如Server和Worker的连接通信),并且收集其它节点的工作进度。

ps-lite 支持同步和异步两种机制

【分布式】基于ps-lite的分布式计算实例解析1. xflow_demo中部分数据结构2. AUC计算方法:3. 参数服务器分配结构4.参数更新5. 梯度计算6 其他知识点介绍7.  总体代码讲解:
  1. 在同步的机制下,系统运行的时间是由最慢的worker节点与通信时间决定的
  2. 在异步的机制下,每个worker不用等待其它workers完成再运行下一次迭代。这样可以提高效率,但从迭代次数的角度来看,会减慢收敛的速度。

ps-lite常应用于推荐系统中的ctr预测中,用以解决数据量较大,特征较多的问题,同时也有很多开发者贡献自己的应用源码。本文就结合基于ps-lite实现的分布式lr为例,记录一下ps-lite的具体使用,代码放在了

https://github.com/peterzhang2029/xflow_demo

https://github.com/xswang/xflow 

2. 代码实例

接下来就以分布式LR为例,简单介绍一下代码的各部分功能:

主函数:

int main(int argc, char *argv[]) {
  if (ps::IsScheduler()) {
    std::cout << "start scheduler" << std::endl;
  }
  if (ps::IsServer()) {
    std::cout << "start server" << std::endl;
    xflow::Server* server = new xflow::Server();
  }
  ps::Start();
  if (ps::IsWorker()) {
    std::cout << "start worker" << std::endl;
    int epochs = std::atoi(argv[4]);
    if (*(argv[3]) == '0') {
      std::cout << "start LR " << std::endl;
      xflow::LRWorker* lr_worker = new xflow::LRWorker(argv[1], argv[2]);
      lr_worker->epochs = epochs;
      lr_worker->train();
    }
    if (*(argv[3]) == '1') {
      std::cout << "start FM " << std::endl;
      xflow::FMWorker* fm_worker = new xflow::FMWorker(argv[1], argv[2]);
      fm_worker->epochs = epochs;
      fm_worker->train();
    }
    if (*(argv[3]) == '2') {
      std::cout<< "start MVM " << std::endl;
      xflow::MVMWorker* mvm_worker = new xflow::MVMWorker(argv[1], argv[2]);
      mvm_worker->epochs = epochs;
      mvm_worker->train();
    }
  }
  ps::Finalize();
}
           
  • 对于 Scheduler 节点,直接调用ps-lite的Start()和Finalize()启动和准备回收;
  • 对于 Server 节点,可以定义一个单独的Server类;
  • 对于 Worker 节点,对于不同的模型方法定义对应的处理类,例如:xflow::LRWorker;

2.1 Scheduler 节点

对于scheduler节点进行初始化,Scheduler机器对应的IP地址设定为$DMLC_PS_ROOT_URI

2.2 Server 节点

Server类的定义:

namespace xflow {
class Server {
 public:
  Server() {
    server_w_ = new ps::KVServer<float>(0);
    server_w_->set_request_handle(SGD::KVServerSGDHandle_w());
    //server_w_->set_request_handle(FTRL::KVServerFTRLHandle_w());

    server_v_ = new ps::KVServer<float>(1);
    //server_v_->set_request_handle(FTRL::KVServerFTRLHandle_v());
    server_v_->set_request_handle(SGD::KVServerSGDHandle_v());
    std::cout << "init server success " << std::endl;
  }
  ~Server() {}
  ps::KVServer<float>* server_w_;
  ps::KVServer<float>* server_v_;
};
}  // namespace xflow
#endif  // SRC_MODEL_SERVER_H_
           

其中定义了两个KVServer类型的权值: server_w_、server_v_对应模型的w和b,然后将在request handler中定义优化策略,这部分是server的核心处理逻辑,以SGD为例:

typedef struct SGDEntry_w {
    SGDEntry_w(int k = w_dim) {
      w.resize(k, 0.0);
    }
    std::vector<float> w;
  } sgdentry_w;

  struct KVServerSGDHandle_w {
    void operator()(const ps::KVMeta& req_meta,
        const ps::KVPairs<float>& req_data,
        ps::KVServer<float>* server) {
      size_t keys_size = req_data.keys.size();
      size_t vals_size = req_data.vals.size();
      ps::KVPairs<float> res;

      if (req_meta.push) {
        w_dim = vals_size / keys_size;
        CHECK_EQ(keys_size, vals_size / w_dim);
      } else {
        res.keys = req_data.keys;
        res.vals.resize(keys_size * w_dim);
      }

      for (size_t i = 0; i < keys_size; ++i) {
        ps::Key key = req_data.keys[i];
        SGDEntry_w& val = store[key];
        for (int j = 0; j < w_dim; ++j) {
          if (req_meta.push) {
            float g = req_data.vals[i * w_dim + j];
            val.w[j] -= learning_rate * g;
          } else {
            for (int j = 0; j < w_dim; ++j) {
              res.vals[i * w_dim + j] = val.w[j];
            }
          }
        }
      }
      server->Response(req_meta, res);
    }

   private:
    std::unordered_map<ps::Key, sgdentry_w> store;
  };
           
  • 对于push操作,接收来自Worker计算的梯度,然后记录到维护的map里
  • 对于pull操作,按照key将上一轮更新的weight拿出来返回给Worker
  • 具体的优化算法,同步或者异步都可以在这里实现,上述实现的是简单的异步SGD

2.3 Worker 节点

Worker类的构造函数:

LRWorker(const char *train_file,
    const char *test_file) :
                           train_file_path(train_file),
                           test_file_path(test_file) {
    kv_w_ = new ps::KVWorker<float>(0);
    base_ = new Base;
    core_num = std::thread::hardware_concurrency();
    pool_ = new ThreadPool(core_num);
  }
           

此处定义了一个线程池,利用多线程来计算梯度,代码中也包含了线程池的具体实现。

读取数据之后,进行训练:

while (1) {
        train_data_loader.load_minibatch_hash_data_fread();
        if (train_data->fea_matrix.size() <= 0) break;
        int thread_size = train_data->fea_matrix.size() / core_num;
        gradient_thread_finish_num = core_num;
        for (int i = 0; i < core_num; ++i) {
          int start = i * thread_size;
          int end = (i + 1)* thread_size;
          pool->enqueue(std::bind(&LRWorker::update, this, start, end));
        }
        while (gradient_thread_finish_num > 0) {
          usleep(5);
        }
        ++block;
      }
           

将数据分块,每个线程处理一份数据,调用LRWorker::update函数对参数进行更新

void LRWorker::update(int start, int end) {
    size_t idx = 0;
    auto all_keys = std::vector<Base::sample_key>();
    auto unique_keys = std::vector<ps::Key>();;
    int line_num = 0;
    for (int row = start; row < end; ++row) {
      int sample_size = train_data->fea_matrix[row].size();
      Base::sample_key sk;
      sk.sid = line_num;
      for (int j = 0; j < sample_size; ++j) {
        idx = train_data->fea_matrix[row][j].fid;
        sk.fid = idx;
        all_keys.push_back(sk);
        unique_keys.push_back(idx);
      }
      ++line_num;
    }
    std::sort(all_keys.begin(), all_keys.end(), base_->sort_finder);
    std::sort(unique_keys.begin(), unique_keys.end());
    unique_keys.erase(unique(unique_keys.begin(), unique_keys.end()),
                        unique_keys.end());
    int keys_size = (unique_keys).size();

    auto w = std::vector<float>(keys_size);
    auto push_gradient = std::vector<float>(keys_size);
    kv_w_->Wait(kv_w_->Pull(unique_keys, &(w)));
    auto loss = std::vector<float>(end - start);
    calculate_loss(w, all_keys, unique_keys, start, end, loss);
    calculate_gradient(all_keys, unique_keys, loss, push_gradient);

    kv_w_->Wait(kv_w_->Push(unique_keys, push_gradient));
    --gradient_thread_finish_num;
  }
           

这部分是Worker的核心处理逻辑,包含以下几个方面:

  • 结合数据组织的Key(ps::Key),去掉没有用到的特征;
  • 根据参数对应的Key从Server那里Pull取最新的参数;
  • 用最新的参数计算loss和gradient,code较长,感兴趣的可以到github了解
  • 计算得到的gradient,然后Push到Server中,进行参数的更新

直到Worker的迭代结束,Server和Scheduler也相继关闭链接,调用ps::Finalize() 进行回收。

简单的分布式LR就这样实现了

如果本文对你有帮助的话,麻烦点个赞支持一下,同时欢迎评论或者私信讨论。

参考

  1. ^xflow https://github.com/xswang/xflow