天天看點

[Alink漫談之三] AllReduce通信模型

Alink 是阿裡巴巴基于實時計算引擎 Flink 研發的新一代機器學習算法平台,是業界首個同時支援批式算法、流式算法的機器學習平台。本文将帶領大家來分析Alink中通訊模型AllReduce的實作。AllReduce在Alink中應用較多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了這個通訊模型。

[Alink漫談之三] AllReduce通信模型

目錄

    • 0x00 摘要
    • 0x01 MPI是什麼
    • 0x02 Alink 實作MPI的思想
    • 0x03 如何實作共享
      • 1. Task相關概念
      • 2. TaskManager
      • 3. 狀态共享
        • 3.1 概念剖析
          • 算法角度:ComContext
          • 架構角度:IterativeComQueue
          • Session角度:SessionSharedObjs
          • Subtask角度:IterTaskObjKeeper
        • 3.2 變量執行個體分析
        • 3.3 ComContext
        • 3.4 SessionSharedObjs
        • 3.5 IterTaskObjKeeper
    • 0x04. 示例代碼
      • KMeansTrainBatchOp調用
      • AllReduce實作
    • 0x05 AllReduce實作
      • 1. KMeansAssignCluster
      • 2. AllReduceSend
      • 3. AllReduceBroadcastRaw
      • 4. AllReduceSum
      • 5. AllReduceBroadcastSum
      • 6. AllReduceRecv
      • 7. KMeansUpdateCentroids
    • 0x06 參考

因為Alink的公開資料太少,是以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會随時更新。

MPI(Message-Passing Interface)是一個跨語言的通訊協定,用于編寫并行計算,支援點對點和廣播。

MPI的目标是高性能、大規模性和可移植性。MPI在今天仍為高性能計算的主要模型。

其特點是

  • A partitioned address space 每個線程隻能通過調用api去讀取非本地資料。所有的互動(Non-local Memory)都需要協同進行(握手)。
  • Supports only explicit parallelization 隻支援顯性的并行化,使用者必須明确的規定消息傳遞的方式。

AllReduce是MPI提供的一個基本原語,我們需要先了解reduce才能更好了解AllReduce。

  • 規約函數 MPI_Reduce :規約是來自函數式程式設計的一個經典概念。其将通信子内各程序的同一個變量參與規約計算,并向指定的程序輸出計算結果。比如通過一個函數将一批資料分成較小的一批資料。或者将一個數組的元素通過加法函數規約為一個數字。
  • 規約并廣播函數 MPI_Allreduce :在計算規約的基礎上,将計算結果分發到每一個程序中。比如函數在得到歸約結果值之後,将結果值分發給每一個程序,這樣的話,并行中的所有程序值都能知道結果值了。

MPI_Allreduce和MPI_Reduce的一個差別就是,MPI_Reduce函數将最後的結果隻傳給了指定的dest_process 号程序,而MPI_Allreduce函數可以将結果傳遞給所有的程序,是以所有的程序都能接收到結果。MPI_Allreduce函數的原型也是以不需要指定目标程序号。

AllReduce在Alink中應用較多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了這個通訊模型。

AllReduce在算法實作中起到了承上啟下的關鍵作用,即把原來串行跑的并行task強制打斷,把計算結果進行彙總再分發,讓串行繼續執行。有一點類似大家熟悉的并發中的Barrier。

對比Flink原生KMeans算法,我們能看到AllReduce對應的是

groupBy(0).reduce

。隻有所有資料都産生之後,才能做groupBy操作。

DataSet<Centroid> newCentroids = points
		// compute closest centroid for each point
		.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
		// count and sum point coordinates for each centroid
		.map(new CountAppender())
        // 這裡如果是Alink,就對應了AllReduce
		.groupBy(0).reduce(new CentroidAccumulator())
		// compute new centroids from point counts and coordinate sums
		.map(new CentroidAverager());
           

從AllReduce的注解中我們可以清晰的看出Alink實作MPI的思想。

* An implement of {@link CommunicateFunction} that do the AllReduce.
 *
 * AllReduce is a communication primitive widely used in MPI. In our implementation, all workers do reduce on a partition of the whole data and they all get the final reduce result.
 *
 * There're mainly three stages:
 *   1. All workers send the there partial data to other workers for reduce.
 *   2. All workers do reduce on all data it received and then send partial results to others.
 *   3. All workers merge partial results into final result and put it into session context with pre-defined object name.
 */
           

翻譯如下:

所有的workers都在部分資料上做reduce操作,所有的workers都可以擷取到reduce最終結果
    
主要有三個階段:
1. 所有workers給其他workers發送需要reduce的部分資料
2. 所有workers在它收到的資料上做reduce,然後把這個部分reduce的結果發送給其他workers
3. 所有workers把部分reduce的結果合并成為最終結果,然後放入預定義的session 上下文變量中
           

"紙上得來終覺淺,絕知此事要躬行。"

Alink為了實作AllReduce,在背後做了大量的工作,下面我們一一剖析。

共享是實作AllReduce的第一要務,因為在歸并/廣播過程中需要中繼資料和輸入輸出,如果有共享變量就可以極大簡化實作。我們下面就看看Alink如何通過task manager實作共享。

  • Task(任務) : Task 是一個階段多個功能相同 subTask 的集合,類似于 Spark 中的 TaskSet。
  • subTask(子任務) :subTask 是 Flink 中任務最小執行單元,是一個 Java 類的執行個體,這個 Java 類中有屬性和方法,完成具體的計算邏輯。
  • 鍊式優化 : 按理說應該是每個算子的一個并行度執行個體就是一個subtask。那麼,帶來很多問題,由于flink的taskmanager運作task的時候是每個task采用一個單獨的線程,這就會帶來很多線程切換開銷,進而影響吞吐量。為了減輕這種情況,flink進行了優化,也即對subtask進行鍊式操作,鍊式操作結束之後得到的task,再作為一個排程執行單元,放到一個線程裡執行。
  • Operator Chains(算子鍊) :Flink 将多個 subTask 合并成一個 Task(任務),這個過程叫做 Operator Chains,每個任務由一個線程執行。使用 Operator Chains(算子鍊) 可以将多個分開的 subTask 拼接成一個任務。類似于 Spark 中的 Pipeline。
  • Slot(插槽) :Flink 中計算資源進行隔離的單元,一個 Slot 中可以運作多個 subTask,但是這些 subTask 必須是來自同一個 application 的不同階段的 subTask。結果就是,每個slot可以執行job的一整個pipeline。

Flink 中的程式本質上是并行的。在執行期間,每一個算子(Transformation)都有一個或多個算子subTask(Operator SubTask),每個算子的 subTask 之間都是彼此獨立,并在不同的線程中執行,并且可能在不同的機器或容器上執行。

同一個application,多個不同 task的 subTask,可以運作在同一個 slot 資源槽中。同一個 task 中的多個的 subTask,不能運作在一個 slot 資源槽中,他們可以分散到其他的資源槽中。對應到後面就是:AllReduceSend的多個并行度執行個體都不能運作在同一個slot中。

Flink 中每一個 TaskManager 都是一個JVM程序,它可能會在獨立的線程上執行一個或多個 subtask。TaskManager 相當于整個叢集的 Slave 節點,負責具體的任務執行和對應任務在每個節點上的資源申請和管理。

TaskManager為了對資源進行隔離和增加允許的task數,引入了slot的概念,這個slot對資源的隔離僅僅是對記憶體進行隔離,政策是均分。一個 TaskManager 至少有一個 slot。如果一個TM有N個Slot,則每個Slot配置設定到的Memory大小為整個TM Memory的1/N,同一個TM内的Slots隻有Memory隔離,CPU是共享的。

用戶端通過将編寫好的 Flink 應用編譯打包,送出到 JobManager,然後 JobManager 會根據已注冊在 JobManager 中 TaskManager 的資源情況,将任務配置設定給有資源的 TaskManager節點,然後啟動并運作任務。

TaskManager 從 JobManager 接收需要部署的任務,然後使用 Slot 資源啟動 Task,建立資料接入的網絡連接配接,接收資料并開始資料處理。同時 TaskManager 之間的資料互動都是通過資料流的方式進行的。

Flink 的任務運作其實是采用多線程的方式,一個TaskManager(TM)在多線程中并發執行多個task。這和 MapReduce 多 JVM 進行的方式有很大的差別,Flink 能夠極大提高 CPU 使用效率,在多個任務和 Task 之間通過 TaskSlot 方式共享系統資源,每個 TaskManager 中通過管理多個 TaskSlot 資源池進行對資源進行有效管理。

對應到後面就是:在一個TaskManager中間運作的多個并行的AllReduceSend執行個體都會共享這個TaskManager中所有靜态變量。

Alink就是利用task manager的靜态變量實作了變量共享。其中有幾個主要類和概念比較複雜。我們從上到下進行講解,能看到随着從上到下,需要的标示和狀态逐漸增加。

從上往下調用層次如下:

使用者代碼調用 : context.getObj(bufferName); 這樣對使用者是最理想的,因為對于使用者來說知道變量名字就可以經過上下文來存取。

但是ComContext則需要知道更多,比如還需要知道 自己對應的sessioin和taskID,具體下面會說明。

ComContext如此向下調用 : SessionSharedObjs.put(objName, sessionId, taskId, obj);

IterativeComQueue 是一個架構概念。以Kmeans為例,就是Kmeans算法對應了若幹IterativeComQueue。

IterativeComQueue上擁有衆多compute/communicate function,每個function都應該知道自己屬于哪一個IterativeComQueue,如何和本Queue上其他function進行通信,不能和其他Queue上搞混了。這樣就需要有一個概念來表标示這個Queue。于是就有了下面Session概念。

為了區分每個IterativeComQueue,就産生了session這個概念。這樣IterativeComQueue上所有compute/communicate function都會綁定同一個session id,同一個IterativeComQueue上的所有function之間可以通信。

一個 IterativeComQueue 對應一個session,是以<"變量名" + sessionId>就對應了這個 session 能通路的某個變量。

SessionSharedObjs 包含靜态成員變量 :

  • int sessionId = 0; 遞增的标示,用來區分session。
  • HashMap<Tuple2<String, Integer>, Long> key2Handle。映射,表示一個session中 某個變量名 對應某個變量handle。

正常來說 "某個名字的變量" 對應 "某個變量handle" 即可。即一個session中某個變量名 對應某個變量handle。但是Flink中,會有多個subtask并行操作的狀态,這樣就需要有一個新的概念來标示subtask對應的變量,這個變量應該和taskId有所關聯。于是就有了下面的state概念。

SessionSharedObjs向下調用 : IterTaskObjKeeper.put(handle, taskId, obj);

這裡就是用靜态變量來實作共享。是task manager中所有的 tasks (threads)都可以通路的共享變量執行個體。

IterTaskObjKeeper 包含靜态成員變量 :

  • long handle = 0L; 遞增的标示,用來區分state。
  • Map <Tuple2.of(handle, taskId), state> states; 是一個映射。即handle代表哪一種變量state,<handle, taskId>表示這種變量中 "哪個task" 對應的state執行個體,是針對subtask的一種細分。

在Flink中,一個算法會被多個subtask并行操作。如果隻有一個handle,那麼多個subtask共同通路,就會有大家都熟知的各種多線程操作問題。是以Alink這裡将handle拆分為多個state。從subtask角度看,每個state用<handle, taskId>來唯一标示。

總結一下,就是對于同樣一個變量名字,每個subtask對應的共享state其實都是獨立的,大家互不幹擾。共享其實就是在這個subtask上跑的各個operator之間共享。

從實際執行的變量中,我們可以有一個更加清楚的認識。

// 能看出來 session 0 中,centroidAllReduce這個變量 對應的handle是 7
SessionSharedObjs.key2Handle = {HashMap@10480}  size = 9
 {Tuple2@10492} "(initCentroid,0)" -> {Long@10493} 1
 {Tuple2@10494} "(statistics,0)" -> {Long@10495} 2
 {Tuple2@10496} "(362158a2-588b-429f-b848-c901a1e15e17,0)" -> {Long@10497} 8
 {Tuple2@10498} "(k,0)" -> {Long@10499} 6
 {Tuple2@10500} "(centroidAllReduce,0)" -> {Long@10501} 7 // 這裡就是所說的
 {Tuple2@10502} "(trainData,0)" -> {Long@10503} 0
 {Tuple2@10504} "(vectorSize,0)" -> {Long@10505} 3
 {Tuple2@10506} "(centroid2,0)" -> {Long@10507} 5
 {Tuple2@10508} "(centroid1,0)" -> {Long@10509} 4

// 下面能看出來,handle 7 這一種變量,因為有 4 個subtask,是以細分為4個state。 
 com.alibaba.alink.common.comqueue.IterTaskObjKeeper.states = {HashMap@10520}  size = 36
 {Tuple2@10571} "(7,0)" -> {double[15]@10572} 
 {Tuple2@10573} "(7,1)" -> {double[15]@10574} 
 {Tuple2@10577} "(7,2)" -> {double[15]@10578} 
 {Tuple2@10581} "(7,3)" -> {double[15]@10582} 

 {Tuple2@10575} "(5,0)" -> {Tuple2@10576} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@29a72fbb)"
 {Tuple2@10579} "(5,1)" -> {Tuple2@10580} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@26c52354)"
 {Tuple2@10585} "(5,2)" -> {Tuple2@10586} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@7c6ed779)"
 {Tuple2@10588} "(5,3)" -> {Tuple2@10589} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@154b8a4d)"
           

下面讓我們結合代碼,一一解析涉及的類。

ComContext 是最上層類,用來擷取runtime資訊和共享變量。IterativeComQueue(BaseComQueue )上所有的compute/communicate function都通過 ComContext 來通路共享變量。比如:

public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {

    // 每一個BaseComQueue都會得到唯一一個sessionId。
    private final int sessionId = SessionSharedObjs.getNewSessionId();

    int taskId = getRuntimeContext().getIndexOfThisSubtask();

    public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
        // 擷取到了一個ComContext
        ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
        if (getIterationRuntimeContext().getSuperstepNumber() == maxIter || criterion) {
            // 利用ComContext繼續通路共享變量
            List<Row> model = completeResult.calc(context);
        }
    }
}

// 使用者類似這麼調用

double[] sendBuf = context.getObj(bufferName);
           

可以看出來,ComContext 就是使用者應該看到的最頂層上下文概念。 taskId, sessionId 是使用關鍵。

  • sessionId 是在 SessionSharedObjs中定義的靜态類成員變量,其會自動遞增。每一個BaseComQueue都會得到唯一一個sessionId,即該Queue保持了唯一session。這樣BaseComQueue中生成的ComContext都有相同的sessionId。
  • taskId是從runtime中獲得。
/**
 * Encapsulates task-specific information: name, index of subtask, parallelism and attempt number.
 */
@Internal
public class TaskInfo {
	/**
	 * Gets the number of this parallel subtask. The numbering starts from 0 and goes up to parallelism-1 (parallelism as returned by {@link #getNumberOfParallelSubtasks()}).
	 *
	 * @return The index of the parallel subtask.
	 */
	public int getIndexOfThisSubtask() {
		return this.indexOfSubtask; // 這裡擷取taskId
	}
}
           

ComContext 具體類定義如下

/**
 * Context used in BaseComQueue to access basic runtime information and shared objects.
 */
public class ComContext {
	private final int taskId;
	private final int numTask;
	private final int stepNo;
	private final int sessionId;

	public ComContext(int sessionId, IterationRuntimeContext runtimeContext) {
		this.sessionId = sessionId;
		this.numTask = runtimeContext.getNumberOfParallelSubtasks();
		this.taskId = runtimeContext.getIndexOfThisSubtask();
		this.stepNo = runtimeContext.getSuperstepNumber();
	}
    
	/**
	 * Put an object into shared objects for access of other QueueItem of the same taskId.
	 *
	 * @param objName object name
	 * @param obj     object itself.
	 */
	public void putObj(String objName, Object obj) {
		SessionSharedObjs.put(objName, sessionId, taskId, obj);
	}
}

// 比如具體舉例如下
this = {ComContext@10578} 
 taskId = 4
 numTask = 8
 stepNo = 1
 sessionId = 0
           

SessionSharedObjs是再下一層的類,維護shared session objects, 這個session 共享是通過 sessionId 做到的。

SessionSharedObjs 維護了一個靜态類變量 sessionId,由此區分各個Session。

SessionSharedObjs核心是

HashMap<Tuple2<String, Integer>, Long> key2Handle

。即 <"變量名" + sessionId> ---> <真實變量 handle> 的一個映射。

一個 IterativeComQueue 對應一個session,是以<"變量名" + sessionId>就對應了這個 IterativeComQueue 能通路的某個變量,正常來說有一個變量handle即可。

但是因為一個 IterativeComQueue會被若幹subtask并行執行,是以為了互斥和區分,是以每個handle又細分為若幹state,每個state用<handle, taskId>來唯一标示。在下面會提到。

/**
 * An static class that manage shared objects for {@link BaseComQueue}s.
 */
class SessionSharedObjs implements Serializable {
	private static HashMap<Tuple2<String, Integer>, Long> key2Handle = new HashMap<>();
	private static int sessionId = 0;
	private static ReadWriteLock rwlock = new ReentrantReadWriteLock();
    
	/**
	 * Get a new session id.
	 * All access operation should bind with a session id. This id is usually shared among compute/communicate function of an {@link IterativeComQueue}.
	 *
	 * @return new session id.
	 */
	synchronized static int getNewSessionId() {
		return sessionId++;
	}    
    
	static void put(String objName, int session, int taskId, Object obj) {
		rwlock.writeLock().lock();
		try {
			Long handle = key2Handle.get(Tuple2.of(objName, session));
			if (handle == null) {
				handle = IterTaskObjKeeper.getNewHandle();
				key2Handle.put(Tuple2.of(objName, session), handle);
			}
      // 這裡進行調用。taskId也是辨識關鍵。
			IterTaskObjKeeper.put(handle, taskId, obj);
		} finally {
			rwlock.writeLock().unlock();
		}
	}    
}
           

這是最底層的共享類,是在task manager程序的堆記憶體上的一個靜态執行個體。task manager的所有task (threads) 都可以分享。

看源碼可知,IterTaskObjKeeper 是通過一個靜态變量states實作了在整個JVM内共享。而具體内容是由 'handle' and 'taskId' 來共同決定。

IterTaskObjKeeper維持了 handle 遞增來作為 “變量state” 的唯一種類辨別。

用<handle, taskId>來作為“變量state”的唯一辨別。這個就是在 task manager process 堆記憶體中被大家共享的變量。

即handle代表哪一種變量state,<handle, taskId>表示這種變量中,對應哪一個task的哪一個變量。 這是針對task的一種細分。

/**
 * A 'state' is an object in the heap memory of task manager process,
 * shared across all tasks (threads) in the task manager.

 * Note that the 'state' is shared by all tasks on the same task manager,
 * users should guarantee that no two tasks modify a 'state' at the same time.

 * A 'state' is identified by 'handle' and 'taskId'.
 */
public class IterTaskObjKeeper implements Serializable {
	private static Map <Tuple2 <Long, Integer>, Object> states;

	/**
	 * A 'handle' is a unique identifier of a state.
	 */
	private static long handle = 0L;

	private static ReadWriteLock rwlock = new ReentrantReadWriteLock();

	static {
		states = new HashMap <>();
	}

	/**
	 * @note Should get a new handle on the client side and pass it to transformers.
	 */
	synchronized public static long getNewHandle() {
		return handle++;
	}

	public static void put(long handle, int taskId, Object state) {
		rwlock.writeLock().lock();
		try {
			states.put(Tuple2.of(handle, taskId), state); 
		} finally {
			rwlock.writeLock().unlock();
		}
	}
}
           

我們示例代碼依然如下。

static DataSet <Row> iterateICQ(...省略...) {
		return new IterativeComQueue()
			.initWithPartitionedData(TRAIN_DATA, data)
			.initWithBroadcastData(INIT_CENTROID, initCentroid)
			.initWithBroadcastData(KMEANS_STATISTICS, statistics)
			.add(new KMeansPreallocateCentroid())
			.add(new KMeansAssignCluster(distance))
			.add(new AllReduce(CENTROID_ALL_REDUCE))
			.add(new KMeansUpdateCentroids(distance))
			.setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol))
			.closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
			.setMaxIter(maxIter)
			.exec();
	}
           

Alink的AllReduce主要代碼摘取如下:

public static <T> DataSet <T> allReduce(
    return input
		.mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
		.withBroadcastSet(input, "barrier")
		.returns(
			new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
		.name("AllReduceSend")
		.partitionCustom(new Partitioner <Integer>() {
			@Override
			public int partition(Integer key, int numPartitions) {
				return key;
			}
		}, 0)
		.name("AllReduceBroadcastRaw")
		.mapPartition(new AllReduceSum(bufferName, lengthName, sessionId, op))
		.returns(
			new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
		.name("AllReduceSum")
		.partitionCustom(new Partitioner <Integer>() {
			@Override
			public int partition(Integer key, int numPartitions) {
				return key;
			}
		}, 0)
		.name("AllReduceBroadcastSum")
		.mapPartition(new AllReduceRecv <T>(bufferName, lengthName, sessionId))
		.returns(input.getType())
		.name("AllReduceRecv");
}
           

結合上面具體代碼,我們先總結AllReduce使用流程如下

  • KMeansAssignCluster :Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster。然後把自己計算出來的cluster 寫入到自己 task manager 的 CENTROID_ALL_REDUCE。
  • 每個AllReduceSend 從自己task manager的CENTROID_ALL_REDUCE中取出之前存入的 cluster(每個AllReduceSend擷取的cluster都是隻有自己能看到的),然後發送給下遊task。發送時根據 "下遊task index 和 資料量" 來決定往哪些task發送。這裡要注意的是:具體給哪一個task發送變量的哪一部分,是依據那個task 的 task index 和資料量 來計算出來的。這個計算機制(如何計算在代碼中,也有部分作為元資訊随着資料一起發送)被後面的AllReduceRecv複用。
  • 每個 AllReduceSum 接收到 AllReduceSend 發送過來的 cluster,計算求和,然後把計算結果再發送出去。每一個AllReduceSum 都是把自己計算求和出來的資料統一發給每一個下遊task。
  • 每個 AllReduceRecv 都接收到 所有 AllReduceSum 發送過來的(求和之後的)cluster。存入到共享變量CENTROID_ALL_REDUCE。具體如何存就複用AllReduceSend計算機制,這樣存到共享變量的什麼地方就互相不會沖突。可以了解為merge操作:比如有5個AllReduce,每個AllReduce的資料都發給了5個AllReduceRecv,每個AllReduceRecv接到這5份資料之後,會根據自己的subtask index寫到自己對應的state中,但是這5份資料分别寫在state什麼地方都是在資料元資訊中指定的,彼此不會有寫的沖突,這樣每個AllReduceRecv就擁有了全部5份資料。
  • KMeansUpdateCentroids :取出CENTROID_ALL_REDUCE變量,然後Update the centroids based on the sum of points and point number belonging to the same cluster

該類的作用是:為每個點(point)計算最近的聚類中心,為每個聚類中心的點坐标的計數和求和。

我們可以看出,KMeansAssignCluster 通過ComContext存儲了CENTROID_ALL_REDUCE,為後續AllReduce使用。假如有5個KMeansAssignCluster,則他們計算出來的結果一般來說各不相同。雖然存儲同一個變量名CENTROID_ALL_REDUCE,但是其state各不相同。

因為這5個KMeansAssignCluster勢必對應了5個subtask,則其在共享變量中的<handle, taskId>必不相同,則對應不同的state,是以分開存儲。

// Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
public class KMeansAssignCluster extends ComputeFunction {
        // 存取共享變量
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        if (sumMatrixData == null) {
            sumMatrixData = new double[k * (vectorSize + 1)];
            context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
        }  
    
        for (FastDistanceVectorData sample : trainData) {
            // Find the closest centroid from centroids for sample, and add the sample to sumMatrix.
            KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance, distanceMatrix);
        }    
}

// 程式中各個變量如下

sample = {FastDistanceVectorData@13274} 
 vector = {DenseVector@13281} "6.3 2.5 4.9 1.5"
 label = {DenseVector@13282} "72.2"
 rows = {Row[1]@13283} 

// 這個就是共享變量。4維向量 + 1 weight ---> 都是"sample和"。
sumMatrixData = {double[15]@10574} 
 0 = 23.6
 1 = 14.9
 2 = 8.7
 3 = 1.7000000000000002
 4 = 5.0
 5 = 52.400000000000006
 6 = 25.1
 7 = 39.699999999999996
 8 = 13.299999999999999
 9 = 9.0
 10 = 33.0
 11 = 16.9
 12 = 28.900000000000002
 13 = 11.4
 14 = 5.0
     
trainData = {ArrayList@10580}  size = 19
 0 = {FastDistanceVectorData@10590} 
  vector = {DenseVector@10595} "7.7 3.8 6.7 2.2"
   data = {double[4]@10601} 
    0 = 7.7
    1 = 3.8
    2 = 6.7
    3 = 2.2
  label = {DenseVector@10596} "123.46000000000001"
  rows = {Row[1]@10597} 
 1 = {FastDistanceVectorData@10603} 
  vector = {DenseVector@10623} "5.7 2.8 4.1 1.3"
  label = {DenseVector@10624} "58.83"
  rows = {Row[1]@10625} 
 2 = {FastDistanceVectorData@10604} 
 3 = {FastDistanceVectorData@10605} 
......
 17 = {FastDistanceVectorData@10619} 
 18 = {FastDistanceVectorData@10620} 
  vector = {DenseVector@10654} "6.5 3.0 5.2 2.0"
  label = {DenseVector@10655} "82.29"
  rows = {Row[1]@10656}      
           

這裡需要再把代碼摘錄一遍,主要是因為有withBroadcastSet。其作用是:

  • 可以了解為是一個公共的共享變量,我們可以把一個dataset 資料集廣播出去,然後不同的task在節點上都能夠擷取到,這個資料在每個節點上隻會存在一份。
  • 如果不使用broadcast,則在每個節點中的每個task中都需要拷貝一份dataset資料集,比較浪費記憶體(也就是一個節點中可能會存在多份dataset資料)。
return input
			.mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
			.withBroadcastSet(input, "barrier")
           

KMeansAssignCluster 會往上下文的變量centroidAllReduce中添加資料。是以 AllReduce 其實就是在等待這個變量。

AllReduce的第一步就是從上下文中取出共享變量,然後發送。這部分代碼由AllReduceSend完成。

對于AllReduceSend的每個task來說,bufferName都是 centroidAllReduce。

因為每個AllReduceSend也對應不同的task,是以每個AllReduceSend讀取的centroidAllReduce必然不一樣,是以每個task擷取的sendBuf都不一樣。他們分别把自己<handle, taskId>對應的 "centroidAllReduce" state取出,發送給下遊。

AllReduceSend 發給其下遊時候,是以subtask的序号為基準發送給每一個task,即本task中擷取的共享變量會發送給每一個task,但是具體給哪一個task發送變量的那一部分,是依據那個task 的 task index 和資料量 來計算出來的。如果資料量少,可能隻給某一個或者幾個task發送。

後續中的 taskId ,都是subtask id。

其中,如何計算給哪個task發送多少,是在DefaultDistributedInfo完成的。這裡需要結合 pieces 函數進行分析。需要注意的是:AllReduceSend這麼發送,AllReduceRecv後面也按照這個套路接受。這樣AllReduceRecv就可以merge了。

AllReduceSend這麼發送,AllReduceRecv後面也按照這個套路接受

int pieces = pieces(sendLen);//表示本人這次send的資料分成幾片,比如分成50片。每片大小是TRANSFER_BUFFER_SIZE

// 将要發給 8 個 subtask
for (int i = 0; i < numOfSubTasks; ++i) {
      // 假如第5個subtask,那麼它發送的起始位置就是50/8 * 4
      int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
      // 給第5個subtask發送多少片
      int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);
           

具體代碼如下:

private static int pieces(int len) {
		int div = len / TRANSFER_BUFFER_SIZE; //本人這次send的資料分成幾片,每片大小是TRANSFER_BUFFER_SIZE
		int mod = len % TRANSFER_BUFFER_SIZE;

		return mod == 0 ? div : div + 1;
	}

public class DefaultDistributedInfo implements DistributedInfo {

	public long startPos(long taskId, long parallelism, long globalRowCnt) {
		long div = globalRowCnt / parallelism;
		long mod = globalRowCnt % parallelism;

		if (mod == 0) {
			return div * taskId;
		} else if (taskId >= mod) {
			return div * taskId + mod;
		} else {
			return div * taskId + taskId;
		}
	}
    
	public long localRowCnt(long taskId, long parallelism, long globalRowCnt) {
		long div = globalRowCnt / parallelism;
		long mod = globalRowCnt % parallelism;

		if (mod == 0) {
			return div;
		} else if (taskId >= mod) {
			return div;
		} else {
			return div + 1;
		}
	}     
}
           

具體AllReduceSend代碼如下,注解中有詳細說明。

// 這裡是變量名字定義。	
public static final String CENTROID_ALL_REDUCE = "centroidAllReduce";

private static class AllReduceSend<T> extends RichMapPartitionFunction <T, Tuple3 <Integer, Integer, double[]>> {
        
    	int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();
		// 與并行度相關,每個task都會執行相同操作
		// bufferName都是 centroidAllReduce,每個task擷取的sendBuf都不一樣
    
        // 計算怎麼發送所需要的資料結構
    	int pieces = pieces(sendLen);
    	DistributedInfo distributedInfo = new DefaultDistributedInfo();

        // 從上下文中擷取需要傳送的資料
		double[] sendBuf = context.getObj(bufferName);
        
			int agg = 0;
    		// 可以看出來,是把需要傳送的資料給每個task都發送。當然這個發送是根據發送資料的大小來确定的,如果資料量小,可能就隻給一個或者幾個task發送。
			for (int i = 0; i < numOfSubTasks; ++i) {
                // startPos : 具體發送變量的那一部分,是依據task index來決定的。
                // cnt : 具體哪一個下遊 task i 發送多少資料由此決定,如果是0,就不給task i發送資料。
				int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
				int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);

				for (int j = 0; j < cnt; ++j) {
                    // 發送哪一個部分
					int bufStart = (startPos + j) * TRANSFER_BUFFER_SIZE;
					// the last
					if (startPos + j == pieces - 1) {
						System.arraycopy(sendBuf, bufStart, transBuf, 0, lastLen(sendLen));
					} else {
						System.arraycopy(sendBuf, bufStart, transBuf, 0, TRANSFER_BUFFER_SIZE);
					}
					agg++;
                    
          // i 是subTasks的index,startPos + j是buffer内的位置,後續分區實際就是按照這個 i 來分區的。本AllReduceSend就是發送到numOfSubTasks這些task中。
					out.collect(Tuple3.of(i, startPos + j, transBuf));
				}
			}
}

	private static int pieces(int len) {
		int div = len / TRANSFER_BUFFER_SIZE; // 4096
		int mod = len % TRANSFER_BUFFER_SIZE;
		return mod == 0 ? div : div + 1;
	}

sendBuf = {double[15]@10602} 
 0 = 40.3
 1 = 18.200000000000003
 2 = 33.6
 3 = 12.5
 4 = 6.0
 5 = 45.3
 6 = 30.599999999999998
 7 = 12.4
 8 = 2.0
 9 = 9.0
 10 = 24.0
 11 = 10.4
 12 = 17.1
 13 = 5.199999999999999
 14 = 4.0

this = {AllReduce$AllReduceSend@10598} 
 bufferName = "centroidAllReduce"
 lengthName = null
 transferBufferName = "3dfb2aae-683d-4497-91fc-30b8d6853bce"
 sessionId = 0
 runtimeContext = {AbstractIterativeTask$IterativeRuntimeUdfContext@10606}       
           

AllReduceSend發送變量給下遊時候,使用了自定義的partition(partitionCustom )。其是用 index of subtask 來作為key分區。這樣就和AllReduceSend那個out.collect對應了。

.partitionCustom(new Partitioner <Integer>() {
				@Override
				public int partition(Integer key, int numPartitions) {
					return key;
				}
			}, 0)
			.name("AllReduceBroadcastRaw")
               
// 調用到這個partition函數的調用棧
                
partition:102, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
partition:99, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
customPartition:235, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:149, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:36, OutputEmitter (org.apache.flink.runtime.operators.shipping)
emit:120, RecordWriter (org.apache.flink.runtime.io.network.api.writer)
collect:65, OutputCollector (org.apache.flink.runtime.operators.shipping)
collect:35, CountingCollector (org.apache.flink.runtime.operators.util.metrics)
mapPartition:257, AllReduce$AllReduceSend (com.alibaba.alink.common.comqueue.communication)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
run:504, BatchTask (org.apache.flink.runtime.operators)
run:157, AbstractIterativeTask (org.apache.flink.runtime.iterative.task)
run:107, IterationIntermediateTask (org.apache.flink.runtime.iterative.task)
invoke:369, BatchTask (org.apache.flink.runtime.operators)
doRun:705, Task (org.apache.flink.runtime.taskmanager)
run:530, Task (org.apache.flink.runtime.taskmanager)
run:745, Thread (java.lang)                  
                
                 
 // @AllReduceSend.mapPartition 這裡開始調用   
 for (int i = 0; i < numOfSubTasks; ++i) {   
     // i 是subTasks的index,後續分區實際就是按照這個 i 來分區的。本AllReduceSend就是發送到numOfSubTasks這些task中。
	 out.collect(Tuple3.of(i, startPos + j, transBuf));     
 }
                
 // 從後續調用序列可以看出來,最終是用 index of subtask 來作為key分區。    

// 這裡發送record

 public class CountingCollector<OUT> implements Collector<OUT> {
	public void collect(OUT record) {
		this.numRecordsOut.inc();
		this.collector.collect(record);
	}     
 }
             
 record = {Tuple3@10586} "(0,0,[40.50000000000001, 18.7, 33.300000000000004, 12.8, 6.0, 29.7, 21.0, 8.4, 1.7, 6.0, 48.1, 22.199999999999996, 36.0, 12.200000000000001, 8.0, 0.0,"
 f0 = {Integer@10583} 0
 f1 = {Integer@10583} 0
 f2 = {double[4096]@10598}                
       
// 這裡開始分區

public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {
	private int customPartition(T record, int numberOfChannels) {
		if (extractedKeys == null) {
			extractedKeys = new Object[1];
		}

		if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
            // 是以 key 是 0
			final Object key = extractedKeys[0];
			return partitioner.partition(key, numberOfChannels);
		}            
	}    
}

public final class TupleComparator<T extends Tuple> extends TupleComparatorBase<T> {
	public int extractKeys(Object record, Object[] target, int index) {
		int localIndex = index;
		for(int i = 0; i < comparators.length; i++) {
			localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex);
		}
		return localIndex - index;
	}    
}

// 就是取出第一個field的數值

key = {Integer@10583} 0
 value = 0
    
extractedKeys = {Object[1]@10587} 
 0 = {Integer@10583} 0
  value = 0
           

所有workers在它收到的資料上做reduce,然後把這個部分reduce的結果(partial results)發送給其他workers。

partial results是因為每個task接受的資料不同,是上遊根據task index計算位置并且發送過來的。

但是AllReduceSum的計算結果會給每一個下遊 task index 發送。

private static class AllReduceSum extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, Tuple3 <Integer, Integer, double[]>> {
    
    	public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values,Collector <Tuple3 <Integer, Integer, double[]>> out) {
            
            // 這時候雖然也用到了context取出了sendBuf,但是隻是用來擷取其長度而已。
    		int taskId = getRuntimeContext().getIndexOfThisSubtask();
			int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();

			double[] sendBuf = context.getObj(bufferName);
			int sendLen = lengthName != null ? context.getObj(lengthName) : sendBuf.length;
			int pieces = pieces(sendLen);
			DistributedInfo distributedInfo = new DefaultDistributedInfo();

            // startPos : 本task接受的資料,startPos 是應該從原始資料的哪個位置開始。是依據task index來決定的。
            // cnt : 具體哪一個下遊 task i 發送多少資料由此決定。   
			int startPos = (int) distributedInfo.startPos(taskId, numOfSubTasks, pieces);
			int cnt = (int) distributedInfo.localRowCnt(taskId, numOfSubTasks, pieces);
    
    		// 這裡進行了reduce SUM工作
			double[][] sum = new double[cnt][];
			double[] agg = new double[cnt];
			do {
				Tuple3 <Integer, Integer, double[]> val = it.next();
				int localPos = val.f1 - startPos;
				if (sum[localPos] == null) {
					sum[localPos] = val.f2;
					agg[localPos]++;
				} else {
					op.accept(sum[localPos], val.f2);
				}
			} while (it.hasNext());    
    
    		// 依然發送給下遊,依然是用subtask index來作為partition key。
            // 注意,這裡是把結果發送給所有的下遊task。
			for (int i = 0; i < numOfSubTasks; ++i) {
				for (int j = 0; j < cnt; ++j) {
          // startPos是本task發送的資料應該從原始資料的哪個位置開始。
          // 但是給每一個 task i 發的都是同樣的資料。但是 startPos + j 很重要,下遊task i 會根據這個知道它應該把接收到的資料存儲在預定義變量的什麼地方。
					out.collect(Tuple3.of(i, startPos + j, sum[j]));
				}
			}   
        }
}

sum = {double[1][]@10605} 
 0 = {double[4096]@10613} 
  0 = 118.50000000000001
  1 = 77.7
  2 = 37.2
  3 = 5.9
  4 = 25.0
  5 = 621.1000000000001
  6 = 284.7
  7 = 487.59999999999997
  8 = 166.5
  9 = 99.0
  10 = 136.9
  11 = 95.7
  12 = 39.0
  13 = 7.4
  14 = 26.0
           

AllReduceSum 發送變量給下遊時候,使用了自定義的partition(partitionCustom )。其是用 index of subtask 來作為key分區。

其意義和之前的 partitionCustom 相同。

All workers merge partial results into final result and put it into session context with pre-defined object name.

每一個下遊 AllReduceRecv 都接收到 每一個上遊 AllReduceSum 發送過來的 cluster(求和之後的),然後把每份資料存入到自己task manager對應的預定義變量state的不同部分(這個不同部分是根據接受到的資料val.f1計算出來的)。

結合前面可知,AllReduceSend發送和AllReduceRecv接受,都是按照同樣的套路計算在共享變量中的資料位置。這樣AllReduceRecv就可以merge了。

這樣就完成了所有workers把部分reduce sum的結果合并成為最終結果,然後放入預定義的上下文變量中。

private static class AllReduceRecv<T> extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, T> {
		private final String bufferName;
		private final String lengthName;
		private final int sessionId;

		@Override
		public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values, Collector <T> out) throws Exception {
			ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
			Iterator <Tuple3 <Integer, Integer, double[]>> it = values.iterator();
			if (!it.hasNext()) {
				return;
			}
			double[] recvBuf = context.getObj(bufferName);
			int recvLen = lengthName != null ? context.getObj(lengthName) : recvBuf.length;
			int pieces = pieces(recvLen); // 和之前AllReduceSend一樣的套路計算應該存儲在共享變量什麼位置。
			do {
				Tuple3 <Integer, Integer, double[]> val = it.next();
				if (val.f1 == pieces - 1) {
					System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, lastLen(recvLen));
				} else {
           // 拷貝到共享變量的相應部位。val.f1 是上遊發送過來的。作為merge功能的起始位置。
					System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, TRANSFER_BUFFER_SIZE);
				}
			} while (it.hasNext());
		}
	}

val = {Tuple3@10672} "(3,0,[335.3, 150.89999999999998, 277.5, 99.79999999999998, 50.0, 290.9, 136.3, 213.1, 67.8, 50.0, 250.3, 170.89999999999998, 73.2, 12.2, 50.0, 0.0....."
 f0 = {Integer@10682} 3
  value = 3
 f1 = {Integer@10638} 0
  value = 0
 f2 = {double[4096]@10674} 
  0 = 335.3
  1 = 150.89999999999998
  2 = 277.5
  3 = 99.79999999999998
  4 = 50.0
  5 = 290.9
  6 = 136.3
  7 = 213.1
  8 = 67.8
  9 = 50.0
  10 = 250.3
  11 = 170.89999999999998
  12 = 73.2
  13 = 12.2
  14 = 50.0
  15 = 0.0
  ......
      
// 每個task都收到了reduce sum結果。      
recvBuf = {double[15]@10666} 
 0 = 404.3
 1 = 183.1
 2 = 329.3
 3 = 117.2
 4 = 61.0
 5 = 250.3
 6 = 170.89999999999998
 7 = 73.20000000000002
 8 = 12.2
 9 = 50.0
 10 = 221.89999999999998
 11 = 104.1
 12 = 161.29999999999998
 13 = 50.4
 14 = 39.0      
      
           

基于點計數和坐标,計算新的聚類中心。這裡就是從task manager中取出了AllReduce存儲的共享變量CENTROID_ALL_REDUCE。

/**
 * Update the centroids based on the sum of points and point number belonging to the same cluster.
 */
public class KMeansUpdateCentroids extends ComputeFunction {
    public void calc(ComContext context) {

        Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer k = context.getObj(KMeansTrainBatchOp.K);

        // 這裡取出AllReduce存儲的共享變量
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);

        Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
        if (context.getStepNo() % 2 == 0) {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
        } else {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
        }

        stepNumCentroids.f0 = context.getStepNo();

        context.putObj(KMeansTrainBatchOp.K,
            updateCentroids(stepNumCentroids.f1, k, vectorSize, sumMatrixData, distance));
    }
}
           

我的并行計算之路(四)MPI集合通信之Reduce和Allreduce

Message Passing Interface(MPI)

Flink 之 Dataflow、Task、subTask、Operator Chains、Slot 介紹

Flink運作時之TaskManager執行Task