Назад к подготовке

Вопрос про production ML

In PyTorch, what should Dataset do, what should collate_fn do, how do num_workers affect this, and where should .to(device) usually happen?

Ответить самому

Сначала сформулируйте ответ как на собеседовании, затем откройте разбор и оцените себя.

Загрузка

Короткий ответ

Dataset maps indices to CPU examples; collate_fn batches fetched examples; num_workers parallelize fetching/collation. Device transfer usually belongs in the training loop, after batching.

Полный разбор

A Dataset should expose __len__ and __getitem__, and __getitem__ should return one example. It can read files and apply CPU transforms, but it should not usually know about training state or GPU device placement.

collate_fn receives a list of already-fetched examples and turns them into a batch. It is useful when examples are custom classes, variable-length sequences or nested structures that the default collate cannot stack. It should not fetch indices itself; that would mix responsibilities.

num_workers creates worker processes for data loading. If __getitem__ moves tensors to GPU, multiple workers can compete for GPU memory, prefetch GPU batches and make device ownership messy. The common pattern is CPU Dataset and collate, then in the train loop do batch = batch.to(device), optionally with pinned memory and non_blocking transfers.

Теория

PyTorch data loading is a pipeline with separate fetch, batch and train-loop device stages.

Типичные ошибки

  • Make collate_fn fetch from the dataset by index.
  • Move samples to CUDA inside Dataset.__getitem__.
  • Forget __len__ for map-style datasets.

Как отвечать на собеседовании

  • Use separation of responsibilities as the organizing principle.
  • Mention num_workers and prefetching when explaining device placement.