Deep Attention Recurrent Q-Network
5vision groups
摘要:本文将 DQN 引入了 Attention 機制,使得學習更具有方向性和指導性。(前段時間做一個工作打算就這麼幹,誰想到,這麼快就被這幾個孩子給實作了,自愧不如啊( ⊙ o ⊙ ))
引言:我們知道 DQN 是将連續 4幀的視訊資訊輸入到 CNN 當中,那麼,這麼做雖然取得了不錯的效果,但是,仍然隻是能記住這 4 幀的資訊,之前的就會遺忘。是以就有研究者提出了 Deep Recurrent Q-Network (DRQN),一個結合 LSTM 和 DQN 的工作:
1. the fully connected layer in the latter is replaced for a LSTM one ,
2. only the last visual frame at each time step is used as DQN's input.
作者指出雖然隻是使用了一幀的資訊,但是 DRQN 仍然抓住了幀間的相關資訊。盡管如此,仍然沒有看到在 Atari game上有系統的提升。
另一個缺點是:長時間的訓練時間。據說,在單個 GPU 上訓練時間達到 12-14天。于是,有人就提出了并行版本的算法來提升訓練速度。作者認為并行計算并不是唯一的,最有效的方法來解決這個問題。
最近 visual attention models 在各個任務上都取得了驚人的效果。利用這個機制的優勢在于:僅僅需要選擇然後注意一個較小的圖像區域,可以幫助降低參數的個數,進而幫助加速訓練和測試。對比 DRQN,本文的 LSTM 機制存儲的資料不僅用于下一個 actions 的選擇,也用于 選擇下一個 Attention 區域。此外,除了計算速度上的改進之外,Attention-based models 也可以增加 Deep Q-Learning 的可讀性,提供給研究者一個機會去觀察 agent 的集中區域在哪裡以及是什麼,(where and what)。
Deep Attention Recurrent Q-Network:
如上圖所示,DARQN 結構主要由 三種類型的網絡構成:convolutional (CNN), attention, and recurrent . 在每一個時間步驟 t,CNN 收到目前遊戲狀态 $s_t$ 的一個表示,根據這個狀态産生一組 D feature maps,每一個的次元是 m * m。Attention network 将這些 maps 轉換成一組向量 $v_t = \{ v_t^1, ... , v_t^L \}$,L = m*m,然後輸出其線性組合 $z_t$,稱為 a context vector. 這個 recurrent network,在我們這裡是 LSTM,将 context vector 作為輸入,以及 之前的 hidden state $h_{t-1}$,memory state $c_{t-1}$,産生 hidden state $h_t$ 用于:
1. a linear layer for evaluating Q-value of each action $a_t$ that the agent can take being in state $s_t$ ;
2. the attention network for generating a context vector at the next time step t+1.
Soft attention :
這一小節提到的 "soft" Attention mechanism 假設 the context vector $z_t$ 可以表示為 所有向量 $v_t^i$ 的權重和,每一個對應了從圖像不同區域提取出來的 CNN 特征。權重 和 這個 vector 的重要程度成正比例,并且是通過 Attention network g 衡量的。g network 包含兩個 fc layer 後面是一個 softmax layer。其輸出可以表示為:
其中,Z是一個normalizing constant。W 是權重矩陣,Linear(x) = Ax + b 是一個放射變換,權重矩陣是A,偏差是 b。我們一旦定義出了每一個位置向量的重要性,我們可以計算出 context vector 為:
另一個網絡在第三小節進行詳細的介紹。整個 DARQN model 是通過最小化序列損失函數完成訓練:
其中,$Y_t$ 是一個近似的 target value,為了優化這個損失函數,我們利用标準的 Q-learning 更新規則:
DARQN 中的 functions 都是可微分的,是以每一個參數都有梯度,整個模型可以 end-to-end 的進行訓練。本文的算法也借鑒了 target network 和 experience replay 的技術。
Hard Attention:
此處的 hard attention mechanism 采樣的時候要求僅僅從圖像中采樣一個圖像 patch。
假設 $s_t$ 從環境中采樣的時候,受到了 attention policy 的影響,attention network g 的softmax layer 給出了帶參數的類别分布(categorical distribution)。然後,在政策梯度方法,政策參數的更新可以表示為:
其中 $R_t$ 是将來的折扣的損失。為了估計這個值,另一個網絡 $G_t = Linear(h_t)$ 才引入進來。這個網絡通過朝向 期望值 $Y_t$ 進行網絡訓練。Attention network 參數最終的更新采用如下的方式進行:
其中 $G_t - Y_t$ 是advantage function estimation。
作者提供了源代碼:https://github.com/5vision/DARQN
實驗部分:
總結: