PyTorch Lightning
Абстракция train loop, callbacks, logging, воспроизводимость — индустриальный стандарт обучения.
PyTorch Lightning — фреймворк поверх PyTorch
Чистый PyTorch даёт полный контроль, но каждый проект начинается с одного и того же: training loop, валидация, логирование, чекпоинты, multi-GPU. PyTorch Lightning — это тонкая абстракция, которая стандартизирует boilerplate, не забирая гибкость. Ты пишешь только то, что уникально для твоей задачи: модель, loss, данные. Всё остальное — Trainer.
Три главные причины использовать Lightning: 1. Воспроизводимость — структура кода единообразная, seed_everything фиксирует все генераторы, конфигурация сериализуется. 2. Масштабируемость — переход с CPU на GPU, multi-GPU, TPU — одна строка в Trainer. Не нужно переписывать .to(device), DistributedDataParallel и т.д. 3. Меньше багов — стандартный train loop протестирован тысячами пользователей. Ты не забудешь model.eval(), zero_grad или torch.no_grad.
LightningModule — вся логика в одном месте
LightningModule наследуется от nn.Module, но добавляет структуру: ты определяешь что делать на каждом шаге, а как (цикл, GPU, градиенты) берёт на себя Trainer.
Ключевые методы:
• forward(x) — прямой проход, как в обычном nn.Module. Используется при инференсе: model(x).
• training_step(batch, batch_idx) — один шаг обучения. Принимает батч, возвращает loss. Lightning сам вызовет backward(), zero_grad(), optimizer.step().
• validation_step(batch, batch_idx) — шаг валидации. Lightning автоматически включит model.eval() и torch.no_grad().
• configure_optimizers() — возвращает оптимизатор (и опционально scheduler). Один метод вместо разбросанных по коду optimizer = ... и scheduler.step().
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
class MNISTClassifier(L.LightningModule):
def __init__(self, lr: float = 1e-3):
super().__init__()
self.save_hyperparameters() # автоматически сохраняет lr в чекпоинт
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 256), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(256, 10),
)
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_idx):
x, y = batch
loss = F.cross_entropy(self(x), y)
self.log('train_loss', loss, prog_bar=True)
return loss # Lightning сделает backward + step
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(1) == y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)save_hyperparameters()
Trainer — запуск обучения одной строкой
Trainer управляет всем: циклом эпох, валидацией, коллбэками, логированием, GPU/TPU. Ты настраиваешь поведение через параметры, а не переписываешь цикл.
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
trainer = L.Trainer(
max_epochs=20,
accelerator='auto', # CPU / GPU / TPU — определит сам
devices='auto', # сколько GPU использовать
callbacks=[
EarlyStopping(
monitor='val_loss', # следим за val_loss
patience=3, # 3 эпохи без улучшения → стоп
mode='min',
),
ModelCheckpoint(
monitor='val_loss',
save_top_k=1, # сохраняем только лучшую модель
filename='best-{epoch}-{val_loss:.3f}',
),
],
logger=TensorBoardLogger('logs/', name='mnist'),
deterministic=True, # воспроизводимость (чуть медленнее)
)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)Callbacks — расширение без изменения кода:
• EarlyStopping — останавливает обучение, если метрика не улучшается patience эпох. Самый важный callback — защита от переобучения.
• ModelCheckpoint — сохраняет веса лучшей модели. save_top_k=1 — только лучшая, save_top_k=3 — три лучших. Можно восстановить: MNISTClassifier.load_from_checkpoint("best.ckpt").
• LearningRateMonitor — логирует lr на каждом шаге (полезно со schedulers).
• Свои callbacks — наследуешь от lightning.pytorch.Callback, переопределяешь хуки (on_train_epoch_end, on_validation_end и т.д.).
Логирование — TensorBoard и W&B:
• TensorBoardLogger — встроенный, tensorboard --logdir logs/ для визуализации. Бесплатный, локальный.
• WandbLogger — интеграция с Weights & Biases. Автоматически логирует метрики, гиперпараметры, системные ресурсы. Удобно для командной работы и сравнения экспериментов. logger=WandbLogger(project="mnist").
LightningDataModule — инкапсуляция данных
LightningDataModule группирует всю логику данных: скачивание, предобработка, splits, DataLoader'ы. Это делает код переиспользуемым и воспроизводимым — ты можешь передать DataModule другому человеку, и он получит те же данные в том же формате.
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
class MNISTDataModule(L.LightningDataModule):
def __init__(self, batch_size: int = 64):
super().__init__()
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
def prepare_data(self):
# Скачивание (вызывается 1 раз, на 1 процессе)
datasets.MNIST('data', train=True, download=True)
def setup(self, stage=None):
# Создание splits (вызывается на каждом GPU)
full = datasets.MNIST('data', train=True, transform=self.transform)
self.train_data, self.val_data = random_split(full, [55000, 5000])
self.test_data = datasets.MNIST('data', train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=4)
# Использование:
dm = MNISTDataModule(batch_size=128)
trainer.fit(model, datamodule=dm)prepare_data vs setup
seed_everything — полная воспроизводимость
Одна строка фиксирует все источники случайности: Python random, NumPy, PyTorch CPU/CUDA, DataLoader workers. В чистом PyTorch пришлось бы писать 4-5 строк.
L.seed_everything(42, workers=True)
# Эквивалент в чистом PyTorch:
# import random; random.seed(42)
# import numpy as np; np.random.seed(42)
# torch.manual_seed(42)
# torch.cuda.manual_seed_all(42)
# + настройка worker_init_fn в DataLoaderДля полной детерминированности добавь Trainer(deterministic=True) — это включит детерминированные алгоритмы CUDA. Цена: ~10-15% замедление на некоторых операциях. В продакшене обычно не нужно, но для дебага и экспериментов — бесценно.
Пример: чистый PyTorch → Lightning
Покажем, как типичный boilerplate PyTorch-кода сжимается в Lightning:
# ❌ Чистый PyTorch — 25+ строк boilerplate
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(20):
model.train()
for batch in train_loader:
x, y = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()
loss = F.cross_entropy(model(x), y)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
for batch in val_loader:
x, y = batch[0].to(device), batch[1].to(device)
val_loss = F.cross_entropy(model(x), y)
# + ручной чекпоинт, early stopping, логирование...
# ✅ Lightning — только уникальная логика
class Net(L.LightningModule):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Flatten(), nn.Linear(784, 256),
nn.ReLU(), nn.Linear(256, 10))
def training_step(self, batch, batch_idx):
x, y = batch
return F.cross_entropy(self.net(x), y)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
trainer = L.Trainer(max_epochs=20, accelerator='auto')
trainer.fit(Net(), train_dataloaders=train_loader)Что исчезло: .to(device), model.train()/eval(), optimizer.zero_grad(), loss.backward(), optimizer.step(), torch.no_grad(). Lightning делает всё это автоматически. Ты добавляешь callbacks для EarlyStopping, чекпоинтов, логирования — не меняя код модели.
Когда НЕ использовать Lightning
Lightning — не серебряная пуля. Есть ситуации, где он мешает: • Быстрые эксперименты в ноутбуке — если пишешь 30 строк для проверки идеи, обёртка в LightningModule — оверхед. Чистый PyTorch быстрее для прототипа. • Нестандартный training loop — GAN'ы с чередующимися оптимизаторами, RL с environment step, self-play. Можно реализовать через manual optimization, но код получается сложнее, чем в чистом PyTorch. • Отладка низкоуровневых операций — если дебажишь кастомный autograd, custom CUDA kernel — лишний слой абстракции мешает. • Legacy проекты — если проект на чистом PyTorch уже работает, переписывать ради Lightning не стоит. Правило: если проект живёт дольше недели, есть эксперименты и multi-GPU — используй Lightning. Для одноразовых скриптов — чистый PyTorch.
🎯 На собеседовании
Junior
Middle
Senior
Собираем всё вместе
Lightning — это стандартизация PyTorch-проекта. LightningModule содержит модель и логику шагов, Trainer управляет циклом обучения, LightningDataModule инкапсулирует данные. Callbacks (EarlyStopping, ModelCheckpoint) и логирование (TensorBoard, W&B) подключаются без изменения кода модели. seed_everything + deterministic=True обеспечивают воспроизводимость.
Главное правило: если проект больше одного скрипта — используй Lightning. Экономия времени на boilerplate окупается в первый же день.
Материалы
Полная документация: LightningModule, Trainer, callbacks, стратегии распределённого обучения. Начни с Getting Started.
Минимальный пример: от чистого PyTorch к Lightning за 15 минут. Идеально для первого знакомства.
Как подключить Weights & Biases к Lightning: WandbLogger, автоматическое логирование метрик, гиперпараметров и артефактов.