Attention

2 minute read

Published:

In this blog, we will focus on the attention mechanism, which is a key component of the Transformer architecture.

Self-Attention (SA, Scaled Dot-Product Attention)

Assume we have a sequence of token embeddings \(X=[x_1, x_2, \cdots, x_n]\), where \(x_i \in \mathbb{R}^d\). We can compute the self-attention of the sequence by computing the attention of each token with every other token in the sequence.

  • Query Matrix: \(Q=XW^Q\)
  • Key Matrix: \(K=XW^K\)
  • Value Matrix: \(V=XW^V\)

where \(W^Q, W^K, W^V \in \mathbb{R}^{d \times d_k}\).

And the self-attention is defined as:

\[\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

The calculation can be break down into four parts:

  • \(QK^T\): Actually, it is the dot product of each token with every other token in the sequence, which is the similarity between each token in the sequence.
  • \(\frac{1}{\sqrt{d_k}}\): This is the scaling factor to prevent the product \(QK^T\) to be too large, which can cause the softmax function to be unstable.
  • \(\text{softmax}\): The softmax function is used to convert the similarity into a probability distribution.
  • \(V\): The value matrix is used to compute the final output.

The complexity of self-attention operation is \(O(n^2)\), where \(n\) is the length of the input sequence.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def self_attention(query, key, value, mask=None)  
    """
    query: [batch_size, num_heads, seq_len, d_k]
    key: [batch_size, num_heads, seq_len, d_k]
    value: [batch_size, num_heads, seq_len, d_k]
    mask: [batch_size, seq_len, seq_len] (optional)
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k)

    if mask is not None:
      scores = scores.masked_fill(mask==0, float('-inf'))
    
    p_attn = F.softmax(scores, dim=-1)

    output = torch.matmul(p_attn, value)

    return output, p_attn

Multi-Head Attention (MHA)

MHA layer contains \(h\) attentin heads which enable input embeddings to attend to one another in \(h\) different ways and in parallel.

\[\text{MultiHead}(Q,K,V) = \text{Concat}(h_1, h_2, \cdots, h_h)W^O,\]

where

\[h_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

and \(W_i^Q, W_i^K\in \mathbb{R}^{d \times d_k}\), \(W_i^V\in \mathbb{R}^{d \times d_v}\), \(W^O\in \mathbb{R}^{hd_v \times d}\).

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0,

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
      batch_size = query.size(0)
      # [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, d_k]
      # -> [batch_size, seq_len, num_heads, d_k]
      Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
      K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
      V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_v).transpose(1,2)

      output, p_attn = self_attention(Q,K,V,mask)

      # [batch_size, num_heads, seq_len, d_k] -> [batch_size, seq_len, num_heads, d_k]
      # -> [batch_size, seq_len, d_model]
      output = output.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)

      output = self.W_o(output)

      return output, p_attn