Multi-head attention как PyTorch-модуль
Multi-head attention как PyTorch-модуль
Ответить самому
Сначала сформулируйте ответ как на собеседовании, затем откройте разбор и оцените себя.
Короткий ответ
Project inputs to Q/K/V first, split the projected tensors into heads, compute scaled dot-product attention per head, concatenate heads, then apply the output projection and feed-forward/residual blocks.
Полный разбор
A clean answer starts from shapes. For input x with shape batch x seq_len x embed_dim, define linear projections Wq, Wk and Wv. Apply those projections before splitting into heads. Then reshape to batch x num_heads x seq_len x head_dim, where head_dim = embed_dim / num_heads.
For each head, compute attention_scores = QK^T / sqrt(head_dim), softmax over the key sequence dimension, and multiply by V. The result is concatenated back to batch x seq_len x embed_dim and passed through an output projection. In a full Transformer block, this is surrounded by residual connections, LayerNorm and a feed-forward network with a nonlinearity.
Common mistakes are splitting before projection with incompatible dimensions, applying the feed-forward layer before concatenating heads, softmaxing over the wrong axis, forgetting the scale factor, and forgetting that two linear layers without an activation collapse into one linear layer.
Теория
Multi-head attention is mostly a tensor-shape discipline around projected Q/K/V and per-head scaled dot-product attention.
Типичные ошибки
- Split raw input heads before Q/K/V projections in a way that breaks dimensions.
- Use softmax over the head or batch dimension.
- Forget output projection, residual, LayerNorm or activation in the feed-forward block.
Как отвечать на собеседовании
- Say the shapes out loud after every reshape.
- Mention the full block components even if only attention is implemented.