pretrain_maskedlm

Pretrain a model on a Masked Language Modeling (MLM) task.

Classes

ProgressBar([refresh_rate, ...])

ValidationCallback(dataloader, tok)

class undertale.models.item.pretrain_maskedlm.ProgressBar(refresh_rate: int = 1, process_position: int = 0, leave: bool = False)

Bases: TQDMProgressBar

get_metrics(trainer, model)

Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar.

Here is an example of how to override the defaults:

def get_metrics(self, trainer, model):
    # don't show the version number
    items = super().get_metrics(trainer, model)
    items.pop("v_num", None)
    return items
Returns:

Dictionary with the items to be displayed in the progress bar.

class undertale.models.item.pretrain_maskedlm.ValidationCallback(dataloader, tok)

Bases: Callback

on_validation_end(trainer, pl_module)

Called when the validation loop ends.