Skip to content

Callbacks

Here are the 4 callbacks available to the user in pytorch-widedepp: LRHistory, ModelCheckpoint, EarlyStopping and RayTuneReporter.

ℹ️ NOTE: other callbacks , like History, run always by default. In particular, the History callback saves the metrics in the history attribute of the Trainer.

LRHistory

LRHistory(n_epochs)

Bases: Callback

Saves the learning rates during training in the lr_history attribute of the Trainer.

Callbacks are passed as input parameters to the Trainer class. See pytorch_widedeep.trainer.Trainer

Parameters:

  • n_epochs (int) –

    number of training epochs

Examples:

>>> from pytorch_widedeep.callbacks import LRHistory
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[LRHistory(n_epochs=10)])
Source code in pytorch_widedeep/callbacks.py
269
270
271
def __init__(self, n_epochs: int):
    super(LRHistory, self).__init__()
    self.n_epochs = n_epochs

ModelCheckpoint

ModelCheckpoint(filepath=None, monitor='val_loss', min_delta=0.0, verbose=0, save_best_only=False, mode='auto', period=1, max_save=-1)

Bases: Callback

Saves the model after every epoch.

This class is almost identical to the corresponding keras class. Therefore, credit to the Keras Team.

Callbacks are passed as input parameters to the Trainer class. See pytorch_widedeep.trainer.Trainer

Parameters:

  • filepath (Optional[str], default: None ) –

    Full path to save the output weights. It must contain only the root of the filenames. Epoch number and .pt extension (for pytorch) will be added. e.g. filepath="path/to/output_weights/weights_out" And the saved files in that directory will be named: 'weights_out_1.pt', 'weights_out_2.pt', .... If set to None the class just report best metric and best_epoch.

  • monitor (str, default: 'val_loss' ) –

    quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')

  • min_delta (float, default: 0.0 ) –

    minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

  • verbose (int, default: 0 ) –

    verbosity mode

  • save_best_only (bool, default: False ) –

    the latest best model according to the quantity monitored will not be overwritten.

  • mode (str, default: 'auto' ) –

    If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'acc', this should be 'max', for 'loss' this should be 'min', etc. In 'auto' mode, the direction is automatically inferred from the name of the monitored quantity.

  • period (int, default: 1 ) –

    Interval (number of epochs) between checkpoints.

  • max_save (int, default: -1 ) –

    Maximum number of outputs to save. If -1 will save all outputs

Attributes:

  • best (float) –

    best metric

  • best_epoch (int) –

    best epoch

  • best_state_dict (dict) –

    best model state dictionary.
    To restore model to its best state use Trainer.model.load_state_dict (model_checkpoint.best_state_dict) where model_checkpoint is an instance of the class ModelCheckpoint. See the Examples folder in the repo or the Examples section in this documentation for details

Examples:

>>> from pytorch_widedeep.callbacks import ModelCheckpoint
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[ModelCheckpoint(filepath='checkpoints/weights_out')])
Source code in pytorch_widedeep/callbacks.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def __init__(
    self,
    filepath: Optional[str] = None,
    monitor: str = "val_loss",
    min_delta: float = 0.0,
    verbose: int = 0,
    save_best_only: bool = False,
    mode: str = "auto",
    period: int = 1,
    max_save: int = -1,
):
    super(ModelCheckpoint, self).__init__()

    self.filepath = filepath
    self.monitor = monitor
    self.min_delta = min_delta
    self.verbose = verbose
    self.save_best_only = save_best_only
    self.mode = mode
    self.period = period
    self.max_save = max_save

    self.epochs_since_last_save = 0

    if self.filepath:
        if len(self.filepath.split("/")[:-1]) == 0:
            raise ValueError(
                "'filepath' must be the full path to save the output weights,"
                " including the root of the filenames. e.g. 'checkpoints/weights_out'"
            )

        root_dir = ("/").join(self.filepath.split("/")[:-1])
        if not os.path.exists(root_dir):
            os.makedirs(root_dir)

    if self.max_save > 0:
        self.old_files: List[str] = []

    if self.mode not in ["auto", "min", "max"]:
        warnings.warn(
            "ModelCheckpoint mode %s is unknown, "
            "fallback to auto mode." % (self.mode),
            RuntimeWarning,
        )
        self.mode = "auto"
    if self.mode == "min":
        self.monitor_op = np.less
        self.best = np.Inf
    elif self.mode == "max":
        self.monitor_op = np.greater  # type: ignore[assignment]
        self.best = -np.Inf
    else:
        if _is_metric(self.monitor):
            self.monitor_op = np.greater  # type: ignore[assignment]
            self.best = -np.Inf
        else:
            self.monitor_op = np.less
            self.best = np.Inf

    if self.monitor_op == np.greater:
        self.min_delta *= 1
    else:
        self.min_delta *= -1

EarlyStopping

EarlyStopping(monitor='val_loss', min_delta=0.0, patience=10, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

Bases: Callback

Stop training when a monitored quantity has stopped improving.

This class is almost identical to the corresponding keras class. Therefore, credit to the Keras Team.

Callbacks are passed as input parameters to the Trainer class. See pytorch_widedeep.trainer.Trainer

Parameters:

  • monitor (str, default: 'val_loss' ) –

    Quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')

  • min_delta (float, default: 0.0 ) –

    minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

  • patience (int, default: 10 ) –

    Number of epochs that produced the monitored quantity with no improvement after which training will be stopped.

  • verbose (int, default: 0 ) –

    verbosity mode.

  • mode (str, default: 'auto' ) –

    one of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing; in 'auto' mode, the direction is automatically inferred from the name of the monitored quantity.

  • baseline (Optional[float], default: None ) –

    Baseline value for the monitored quantity to reach. Training will stop if the model does not show improvement over the baseline.

  • restore_best_weights (bool, default: False ) –

    Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

Attributes:

  • best (float) –

    best metric

  • stopped_epoch (int) –

    epoch when the training stopped

Examples:

>>> from pytorch_widedeep.callbacks import EarlyStopping
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[EarlyStopping(patience=10)])
Source code in pytorch_widedeep/callbacks.py
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def __init__(
    self,
    monitor: str = "val_loss",
    min_delta: float = 0.0,
    patience: int = 10,
    verbose: int = 0,
    mode: str = "auto",
    baseline: Optional[float] = None,
    restore_best_weights: bool = False,
):
    super(EarlyStopping, self).__init__()

    self.monitor = monitor
    self.min_delta = min_delta
    self.patience = patience
    self.verbose = verbose
    self.mode = mode
    self.baseline = baseline
    self.restore_best_weights = restore_best_weights

    self.wait = 0
    self.stopped_epoch = 0
    self.state_dict = None

    if self.mode not in ["auto", "min", "max"]:
        warnings.warn(
            "EarlyStopping mode %s is unknown, "
            "fallback to auto mode." % self.mode,
            RuntimeWarning,
        )
        self.mode = "auto"

    if self.mode == "min":
        self.monitor_op = np.less
    elif self.mode == "max":
        self.monitor_op = np.greater  # type: ignore[assignment]
    else:
        if _is_metric(self.monitor):
            self.monitor_op = np.greater  # type: ignore[assignment]
        else:
            self.monitor_op = np.less

    if self.monitor_op == np.greater:
        self.min_delta *= 1
    else:
        self.min_delta *= -1