Multi-head Attention
Contents
詳解 Multi-head Attention 實作。
Attention
Attention 的算法公式列在下方:
$$A(Q, K, V) = \text{softmax} ( \frac{QK^T}{\sqrt{d_k}})V$$其中,Q, K, V 矩陣是 input tensor 經過分別三個 linear projection 得到。
上述的 attention 是 single head 的,只有計算一組 Q, K, V。所謂 Multi-head Attention,則是希望同時有多組的 Q, K, V,讓模型可以學習到不同的 feature。
Multi-head Attention
先上程式碼,以 Huggingface 的 LlamaAttention 實作舉例,只留下核心算法的部份:
|
|
以下分段解析 forward 的部份。
Step 1
在實作方法中,Q, K, V 的 linear projection 仍然只用各一個 linear projection 完成 (q_proj
, k_proj
, v_proj
)。
|
|
Step 2
info
先提一下整個 MHA 實作的中心思想:
把原本 hidden dim 平均拆成 num_heads 塊,這些子塊代表不同的 head,各自獨立計算 attention,最後再將結果拼回一起。
此步驟的目的是要將 Q, K, V 沿著 hidden dim 維度平均拆成 num_heads 份。
對 Q, K, V 做 reshape,把原本 token 的 hidden dim 切成 num_heads x head_dim,如下方中間的圖。每個 token 的 hidden dim 被分成 num_heads 組了。 再來,做 transpose 把 num_heads 和 seq_len 這兩個維度對調,把同的 head 部份拼在一起。
到這邊可以發現,這兩個步驟其實等同於把原本的矩陣依照 hidden dim 維度分割成 num_heads 塊。
因此我們得到了 Multi-head Attention 需要的多組 Q, K, V (即 \(q_i, k_i, v_i, i=0, 1, \dots \text{,num_heads-1}\))。
好像有點難想像?
舉間單例子:seq_len = 4, hidden_dim = 6, num_heads = 2, head_dim = 3
|
|
很巧妙對吧。看到這裡我不禁驚嘆出聲。
Step 3
|
|
對各組 \(q_i, k_i, v_i\) 個別計算 attention,得到單個 head 的 attention output \(s_i\),這邊利用高維矩陣相乘實作。
高維矩陣相乘,其實就是對於高維矩陣中的每個二維矩陣做矩陣乘法。剛剛處理完的 Q, K, V 裡含有 num_heads 個二維矩陣,因此對對應的 \(q_i,k_i^T,v_i\) 計算 attention 這件事在編寫程式上就可以用兩次矩陣乘法完成。
Step 4
計算完所有 heads 的 attention 得到 num_heads 個 \(s_i\) 後,實行 Step 2 的逆操作。
先將 seq_len 和 num_heads 維度對調 (transpose) ,相當於把屬於個別 token 的 hidden dim 子塊拼起來,再把矩陣 reshape 成 (seq_len, hidden dim)。
至此,原本各自計算的 \(s_i\) 被合回一個二維矩陣,完成 Multi-head Attention 的計算。
Step 5
在 Multi-head Attention 運算最後,常會把結果再做一次 linear projection,稱為 output projection。得到的結果為最終 Multi-head Attention 的 output。
後記
看了很多 Multihead attention 的解說但一直沒能懂,那些圖完全沒能輔助我解讀計算過程。徹底把程式碼拆出來看後我才徹徹底底的理解 MHA 到底是個什麼樣的機制,並把理論很緊密和實作的關聯起來。因此決定寫一篇 blog 把我的理解過程寫下來備忘,並順手畫一些圖輔助理解。