Глубина
~20 мин

PyTorch Lightning

Абстракция train loop, callbacks, logging, воспроизводимость — индустриальный стандарт обучения.

PyTorch Lightning — фреймворк поверх PyTorch

Чистый PyTorch даёт полный контроль, но каждый проект начинается с одного и того же: training loop, валидация, логирование, чекпоинты, multi-GPU. PyTorch Lightning — это тонкая абстракция, которая стандартизирует boilerplate, не забирая гибкость. Ты пишешь только то, что уникально для твоей задачи: модель, loss, данные. Всё остальное — Trainer.

Три главные причины использовать Lightning: 1. Воспроизводимость — структура кода единообразная, seed_everything фиксирует все генераторы, конфигурация сериализуется. 2. Масштабируемость — переход с CPU на GPU, multi-GPU, TPU — одна строка в Trainer. Не нужно переписывать .to(device), DistributedDataParallel и т.д. 3. Меньше багов — стандартный train loop протестирован тысячами пользователей. Ты не забудешь model.eval(), zero_grad или torch.no_grad.

LightningModule — вся логика в одном месте

LightningModule наследуется от nn.Module, но добавляет структуру: ты определяешь что делать на каждом шаге, а как (цикл, GPU, градиенты) берёт на себя Trainer.

Ключевые методы:forward(x) — прямой проход, как в обычном nn.Module. Используется при инференсе: model(x). • training_step(batch, batch_idx) — один шаг обучения. Принимает батч, возвращает loss. Lightning сам вызовет backward(), zero_grad(), optimizer.step(). • validation_step(batch, batch_idx) — шаг валидации. Lightning автоматически включит model.eval() и torch.no_grad(). • configure_optimizers() — возвращает оптимизатор (и опционально scheduler). Один метод вместо разбросанных по коду optimizer = ... и scheduler.step().

import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTClassifier(L.LightningModule):
    def __init__(self, lr: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()          # автоматически сохраняет lr в чекпоинт
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log('train_loss', loss, prog_bar=True)
        return loss                          # Lightning сделает backward + step

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

save_hyperparameters()

Вызов self.save_hyperparameters() в __init__ автоматически сохраняет все аргументы конструктора в self.hparams и в чекпоинт. При загрузке модели из чекпоинта гиперпараметры восстанавливаются. Не нужно вручную передавать lr, hidden_dim и т.д.

Trainer — запуск обучения одной строкой

Trainer управляет всем: циклом эпох, валидацией, коллбэками, логированием, GPU/TPU. Ты настраиваешь поведение через параметры, а не переписываешь цикл.

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

trainer = L.Trainer(
    max_epochs=20,
    accelerator='auto',            # CPU / GPU / TPU — определит сам
    devices='auto',                # сколько GPU использовать
    callbacks=[
        EarlyStopping(
            monitor='val_loss',    # следим за val_loss
            patience=3,            # 3 эпохи без улучшения → стоп
            mode='min',
        ),
        ModelCheckpoint(
            monitor='val_loss',
            save_top_k=1,          # сохраняем только лучшую модель
            filename='best-{epoch}-{val_loss:.3f}',
        ),
    ],
    logger=TensorBoardLogger('logs/', name='mnist'),
    deterministic=True,            # воспроизводимость (чуть медленнее)
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Callbacks — расширение без изменения кода:EarlyStopping — останавливает обучение, если метрика не улучшается patience эпох. Самый важный callback — защита от переобучения. • ModelCheckpoint — сохраняет веса лучшей модели. save_top_k=1 — только лучшая, save_top_k=3 — три лучших. Можно восстановить: MNISTClassifier.load_from_checkpoint("best.ckpt"). • LearningRateMonitor — логирует lr на каждом шаге (полезно со schedulers). • Свои callbacks — наследуешь от lightning.pytorch.Callback, переопределяешь хуки (on_train_epoch_end, on_validation_end и т.д.).

Логирование — TensorBoard и W&B:TensorBoardLogger — встроенный, tensorboard --logdir logs/ для визуализации. Бесплатный, локальный. • WandbLogger — интеграция с Weights & Biases. Автоматически логирует метрики, гиперпараметры, системные ресурсы. Удобно для командной работы и сравнения экспериментов. logger=WandbLogger(project="mnist").

LightningDataModule — инкапсуляция данных

LightningDataModule группирует всю логику данных: скачивание, предобработка, splits, DataLoader'ы. Это делает код переиспользуемым и воспроизводимым — ты можешь передать DataModule другому человеку, и он получит те же данные в том же формате.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size: int = 64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

    def prepare_data(self):
        # Скачивание (вызывается 1 раз, на 1 процессе)
        datasets.MNIST('data', train=True, download=True)

    def setup(self, stage=None):
        # Создание splits (вызывается на каждом GPU)
        full = datasets.MNIST('data', train=True, transform=self.transform)
        self.train_data, self.val_data = random_split(full, [55000, 5000])
        self.test_data = datasets.MNIST('data', train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=4)

# Использование:
dm = MNISTDataModule(batch_size=128)
trainer.fit(model, datamodule=dm)

prepare_data vs setup

prepare_data() — скачивание и запись на диск. Вызывается один раз на главном процессе. Нельзя сохранять в self (multi-GPU поломается). • setup(stage) — создание датасетов, splits, трансформации. Вызывается на каждом GPU. Здесь сохраняй в self.

seed_everything — полная воспроизводимость

Одна строка фиксирует все источники случайности: Python random, NumPy, PyTorch CPU/CUDA, DataLoader workers. В чистом PyTorch пришлось бы писать 4-5 строк.

L.seed_everything(42, workers=True)

# Эквивалент в чистом PyTorch:
# import random; random.seed(42)
# import numpy as np; np.random.seed(42)
# torch.manual_seed(42)
# torch.cuda.manual_seed_all(42)
# + настройка worker_init_fn в DataLoader

Для полной детерминированности добавь Trainer(deterministic=True) — это включит детерминированные алгоритмы CUDA. Цена: ~10-15% замедление на некоторых операциях. В продакшене обычно не нужно, но для дебага и экспериментов — бесценно.

Пример: чистый PyTorch → Lightning

Покажем, как типичный boilerplate PyTorch-кода сжимается в Lightning:

# ❌ Чистый PyTorch — 25+ строк boilerplate
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(20):
    model.train()
    for batch in train_loader:
        x, y = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        loss = F.cross_entropy(model(x), y)
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch[0].to(device), batch[1].to(device)
            val_loss = F.cross_entropy(model(x), y)
    # + ручной чекпоинт, early stopping, логирование...

# ✅ Lightning — только уникальная логика
class Net(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256),
                                 nn.ReLU(), nn.Linear(256, 10))

    def training_step(self, batch, batch_idx):
        x, y = batch
        return F.cross_entropy(self.net(x), y)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

trainer = L.Trainer(max_epochs=20, accelerator='auto')
trainer.fit(Net(), train_dataloaders=train_loader)

Что исчезло: .to(device), model.train()/eval(), optimizer.zero_grad(), loss.backward(), optimizer.step(), torch.no_grad(). Lightning делает всё это автоматически. Ты добавляешь callbacks для EarlyStopping, чекпоинтов, логирования — не меняя код модели.

Когда НЕ использовать Lightning

Lightning — не серебряная пуля. Есть ситуации, где он мешает: • Быстрые эксперименты в ноутбуке — если пишешь 30 строк для проверки идеи, обёртка в LightningModule — оверхед. Чистый PyTorch быстрее для прототипа. • Нестандартный training loop — GAN'ы с чередующимися оптимизаторами, RL с environment step, self-play. Можно реализовать через manual optimization, но код получается сложнее, чем в чистом PyTorch. • Отладка низкоуровневых операций — если дебажишь кастомный autograd, custom CUDA kernel — лишний слой абстракции мешает. • Legacy проекты — если проект на чистом PyTorch уже работает, переписывать ради Lightning не стоит. Правило: если проект живёт дольше недели, есть эксперименты и multi-GPU — используй Lightning. Для одноразовых скриптов — чистый PyTorch.

🎯 На собеседовании

Junior

Зачем Lightning, если есть PyTorch? Убирает boilerplate (training loop, device management, logging). Код стандартизирован, легче читать и поддерживать. • Какие методы нужно определить в LightningModule? training_step (возвращает loss), configure_optimizers (возвращает оптимизатор). forward — для инференса. • Что делает Trainer? Управляет циклом обучения, валидацией, callbacks, логированием, GPU/TPU — одной строкой.

Middle

EarlyStopping — как работает? Следит за метрикой (val_loss), если не улучшается patience эпох — останавливает обучение. Предотвращает переобучение. • prepare_data vs setup — в чём разница? prepare_data — 1 раз на главном процессе (скачивание). setup — на каждом GPU (создание датасетов). В prepare_data нельзя сохранять в self. • Как перейти на multi-GPU? Trainer(devices=4, strategy="ddp"). Код модели не меняется. Lightning сам обернёт модель в DDP.

Senior

Manual optimization в Lightning — когда? GAN'ы (2 оптимизатора поочерёдно), custom gradient accumulation, RL. Устанавливаешь self.automatic_optimization = False и вручную вызываешь opt.step(). • Lightning vs Hugging Face Trainer — что выбрать? HF Trainer — для NLP с HF-моделями (заточен на Transformers). Lightning — универсальный, для любой PyTorch-архитектуры. Если CV/Audio/custom — Lightning. • Как Lightning работает с DDP? Каждый GPU получает копию модели, данные шардируются через DistributedSampler. Градиенты усредняются через all-reduce. self.log() автоматически sync_dist.

Собираем всё вместе

Lightning — это стандартизация PyTorch-проекта. LightningModule содержит модель и логику шагов, Trainer управляет циклом обучения, LightningDataModule инкапсулирует данные. Callbacks (EarlyStopping, ModelCheckpoint) и логирование (TensorBoard, W&B) подключаются без изменения кода модели. seed_everything + deterministic=True обеспечивают воспроизводимость.

Главное правило: если проект больше одного скрипта — используй Lightning. Экономия времени на boilerplate окупается в первый же день.