天天看點

MLSQL是如何內建TensorFlow Cluster的

前言

我們知道MLSQL支援SKLearn,TF等流行的算法架構,不過雖然支援了多個執行個體同時運作,但其實每個模型都需要跑全部資料。有的時候資料太大,确實是個問題,是以這個時候還是需要引入Cluster的。MLSQL基于Spark,是以問題就變成了如何在Spark裡內建TF Cluster了。TFoS 已經實作了類似的功能,但遺憾的是,TFoS完全是用Python編寫的,并且每次都需要啟動一個新的Spark 執行個體來運作,overhead 是比較高的。

MLSQL內建TF Cluster

MLSQL內建TF Cluster 的主要優勢有:

  1. 一個Spark執行個體可以運作多個TF Cluster,互不影響。
  2. 可以local模式運作TF Cluster
  3. 資料互動本地化(也可以消費Kafka),假設你配置了10個worker,資料會被切分成十份,然後同步到對應worker的本地目錄。
  4. 易用,你隻要寫一個python腳本,所有排程相關工作全部由MLSQL來完成。

感興趣的可以參看這個PR,看看具體實作源碼。

一個示例

load libsvm.`/tmp/william/sample_libsvm_data.txt` as data;

train data as DTFAlg.`/tmp/jack`
where
pythonScriptPath="/tmp/tensorflow-distribute.py"
and `kafkaParam.bootstrap.servers`="127.0.0.1:9092"
and `kafkaParam.topic`="test"
and `kafkaParam.group_id`="g_test-1"
and  keepVersion="true"
and  enableDataLocal="true"
and  dataLocalFormat="json"
and distributeEveryExecutor="false"

and  `fitParam.0.jobName`="worker"
and  `fitParam.0.taskIndex`="0"

and  `fitParam.1.jobName`="worker"
and  `fitParam.1.taskIndex`="1"

and  `fitParam.2.jobName`="ps"
and  `fitParam.2.taskIndex`="0"


and `systemParam.pythonPath`="python"
and `systemParam.pythonVer`="2.7"
;           

複制

我們看到,隻要配置一個python腳本,然後通過fitParam指定每個節點的jobName,taskIndex即可。

在python腳本中,你可以通過如下方式拿到這些參數:

jobName = param("jobName", "worker")
taskIndex = int(param("taskIndex", "0"))
clusterSpec = json.loads(mlsql.internal_system_param["clusterSpec"])
checkpoint_dir = mlsql.internal_system_param["checkpointDir"]           

複制

一個大緻的TF腳本如下:

def run():
    # create the cluster configured by `ps_hosts' and 'worker_hosts'
    cluster = tf.train.ClusterSpec(clusterSpec)

    # create a server for local task
    server = tf.train.Server(cluster, job_name=jobName,
                             task_index=taskIndex)

    if jobName == "ps":
        server.join()  # ps hosts only join
    elif jobName == "worker":
       .......           

複制

當然,不可避免的,你可能需要用到MonitoredTrainingSession等和叢集相關的API。

運作後的一些資訊可以查詢到:

MLSQL是如何內建TensorFlow Cluster的

WX20180717-144037.png

圖中顯示了,第一行第二行是worker,第三行是ps, algIndex 0,1都産生模型(其實是checkpoint),但實際上隻有0是有資料的,狀态都是成功,對應的參數為trainParams

難點

這個需求我昨天早上提出,下午開始弄,我一開始以為一個下午就能搞定,但是最後還是做到了晚上十一點多,這裡有幾個問題需要注意:

  1. 使用者可能取消任務,如何及時的殺掉TF cluster.
  2. spark 可能異常退出,如何保證也能退出TF cluster
  3. 如何差別對待PS/Worker角色

實作方式

worker需要能夠和driver 進行互動。為什麼呢?TF啟動Cluster的時候,是需要ClusterSpec,也就是每個節點host和port。

Spark在分發Task的時候是并行的,你不知道會分發到哪個節點,并且分發後,你也不知道TF能夠在對應的節點擷取到哪個端口。為了完成這些資訊的收集,需要走如下幾個流程:

  1. 每個Task在啟動TF Server之前,需要先擷取端口,并且占用住,然後上報給Driver,Driver會記住這些。
  2. 接着Task會等待所有的Task都上報完成,然後釋放占用的端口,啟動對應的TF Server。

    TF Server 完成訓練後會上報Driver。

  3. PS會監聽是不是所有的Worker都已經完成了工作,如果是,則會自己把自己結束掉。
  4. 最後整個訓練結束,并且把訓練好的模型發送到HDFS上。

Executor 和Driver 互動,其實MLSQL裡已經實作了自己的PRC層。不過因為這次比較簡單,隻需要單向通訊即可,是以直接基于Driver 的http接口完成。