3rd party integration - RayTune, Weights & Biases¶
This notebook provides guideline for integration of external library functions in the model training process through Callback
objects, a popular concept of using objects as arguments for other objects.
[DISCLAIMER]
We show integration of RayTune (a hyperparameter tuning framework) and Weights & Biases (ML projects experiment tracking and versioning solution) in the pytorch_widedeep
model training process. We did not include RayTuneReporter
and WnBReportBest
in the library code to minimize the dependencies on other libraries that are not directly included in the model design and training.
Initial imports¶
from typing import Optional, Dict
import os
import numpy as np
import pandas as pd
import torch
import wandb
from torch.optim import SGD, lr_scheduler
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import TabMlp, WideDeep
from torchmetrics import F1Score as F1_torchmetrics
from torchmetrics import Accuracy as Accuracy_torchmetrics
from torchmetrics import Precision as Precision_torchmetrics
from torchmetrics import Recall as Recall_torchmetrics
from pytorch_widedeep.metrics import Accuracy, Recall, Precision, F1Score, R2Score
from pytorch_widedeep.initializers import XavierNormal
from pytorch_widedeep.callbacks import (
EarlyStopping,
ModelCheckpoint,
Callback,
)
from pytorch_widedeep.datasets import load_bio_kdd04
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import JupyterNotebookReporter
from ray.air.integrations.wandb import WandbLoggerCallback
# from ray.tune.integration.wandb import wandb_mixin
import tracemalloc
tracemalloc.start()
# increase displayed columns in jupyter notebook
pd.set_option("display.max_columns", 200)
pd.set_option("display.max_rows", 300)
class RayTuneReporter(Callback):
r"""Callback that allows reporting history and lr_history values to RayTune
during Hyperparameter tuning
Callbacks are passed as input parameters to the ``Trainer`` class. See
:class:`pytorch_widedeep.trainer.Trainer`
For examples see the examples folder at:
.. code-block:: bash
/examples/12_HyperParameter_tuning_w_RayTune.ipynb
"""
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
report_dict = {}
for k, v in self.trainer.history.items():
report_dict.update({k: v[-1]})
if hasattr(self.trainer, "lr_history"):
for k, v in self.trainer.lr_history.items():
report_dict.update({k: v[-1]})
tune.report(report_dict)
class WnBReportBest(Callback):
r"""Callback that allows reporting best performance of a run to WnB
during Hyperparameter tuning. It is an adjusted pytorch_widedeep.callbacks.ModelCheckpoint
with added WnB and removed checkpoint saving.
Callbacks are passed as input parameters to the ``Trainer`` class.
Parameters
----------
wb: obj
Weights&Biases API interface to report single best result usable for
comparisson of multiple paramater combinations by, for example,
`parallel coordinates
<https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates>`_.
E.g W&B summary report `wandb.run.summary["best"]`.
monitor: str, default="loss"
quantity to monitor. Typically `'val_loss'` or metric name
(e.g. `'val_acc'`)
mode: str, default="auto"
If ``save_best_only=True``, the decision to overwrite the current save
file is made based on either the maximization or the minimization of
the monitored quantity. For `'acc'`, this should be `'max'`, for
`'loss'` this should be `'min'`, etc. In `'auto'` mode, the
direction is automatically inferred from the name of the monitored
quantity.
"""
def __init__(
self,
wb: object,
monitor: str = "val_loss",
mode: str = "auto",
):
super(WnBReportBest, self).__init__()
self.monitor = monitor
self.mode = mode
self.wb = wb
if self.mode not in ["auto", "min", "max"]:
warnings.warn(
"WnBReportBest mode %s is unknown, "
"fallback to auto mode." % (self.mode),
RuntimeWarning,
)
self.mode = "auto"
if self.mode == "min":
self.monitor_op = np.less
self.best = np.Inf
elif self.mode == "max":
self.monitor_op = np.greater # type: ignore[assignment]
self.best = -np.Inf
else:
if self._is_metric(self.monitor):
self.monitor_op = np.greater # type: ignore[assignment]
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end( # noqa: C901
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
logs = logs or {}
current = logs.get(self.monitor)
if current is not None:
if self.monitor_op(current, self.best):
self.wb.run.summary["best"] = current # type: ignore[attr-defined]
self.best = current
self.best_epoch = epoch
@staticmethod
def _is_metric(monitor: str):
"copied from pytorch_widedeep.callbacks"
if any([s in monitor for s in ["acc", "prec", "rec", "fscore", "f1", "f2"]]):
return True
else:
return False
df = load_bio_kdd04(as_frame=True)
df.head()
EXAMPLE_ID | BLOCK_ID | target | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 279 | 261532 | 0 | 52.0 | 32.69 | 0.30 | 2.5 | 20.0 | 1256.8 | -0.89 | 0.33 | 11.0 | -55.0 | 267.2 | 0.52 | 0.05 | -2.36 | 49.6 | 252.0 | 0.43 | 1.16 | -2.06 | -33.0 | -123.2 | 1.60 | -0.49 | -6.06 | 65.0 | 296.1 | -0.28 | -0.26 | -3.83 | -22.6 | -170.0 | 3.06 | -1.05 | -3.29 | 22.9 | 286.3 | 0.12 | 2.58 | 4.08 | -33.0 | -178.9 | 1.88 | 0.53 | -7.0 | -44.0 | 1987.0 | -5.41 | 0.95 | -4.0 | -57.0 | 722.9 | -3.26 | -0.55 | -7.5 | 125.5 | 1547.2 | -0.36 | 1.12 | 9.0 | -37.0 | 72.5 | 0.47 | 0.74 | -11.0 | -8.0 | 1595.1 | -1.64 | 2.83 | -2.0 | -50.0 | 445.2 | -0.35 | 0.26 | 0.76 |
1 | 279 | 261533 | 0 | 58.0 | 33.33 | 0.00 | 16.5 | 9.5 | 608.1 | 0.50 | 0.07 | 20.5 | -52.5 | 521.6 | -1.08 | 0.58 | -0.02 | -3.2 | 103.6 | -0.95 | 0.23 | -2.87 | -25.9 | -52.2 | -0.21 | 0.87 | -1.81 | 10.4 | 62.0 | -0.28 | -0.04 | 1.48 | -17.6 | -198.3 | 3.43 | 2.84 | 5.87 | -16.9 | 72.6 | -0.31 | 2.79 | 2.71 | -33.5 | -11.6 | -1.11 | 4.01 | 5.0 | -57.0 | 666.3 | 1.13 | 4.38 | 5.0 | -64.0 | 39.3 | 1.07 | -0.16 | 32.5 | 100.0 | 1893.7 | -2.80 | -0.22 | 2.5 | -28.5 | 45.0 | 0.58 | 0.41 | -19.0 | -6.0 | 762.9 | 0.29 | 0.82 | -3.0 | -35.0 | 140.3 | 1.16 | 0.39 | 0.73 |
2 | 279 | 261534 | 0 | 77.0 | 27.27 | -0.91 | 6.0 | 58.5 | 1623.6 | -1.40 | 0.02 | -6.5 | -48.0 | 621.0 | -1.20 | 0.14 | -0.20 | 73.6 | 609.1 | -0.44 | -0.58 | -0.04 | -23.0 | -27.4 | -0.72 | -1.04 | -1.09 | 91.1 | 635.6 | -0.88 | 0.24 | 0.59 | -18.7 | -7.2 | -0.60 | -2.82 | -0.71 | 52.4 | 504.1 | 0.89 | -0.67 | -9.30 | -20.8 | -25.7 | -0.77 | -0.85 | 0.0 | -20.0 | 2259.0 | -0.94 | 1.15 | -4.0 | -44.0 | -22.7 | 0.94 | -0.98 | -19.0 | 105.0 | 1267.9 | 1.03 | 1.27 | 11.0 | -39.5 | 82.3 | 0.47 | -0.19 | -10.0 | 7.0 | 1491.8 | 0.32 | -1.29 | 0.0 | -34.0 | 658.2 | -0.76 | 0.26 | 0.24 |
3 | 279 | 261535 | 0 | 41.0 | 27.91 | -0.35 | 3.0 | 46.0 | 1921.6 | -1.36 | -0.47 | -32.0 | -51.5 | 560.9 | -0.29 | -0.10 | -1.11 | 124.3 | 791.6 | 0.00 | 0.39 | -1.85 | -21.7 | -44.9 | -0.21 | 0.02 | 0.89 | 133.9 | 797.8 | -0.08 | 1.06 | -0.26 | -16.4 | -74.1 | 0.97 | -0.80 | -0.41 | 66.9 | 955.3 | -1.90 | 1.28 | -6.65 | -28.1 | 47.5 | -1.91 | 1.42 | 1.0 | -30.0 | 1846.7 | 0.76 | 1.10 | -4.0 | -52.0 | -53.9 | 1.71 | -0.22 | -12.0 | 97.5 | 1969.8 | -1.70 | 0.16 | -1.0 | -32.5 | 255.9 | -0.46 | 1.57 | 10.0 | 6.0 | 2047.7 | -0.98 | 1.53 | 0.0 | -49.0 | 554.2 | -0.83 | 0.39 | 0.73 |
4 | 279 | 261536 | 0 | 50.0 | 28.00 | -1.32 | -9.0 | 12.0 | 464.8 | 0.88 | 0.19 | 8.0 | -51.5 | 98.1 | 1.09 | -0.33 | -2.16 | -3.9 | 102.7 | 0.39 | -1.22 | -3.39 | -15.2 | -42.2 | -1.18 | -1.11 | -3.55 | 8.9 | 141.3 | -0.16 | -0.43 | -4.15 | -12.9 | -13.4 | -1.32 | -0.98 | -3.69 | 8.8 | 136.1 | -0.30 | 4.13 | 1.89 | -13.0 | -18.7 | -1.37 | -0.93 | 0.0 | -1.0 | 810.1 | -2.29 | 6.72 | 1.0 | -23.0 | -29.7 | 0.58 | -1.10 | -18.5 | 33.5 | 206.8 | 1.84 | -0.13 | 4.0 | -29.0 | 30.1 | 0.80 | -0.24 | 5.0 | -14.0 | 479.5 | 0.68 | -0.59 | 2.0 | -36.0 | -6.9 | 2.02 | 0.14 | -0.23 |
# imbalance of the classes
df["target"].value_counts()
target 0 144455 1 1296 Name: count, dtype: int64
# drop columns we won't need in this example
df.drop(columns=["EXAMPLE_ID", "BLOCK_ID"], inplace=True)
df_train, df_valid = train_test_split(
df, test_size=0.2, stratify=df["target"], random_state=1
)
df_valid, df_test = train_test_split(
df_valid, test_size=0.5, stratify=df_valid["target"], random_state=1
)
Preparing the data¶
continuous_cols = df.drop(columns=["target"]).columns.values.tolist()
# deeptabular
tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)
X_tab_train = tab_preprocessor.fit_transform(df_train)
X_tab_valid = tab_preprocessor.transform(df_valid)
X_tab_test = tab_preprocessor.transform(df_test)
# target
y_train = df_train["target"].values
y_valid = df_valid["target"].values
y_test = df_test["target"].values
Define the model¶
input_layer = len(tab_preprocessor.continuous_cols)
output_layer = 1
hidden_layers = np.linspace(
input_layer * 2, output_layer, 5, endpoint=False, dtype=int
).tolist()
deeptabular = TabMlp(
mlp_hidden_dims=hidden_layers,
column_idx=tab_preprocessor.column_idx,
continuous_cols=tab_preprocessor.continuous_cols,
)
model = WideDeep(deeptabular=deeptabular)
model
WideDeep( (deeptabular): Sequential( (0): TabMlp( (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=74, out_features=148, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=148, out_features=118, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_2): Sequential( (0): Linear(in_features=118, out_features=89, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_3): Sequential( (0): Linear(in_features=89, out_features=59, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_4): Sequential( (0): Linear(in_features=59, out_features=30, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) ) (1): Linear(in_features=30, out_features=1, bias=True) ) )
# Metrics from torchmetrics
accuracy = Accuracy_torchmetrics(average=None, num_classes=1, task="binary")
precision = Precision_torchmetrics(average="micro", num_classes=1, task="binary")
f1 = F1_torchmetrics(average=None, num_classes=1, task="binary")
recall = Recall_torchmetrics(average=None, num_classes=1, task="binary")
Note:
Following cells includes usage of both RayTuneReporter
and WnBReportBest
callbacks.
In case you want to use just RayTuneReporter
, remove following:
- wandb from config
WandbLoggerCallback
WnBReportBest
@wandb_mixin
decorator
We do not see strong reason to use WnB without RayTune for a single paramater combination run, but it is possible:
- option01: define paramaters in config only for a single value
tune.grid_search([1000])
(single value RayTune run) - option02: define WnB callback that reports currnet validation/training loss, metrics, etc. at the end of batch, ie. do not report to WnB at
epoch_end
as inWnBReportBest
but at theon_batch_end
, seepytorch_widedeep.callbacks.Callback
config = {
"batch_size": tune.grid_search([1000, 5000]),
"wandb": {
"project": "test",
# "api_key_file": os.getcwd() + "/wandb_api.key",
"api_key": "WNB_API_KEY",
},
}
# Optimizers
deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)
# LR Scheduler
deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)
@wandb_mixin
def training_function(config, X_train, X_val):
early_stopping = EarlyStopping()
model_checkpoint = ModelCheckpoint(save_best_only=True)
# Hyperparameters
batch_size = config["batch_size"]
trainer = Trainer(
model,
objective="binary_focal_loss",
callbacks=[
RayTuneReporter,
WnBReportBest(wb=wandb),
early_stopping,
model_checkpoint,
],
lr_schedulers={"deeptabular": deep_sch},
initializers={"deeptabular": XavierNormal},
optimizers={"deeptabular": deep_opt},
metrics=[accuracy, precision, recall, f1],
verbose=0,
)
trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=batch_size)
X_train = {"X_tab": X_tab_train, "target": y_train}
X_val = {"X_tab": X_tab_valid, "target": y_valid}
asha_scheduler = AsyncHyperBandScheduler(
time_attr="training_iteration",
metric="_metric/val_loss",
mode="min",
max_t=100,
grace_period=10,
reduction_factor=3,
brackets=1,
)
analysis = tune.run(
tune.with_parameters(training_function, X_train=X_train, X_val=X_val),
resources_per_trial={"cpu": 1, "gpu": 0},
progress_reporter=JupyterNotebookReporter(overwrite=True),
scheduler=asha_scheduler,
config=config,
callbacks=[
WandbLoggerCallback(
project=config["wandb"]["project"],
# api_key_file=config["wandb"]["api_key_file"],
api_key=config["wandb"]["api_key"],
log_config=True,
)
],
)
/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/_2/lrjn1qn54c758tdtktr1bvkc0000gn/T/tmp60pfyl1kwandb'> _warnings.warn(warn_message, ResourceWarning) /Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/_2/lrjn1qn54c758tdtktr1bvkc0000gn/T/tmpnjv2rg1wwandb-artifacts'> _warnings.warn(warn_message, ResourceWarning) /Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/_2/lrjn1qn54c758tdtktr1bvkc0000gn/T/tmpgebu5k1kwandb-media'> _warnings.warn(warn_message, ResourceWarning) /Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/var/folders/_2/lrjn1qn54c758tdtktr1bvkc0000gn/T/tmpxy9y2yriwandb-media'> _warnings.warn(warn_message, ResourceWarning)
analysis.results
{'fc9a8_00000': {'_metric': {'train_loss': 0.006297602537127896, 'train_Accuracy': 0.9925042986869812, 'train_Precision': 0.9939393997192383, 'train_Recall': 0.15814851224422455, 'train_F1Score': 0.2728785574436188, 'val_loss': 0.005045663565397263, 'val_Accuracy': 0.9946483969688416, 'val_Precision': 1.0, 'val_Recall': 0.39534884691238403, 'val_F1Score': 0.5666667222976685}, 'time_this_iter_s': 2.388202428817749, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 5, 'trial_id': 'fc9a8_00000', 'experiment_id': 'baad1d4f3d924b48b9ece1b9f26c80cc', 'date': '2022-07-31_14-06-51', 'timestamp': 1659276411, 'time_total_s': 12.656474113464355, 'pid': 1813, 'hostname': 'jupyter-5uperpalo', 'node_ip': '10.32.44.172', 'config': {'batch_size': 1000}, 'time_since_restore': 12.656474113464355, 'timesteps_since_restore': 0, 'iterations_since_restore': 5, 'warmup_time': 0.8006253242492676, 'experiment_tag': '0_batch_size=1000'}, 'fc9a8_00001': {'_metric': {'train_loss': 0.02519632239515583, 'train_Accuracy': 0.9910891652107239, 'train_Precision': 0.25, 'train_Recall': 0.0009643201483413577, 'train_F1Score': 0.0019212296465411782, 'val_loss': 0.02578434906899929, 'val_Accuracy': 0.9911492466926575, 'val_Precision': 0.0, 'val_Recall': 0.0, 'val_F1Score': 0.0}, 'time_this_iter_s': 4.113586902618408, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 5, 'trial_id': 'fc9a8_00001', 'experiment_id': 'f2e54a6a5780429fbf0db0746853347e', 'date': '2022-07-31_14-06-56', 'timestamp': 1659276416, 'time_total_s': 12.926990509033203, 'pid': 1962, 'hostname': 'jupyter-5uperpalo', 'node_ip': '10.32.44.172', 'config': {'batch_size': 5000}, 'time_since_restore': 12.926990509033203, 'timesteps_since_restore': 0, 'iterations_since_restore': 5, 'warmup_time': 0.9253025054931641, 'experiment_tag': '1_batch_size=5000'}}
Using Weights and Biases logging you can create parallel coordinates graphs that map parametr combinations to the best(lowest) loss achieved during the training of the networks
local visualization of raytune reults using tensorboard
%load_ext tensorboard
%tensorboard --logdir ~/ray_results