Callbacks¶
Here are the 4 callbacks available to the user in pytorch-widedepp
:
LRHistory
, ModelCheckpoint
, EarlyStopping
and RayTuneReporter
.
NOTE: other callbacks , like History
, run always
by default. In particular, the History
callback saves the metrics in the
history
attribute of the Trainer
.
LRHistory ¶
LRHistory(n_epochs)
Bases: Callback
Saves the learning rates during training in the lr_history
attribute
of the Trainer
.
Callbacks are passed as input parameters to the Trainer
class. See
pytorch_widedeep.trainer.Trainer
Parameters:
-
n_epochs
(int
) –number of training epochs
Examples:
>>> from pytorch_widedeep.callbacks import LRHistory
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[LRHistory(n_epochs=10)])
Source code in pytorch_widedeep/callbacks.py
269 270 271 |
|
ModelCheckpoint ¶
ModelCheckpoint(filepath=None, monitor='val_loss', min_delta=0.0, verbose=0, save_best_only=False, mode='auto', period=1, max_save=-1)
Bases: Callback
Saves the model after every epoch.
This class is almost identical to the corresponding keras class. Therefore, credit to the Keras Team.
Callbacks are passed as input parameters to the Trainer
class. See
pytorch_widedeep.trainer.Trainer
Parameters:
-
filepath
(Optional[str]
, default:None
) –Full path to save the output weights. It must contain only the root of the filenames. Epoch number and
.pt
extension (for pytorch) will be added. e.g.filepath="path/to/output_weights/weights_out"
And the saved files in that directory will be named: 'weights_out_1.pt', 'weights_out_2.pt', .... If set toNone
the class just report best metric and best_epoch. -
monitor
(str
, default:'val_loss'
) –quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
-
min_delta
(float
, default:0.0
) –minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
-
verbose
(int
, default:0
) –verbosity mode
-
save_best_only
(bool
, default:False
) –the latest best model according to the quantity monitored will not be overwritten.
-
mode
(str
, default:'auto'
) –If
save_best_only=True
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'acc', this should be 'max', for 'loss' this should be 'min', etc. In 'auto' mode, the direction is automatically inferred from the name of the monitored quantity. -
period
(int
, default:1
) –Interval (number of epochs) between checkpoints.
-
max_save
(int
, default:-1
) –Maximum number of outputs to save. If -1 will save all outputs
Attributes:
-
best
(float
) –best metric
-
best_epoch
(int
) –best epoch
-
best_state_dict
(dict
) –best model state dictionary.
To restore model to its best state useTrainer.model.load_state_dict (model_checkpoint.best_state_dict)
wheremodel_checkpoint
is an instance of the classModelCheckpoint
. See the Examples folder in the repo or the Examples section in this documentation for details
Examples:
>>> from pytorch_widedeep.callbacks import ModelCheckpoint
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[ModelCheckpoint(filepath='checkpoints/weights_out')])
Source code in pytorch_widedeep/callbacks.py
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 |
|
EarlyStopping ¶
EarlyStopping(monitor='val_loss', min_delta=0.0, patience=10, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
Bases: Callback
Stop training when a monitored quantity has stopped improving.
This class is almost identical to the corresponding keras class. Therefore, credit to the Keras Team.
Callbacks are passed as input parameters to the Trainer
class. See
pytorch_widedeep.trainer.Trainer
Parameters:
-
monitor
(str
, default:'val_loss'
) –Quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
-
min_delta
(float
, default:0.0
) –minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
-
patience
(int
, default:10
) –Number of epochs that produced the monitored quantity with no improvement after which training will be stopped.
-
verbose
(int
, default:0
) –verbosity mode.
-
mode
(str
, default:'auto'
) –one of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing; in 'auto' mode, the direction is automatically inferred from the name of the monitored quantity.
-
baseline
(Optional[float]
, default:None
) –Baseline value for the monitored quantity to reach. Training will stop if the model does not show improvement over the baseline.
-
restore_best_weights
(bool
, default:False
) –Whether to restore model weights from the epoch with the best value of the monitored quantity. If
False
, the model weights obtained at the last step of training are used.
Attributes:
-
best
(float
) –best metric
-
stopped_epoch
(int
) –epoch when the training stopped
Examples:
>>> from pytorch_widedeep.callbacks import EarlyStopping
>>> from pytorch_widedeep.models import TabMlp, Wide, WideDeep
>>> from pytorch_widedeep.training import Trainer
>>>
>>> embed_input = [(u, i, j) for u, i, j in zip(["a", "b", "c"][:4], [4] * 3, [8] * 3)]
>>> column_idx = {k: v for v, k in enumerate(["a", "b", "c"])}
>>> wide = Wide(10, 1)
>>> deep = TabMlp(mlp_hidden_dims=[8, 4], column_idx=column_idx, cat_embed_input=embed_input)
>>> model = WideDeep(wide, deep)
>>> trainer = Trainer(model, objective="regression", callbacks=[EarlyStopping(patience=10)])
Source code in pytorch_widedeep/callbacks.py
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 |
|