Long-context training: почему не помещается attention
При обучении на сотнях тысяч токенов обычный/Flash Attention все равно не помещается в GPU. Что раздувает память и какие классы решений есть?
Ответить самому
Сначала сформулируйте ответ как на собеседовании, затем откройте разбор и оцените себя.
Короткий ответ
Главная проблема - активации и attention-related tensors, растущие с sequence length; помогают sparse/local attention, sequence parallelism и activation checkpointing.
Полный разбор
Flash Attention снижает materialization n x n attention matrix и улучшает IO, но не отменяет фундаментальную стоимость очень длинной последовательности: Q/K/V, activations для backprop, MLP activations и коммуникация становятся огромными.
Решения зависят от цели: local/sparse/sliding attention меняет математическую структуру; sequence/context parallelism распределяет sequence dimension между GPU; activation checkpointing пересчитывает часть forward вместо хранения; tensor/pipeline/FSDP/ZeRO распределяют параметры, gradients и optimizer state. Важно не просто порезать sequence на независимые куски, иначе потеряется attention между сегментами.