Вопрос про production ML
In PyTorch DDP training, which common layer can behave badly across processes and how do teams usually handle it?
Ответить самому
Сначала сформулируйте ответ как на собеседовании, затем откройте разбор и оцените себя.
Короткий ответ
BatchNorm is the classic issue: each process sees only its local mini-batch, so statistics can diverge. Use SyncBatchNorm, larger effective batches, different normalization, or accept the approximation.
Полный разбор
BatchNorm depends on batch statistics. In DDP, each process usually receives only a shard of the global batch, so ordinary BatchNorm computes mean and variance from the local shard. If the per-GPU batch is small or non-representative, the normalization can be noisy or inconsistent.
One common fix is synchronized batch normalization, such as PyTorch SyncBatchNorm, which synchronizes statistics across processes. Another path is to use normalization layers that do not depend on cross-sample batch statistics, such as LayerNorm or GroupNorm, depending on the architecture. Some teams also simply tolerate local BatchNorm when the per-device batch is large enough and metrics are stable.
The production answer should include the tradeoff: synchronization costs communication and can slow training, so it is not automatically the best choice.
Теория
DDP synchronizes gradients by default, not necessarily the activation statistics used by BatchNorm.
Типичные ошибки
- Assume DDP automatically makes BatchNorm global.
- Forget the communication overhead of SyncBatchNorm.
- Ignore per-device batch size.
Как отвечать на собеседовании
- Say exactly what statistic is local.
- Mention LayerNorm or GroupNorm as practical alternatives.