Attention
Published:
In this blog, we will focus on the attention mechanism, which is a key component of the Transformer architecture.
Self-Attention (SA, Scaled Dot-Product Attention)
Assume we have a sequence of token embeddings \(X=[x_1, x_2, \cdots, x_n]\), where \(x_i \in \mathbb{R}^d\). We can compute the self-attention of the sequence by computing the attention of each token with every other token in the sequence.
- Query Matrix: \(Q=XW^Q\)
- Key Matrix: \(K=XW^K\)
- Value Matrix: \(V=XW^V\)
where \(W^Q, W^K, W^V \in \mathbb{R}^{d \times d_k}\).
And the self-attention is defined as:
\[\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]The calculation can be break down into four parts:
- \(QK^T\): Actually, it is the dot product of each token with every other token in the sequence, which is the similarity between each token in the sequence.
- \(\frac{1}{\sqrt{d_k}}\): This is the scaling factor to prevent the product \(QK^T\) to be too large, which can cause the softmax function to be unstable.
- \(\text{softmax}\): The softmax function is used to convert the similarity into a probability distribution.
- \(V\): The value matrix is used to compute the final output.
The complexity of self-attention operation is \(O(n^2)\), where \(n\) is the length of the input sequence.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def self_attention(query, key, value, mask=None)
"""
query: [batch_size, num_heads, seq_len, d_k]
key: [batch_size, num_heads, seq_len, d_k]
value: [batch_size, num_heads, seq_len, d_k]
mask: [batch_size, seq_len, seq_len] (optional)
"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask==0, float('-inf'))
p_attn = F.softmax(scores, dim=-1)
output = torch.matmul(p_attn, value)
return output, p_attn
Multi-Head Attention (MHA)
MHA layer contains \(h\) attentin heads which enable input embeddings to attend to one another in \(h\) different ways and in parallel.
\[\text{MultiHead}(Q,K,V) = \text{Concat}(h_1, h_2, \cdots, h_h)W^O,\]where
\[h_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]and \(W_i^Q, W_i^K\in \mathbb{R}^{d \times d_k}\), \(W_i^V\in \mathbb{R}^{d \times d_v}\), \(W^O\in \mathbb{R}^{hd_v \times d}\).
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0,
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, d_k]
# -> [batch_size, seq_len, num_heads, d_k]
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_v).transpose(1,2)
output, p_attn = self_attention(Q,K,V,mask)
# [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, num_heads, d_k]
# -> [batch_size, seq_len, d_model]
output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(output)
return output, p_attn
