tuning

Utilities for hyperparameter tuning.

Functions

find_optimal_batch_size(model, load_dataset)

Find the maximum per-GPU batch size that fits in memory.

undertale.models.tuning.find_optimal_batch_size(model: LightningModule, load_dataset: Callable[[int], DataLoader], utilization: float = 0.95) int

Find the maximum per-GPU batch size that fits in memory.

Performs a doubling search followed by binary search to find the largest batch size that does not cause an OOM error, then scales it by a utilization factor. Uses a single-device Trainer per probe, which is compatible with DDP training (each GPU in DDP independently holds batch_size samples).

Parameters:
  • model – The LightningModule to probe. Its state is restored after probing so the caller receives an unmodified model.

  • load_dataset – Callable that accepts a batch size and returns a DataLoader configured with that batch size.

  • utilization – Fraction of the maximum fitting batch size to return. Defaults to 0.95 to leave a small safety margin.

Returns:

The recommended per-GPU batch size.

Raises:

RuntimeError – If the model does not fit in memory even at batch_size=1.