tuning¶
Utilities for hyperparameter tuning.
Functions
|
Find the maximum per-GPU batch size that fits in memory. |
- undertale.models.tuning.find_optimal_batch_size(model: LightningModule, load_dataset: Callable[[int], DataLoader], utilization: float = 0.95) int¶
Find the maximum per-GPU batch size that fits in memory.
Performs a doubling search followed by binary search to find the largest batch size that does not cause an OOM error, then scales it by a utilization factor. Uses a single-device Trainer per probe, which is compatible with DDP training (each GPU in DDP independently holds
batch_sizesamples).- Parameters:
model – The LightningModule to probe. Its state is restored after probing so the caller receives an unmodified model.
load_dataset – Callable that accepts a batch size and returns a DataLoader configured with that batch size.
utilization – Fraction of the maximum fitting batch size to return. Defaults to 0.95 to leave a small safety margin.
- Returns:
The recommended per-GPU batch size.
- Raises:
RuntimeError – If the model does not fit in memory even at batch_size=1.