Назад к подготовке

Long-context training: почему не помещается attention

При обучении на сотнях тысяч токенов обычный/Flash Attention все равно не помещается в GPU. Что раздувает память и какие классы решений есть?

Ответить самому

Сначала сформулируйте ответ как на собеседовании, затем откройте разбор и оцените себя.

Загрузка

Короткий ответ

Главная проблема - активации и attention-related tensors, растущие с sequence length; помогают sparse/local attention, sequence parallelism и activation checkpointing.

Полный разбор

Flash Attention снижает materialization n x n attention matrix и улучшает IO, но не отменяет фундаментальную стоимость очень длинной последовательности: Q/K/V, activations для backprop, MLP activations и коммуникация становятся огромными.

Решения зависят от цели: local/sparse/sliding attention меняет математическую структуру; sequence/context parallelism распределяет sequence dimension между GPU; activation checkpointing пересчитывает часть forward вместо хранения; tensor/pipeline/FSDP/ZeRO распределяют параметры, gradients и optimizer state. Важно не просто порезать sequence на независимые куски, иначе потеряется attention между сегментами.