Quick Start¶
This is an example of a binary classification with the
adult census
dataset using a combination of a wide and deep model (in this case a so called
deeptabular
model) with defaults settings.
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
df = load_adult(as_frame=True)
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
# Define the 'column set up'
wide_cols = [
"education",
"relationship",
"workclass",
"occupation",
"native-country",
"gender",
]
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]
cat_embed_cols = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"gender",
"capital-gain",
"capital-loss",
"native-country",
]
continuous_cols = ["age", "hours-per-week"]
target = "income_label"
target = df_train[target].values
# prepare the data
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df_train)
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(df_train)
# build the model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)
# train and validate
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
target=target,
n_epochs=5,
batch_size=256,
)
# predict on test
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
# Save and load
# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)
# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")
# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))
# 2. Instantiate the trainer
trainer_new = Trainer(model_new, objective="binary")
# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab, batch_size=32)