天天看點

DI-engine強化學習入門(十又二分之一)如何使用RNN

作者:古月居

一、資料處理

用于訓練 RNN 的 mini-batch 資料不同于通常的資料。 這些資料通常應按時間序列排列。 對于 DI-engine, 這個處理是在 collector 階段完成的。 使用者需要在配置檔案中指定 learn_unroll_len 以確定序列資料的長度與算法比對。 對于大多數情況, learn_unroll_len 應該等于 RNN 的曆史長度(a.k.a 時間序列長度),但在某些情況下并非如此。比如,在 r2d2 中, 我們使用burn-in操作, 序列長度等于 learn_unroll_len + burnin_step 。 這裡将在下一節中具體解釋。

什麼是資料處理?

資料處理指的是為循環神經網絡(RNN)訓練準備時間序列資料的過程。這個過程包括将收集到的資料組織成适當格式的小批量(mini-batches),這些批量資料将用于網絡的訓練。這一步驟通常發生在DI-engine的collector階段,也就是資料收集和預處理發生的地方。使用者需要在配置檔案中指定 learn_unroll_len 以確定序列資料的長度與算法比對。 對于大多數情況, learn_unroll_len 應該等于 RNN 的曆史長度(a.k.a 時間序列長度),但在某些情況下并非如此。比如,在 r2d2 中, 我們使用burn-in操作, 序列長度等于 learn_unroll_len + burnin_step 。例如,如果你設定 learn_unroll_len = 10 和 burnin_step = 5,那麼 RNN 實際接收的輸入序列長度将是 15:前 5 步為 burn-in(用于預熱隐藏狀态),接下來的 10 步作為學習的一部分。這樣設定可以幫助 RNN 在計算梯度和進行權重更新時,有一個更加準确的隐藏狀态作為起點。

部分名詞解釋

  • mini-batches:在機器學習中,特别是在訓練神經網絡時,資料一般被分成小的批次進行處理,這些批次被稱為 “mini-batch”。一個 mini-batch 包含了一組樣本,這組樣本用于執行單次疊代的前向傳播和反向傳播,以更新網絡的權重。使用 mini-batches 而不是單個樣本或整個資料集(後者稱為 “batch” 或 “full-batch”)可以平衡計算效率和記憶體限制,有助于提高學習的穩定性和收斂速度。
  • collector階段:在 DI-engine中,collector 階段是指環境與智能體互動并收集經驗資料的過程。在這個階段,智能體根據其目前的政策執行操作,環境則傳回新的狀态、獎勵和其他可能的資訊,如是否達到終止狀态。收集到的資料(經常被稱為經驗或轉換)随後被用于訓練智能體的模型,例如對政策或價值函數進行更新。

為什麼要進行資料處理:

  1. 保持時間依賴性:RNN的核心優勢是處理具有時間序列依賴性的資料,比如語言、視訊幀、股票價格等。正确的資料處理確定了這些時間依賴性在訓練資料中得以保留,使得模型能夠學習到資料中的序列特征。
  2. 提高學習效率:通過将資料劃分為與模型期望的序列長度比對的批次,可以提高模型學習的效率。這樣做可以確定網絡在每次更新時都接收到足夠的上下文資訊。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的輸入資料。例如,标準的RNN隻需要過去的資訊,而一些變體如LSTM或GRU可能會處理更長的序列。特定的算法,如R2D2,還可能需要額外的步驟(如burn-in),以便更好地初始化網絡狀态。
  4. 處理不規則長度:在現實世界的資料集中,序列長度往往是不規則的。資料處理確定了每個mini-batch都有統一的序列長度,這通常通過截斷過長的序列或填充過短的序列來實作。
  5. 優化記憶體和計算資源:通過将資料組織成具有固定時間步長的批次,可以更有效地利用GPU等計算資源,因為這些資源在處理固定大小的資料時通常更高效。
  6. 穩定學習過程:特别是在強化學習中,使用如n-step傳回或經驗回放的技術,可以幫助模型從環境回報中學習,并減少方差,進而穩定學習過程。
  7. 如何進行資料處理
  8. 比如原始采樣資料是 [],每個x表示 [](或者 ,隐藏狀态等),此時 n_sample = 6 。此時根據所需 RNN 的序列長度即 learn_unroll_len 有以下三種情況:
  9. n_sample >= learn_unroll_len 并且 n_sample 可以被 learn_unroll_len 除盡: 例如 n_sample=6 和 learn_unroll_len=3,資料将被排列為:[[],[]]。
  10. n_sample >= learn_unroll_len 并且 n_sample 不可以被 learn_unroll_len 除盡: 預設情況下,殘差資料将由上一個樣本中的一部分資料填充,例如如果 n_sample=6 和 learn_unroll_len=4 ,資料将被排列為[[],[]]。
  11. n_sample < learn_unroll_len:例如如果 n_sample=6 和 learn_unroll_len=7,預設情況下,算法将使用 null_padding 方法,資料将被排列為 [[]]。類似于但它的 done=True 和 reward=0。

這裡以r2d2算法為例,在r2d2中,在方法 _get_train_sample 中通過調用函數 get_nstep_return_data 和 get_train_sample 擷取按時序排列的資料。

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
    return get_train_sample(data, self._sequence_len)           

代碼段 def _get_train_sample(self, data: list) 是一個方法,它的作用是從收集到的資料中提取用于訓練 RNN 的樣本。這個方法會在兩個步驟中處理資料:

點選DI-engine強化學習入門(十又二分之一)如何使用RNN——資料處理、隐藏狀态、Burn-in - 古月居可檢視全文