Custom components¶
As I mentioned earlier in the example notebooks, and also in the README
, it is possible to customise almost every component in pytorch-widedeep
.
Let's now go through a couple of simple examples to illustrate how that could be done.
First let's load and process the data "as usual", let's start with a regression and the airbnb dataset.
import numpy as np
import pandas as pd
import os
import torch
from torch import Tensor
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import (
WidePreprocessor,
TabPreprocessor,
TextPreprocessor,
ImagePreprocessor,
)
from pytorch_widedeep.models import (
Wide,
TabMlp,
Vision,
BasicRNN,
WideDeep,
)
from pytorch_widedeep.losses import RMSELoss
from pytorch_widedeep.initializers import *
from pytorch_widedeep.callbacks import *
from pytorch_widedeep.datasets import load_adult
/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
df = pd.read_csv("../tmp_data/airbnb/airbnb_sample.csv")
df.head()
id | host_id | description | host_listings_count | host_identity_verified | neighbourhood_cleansed | latitude | longitude | is_location_exact | property_type | ... | amenity_wide_entrance | amenity_wide_entrance_for_guests | amenity_wide_entryway | amenity_wide_hallways | amenity_wifi | amenity_window_guards | amenity_wine_cooler | security_deposit | extra_people | yield | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 13913.jpg | 54730 | My bright double bedroom with a large window h... | 4.0 | f | Islington | 51.56802 | -0.11121 | t | apartment | ... | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 100.0 | 15.0 | 12.00 |
1 | 15400.jpg | 60302 | Lots of windows and light. St Luke's Gardens ... | 1.0 | t | Kensington and Chelsea | 51.48796 | -0.16898 | t | apartment | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 150.0 | 0.0 | 109.50 |
2 | 17402.jpg | 67564 | Open from June 2018 after a 3-year break, we a... | 19.0 | t | Westminster | 51.52098 | -0.14002 | t | apartment | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 350.0 | 10.0 | 149.65 |
3 | 24328.jpg | 41759 | Artist house, bright high ceiling rooms, priva... | 2.0 | t | Wandsworth | 51.47298 | -0.16376 | t | other | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 250.0 | 0.0 | 215.60 |
4 | 25023.jpg | 102813 | Large, all comforts, 2-bed flat; first floor; ... | 1.0 | f | Wandsworth | 51.44687 | -0.21874 | t | apartment | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 250.0 | 11.0 | 79.35 |
5 rows × 223 columns
# There are a number of columns that are already binary. Therefore, no need to one hot encode them
crossed_cols = [("property_type", "room_type")]
already_dummies = [c for c in df.columns if "amenity" in c] + ["has_house_rules"]
wide_cols = [
"is_location_exact",
"property_type",
"room_type",
"host_gender",
"instant_bookable",
] + already_dummies
cat_embed_cols = [(c, 16) for c in df.columns if "catg" in c] + [
("neighbourhood_cleansed", 64),
("cancellation_policy", 16),
]
continuous_cols = ["latitude", "longitude", "security_deposit", "extra_people"]
# it does not make sense to standarised Latitude and Longitude
already_standard = ["latitude", "longitude"]
# text and image colnames
text_col = "description"
img_col = "id"
# path to pretrained word embeddings and the images
word_vectors_path = "../tmp_data/glove.6B/glove.6B.100d.txt"
img_path = "../tmp_data/airbnb/property_picture"
# target
target_col = "yield"
target = df[target_col].values
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df)
tab_preprocessor = TabPreprocessor(
embed_cols=cat_embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(df)
text_preprocessor = TextPreprocessor(
word_vectors_path=word_vectors_path, text_col=text_col
)
X_text = text_preprocessor.fit_transform(df)
image_processor = ImagePreprocessor(img_col=img_col, img_path=img_path)
X_images = image_processor.fit_transform(df)
/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised warnings.warn("Continuous columns will not be normalised")
The vocabulary contains 2192 tokens Indexing word vectors... Loaded 400000 word vectors Preparing embeddings matrix... 2175 words in the vocabulary had ../tmp_data/glove.6B/glove.6B.100d.txt vectors and appear more than 5 times Reading Images from ../tmp_data/airbnb/property_picture Resizing
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1001/1001 [00:02<00:00, 497.80it/s]
Computing normalisation metrics
Now we are ready to build a wide and deep model. Three of the four components we will use are included in this package, and they will be combined with a custom deeptext
component. Then the fit process will run with a custom loss function.
Let's have a look
# Linear model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
# DeepDense: 2 Dense layers
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
cat_embed_dropout=0.1,
continuous_cols=continuous_cols,
mlp_hidden_dims=[128, 64],
mlp_dropout=0.1,
)
# Pretrained Resnet 18
resnet = Vision(pretrained_model_name="resnet18", n_trainable=0)
Custom deeptext
¶
Standard Pytorch model
class MyDeepText(nn.Module):
def __init__(self, vocab_size, padding_idx=1, embed_dim=100, hidden_dim=64):
super(MyDeepText, self).__init__()
# word/token embeddings
self.word_embed = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
# stack of RNNs
self.rnn = nn.GRU(
embed_dim,
hidden_dim,
num_layers=2,
bidirectional=True,
batch_first=True,
)
# Remember, this MUST be defined. If not WideDeep will through an error
self.output_dim = hidden_dim * 2
def forward(self, X):
embed = self.word_embed(X.long())
o, h = self.rnn(embed)
return torch.cat((h[-2], h[-1]), dim=1)
mydeeptext = MyDeepText(vocab_size=len(text_preprocessor.vocab.itos))
model = WideDeep(wide=wide, deeptabular=tab_mlp, deeptext=mydeeptext, deepimage=resnet)
Custom loss function¶
Loss functions must simply inherit pytorch's nn.Module
. For example, let's say we want to use RMSE
(note that this is already available in the package, but I will pass it here as a custom loss for illustration purposes)
class RMSELoss(nn.Module):
def __init__(self):
"""root mean squared error"""
super().__init__()
self.mse = nn.MSELoss()
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return torch.sqrt(self.mse(input, target))
and now we just instantiate the Trainer
as usual. Needless to say, but this runs with 1000 random observations, so loss and metric values are meaningless. This is just an example
trainer = Trainer(model, objective="regression", custom_loss_function=RMSELoss())
trainer.fit(
X_wide=X_wide,
X_tab=X_tab,
X_text=X_text,
X_img=X_images,
target=target,
n_epochs=1,
batch_size=32,
val_split=0.2,
)
epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:23<00:00, 1.07it/s, loss=126] valid: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:05<00:00, 1.24it/s, loss=97.4]
In addition to model components and loss functions, we can also use custom callbacks or custom metrics. The former need to be of type Callback
and the latter need to be of type Metric
. See:
pytorch-widedeep.callbacks
and
pytorch-widedeep.metrics
For this example let me use the adult dataset. Again, we first prepare the data as usual
df = load_adult(as_frame=True)
df.head()
age | workclass | fnlwgt | education | educational-num | marital-status | occupation | relationship | race | gender | capital-gain | capital-loss | hours-per-week | native-country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 25 | Private | 226802 | 11th | 7 | Never-married | Machine-op-inspct | Own-child | Black | Male | 0 | 0 | 40 | United-States | <=50K |
1 | 38 | Private | 89814 | HS-grad | 9 | Married-civ-spouse | Farming-fishing | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
2 | 28 | Local-gov | 336951 | Assoc-acdm | 12 | Married-civ-spouse | Protective-serv | Husband | White | Male | 0 | 0 | 40 | United-States | >50K |
3 | 44 | Private | 160323 | Some-college | 10 | Married-civ-spouse | Machine-op-inspct | Husband | Black | Male | 7688 | 0 | 40 | United-States | >50K |
4 | 18 | ? | 103497 | Some-college | 10 | Never-married | ? | Own-child | White | Female | 0 | 0 | 30 | United-States | <=50K |
# For convenience, we'll replace '-' with '_'
df.columns = [c.replace("-", "_") for c in df.columns]
# binary target
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df.head()
age | workclass | fnlwgt | education | educational_num | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 25 | Private | 226802 | 11th | 7 | Never-married | Machine-op-inspct | Own-child | Black | Male | 0 | 0 | 40 | United-States | 0 |
1 | 38 | Private | 89814 | HS-grad | 9 | Married-civ-spouse | Farming-fishing | Husband | White | Male | 0 | 0 | 50 | United-States | 0 |
2 | 28 | Local-gov | 336951 | Assoc-acdm | 12 | Married-civ-spouse | Protective-serv | Husband | White | Male | 0 | 0 | 40 | United-States | 1 |
3 | 44 | Private | 160323 | Some-college | 10 | Married-civ-spouse | Machine-op-inspct | Husband | Black | Male | 7688 | 0 | 40 | United-States | 1 |
4 | 18 | ? | 103497 | Some-college | 10 | Never-married | ? | Own-child | White | Female | 0 | 0 | 30 | United-States | 0 |
# Define wide, crossed and deep tabular columns
wide_cols = [
"workclass",
"education",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"native_country",
]
crossed_cols = [("education", "occupation"), ("native_country", "occupation")]
cat_embed_cols = [
"workclass",
"education",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"capital_gain",
"capital_loss",
"native_country",
]
continuous_cols = ["age", "hours_per_week"]
target_col = "income_label"
target = df[target_col].values
# wide
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = wide_preprocessor.fit_transform(df)
# deeptabular
tab_preprocessor = TabPreprocessor(
embed_cols=cat_embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(df)
/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised warnings.warn("Continuous columns will not be normalised")
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
mlp_hidden_dims=[128, 64],
mlp_dropout=0.2,
mlp_activation="leaky_relu",
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)
Custom metric¶
Let's say we want to use our own accuracy metric (again, this is already available in the package, but I will pass it here as a custom loss for illustration purposes).
This could be done as:
from pytorch_widedeep.metrics import Metric
class Accuracy(Metric):
def __init__(self, top_k: int = 1):
super(Accuracy, self).__init__()
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
# metric name needs to be defined
self._name = "acc"
def reset(self):
self.correct_count = 0
self.total_count = 0
def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
num_classes = y_pred.size(1)
if num_classes == 1:
y_pred = y_pred.round()
y_true = y_true
elif num_classes > 1:
y_pred = y_pred.topk(self.top_k, 1)[1]
y_true = y_true.view(-1, 1).expand_as(y_pred)
self.correct_count += y_pred.eq(y_true).sum().item()
self.total_count += len(y_pred)
accuracy = float(self.correct_count) / float(self.total_count)
return np.array(accuracy)
Custom Callback¶
Let's code a callback that records the current epoch at the beginning and the end of each epoch (silly, but you know, this is just an example)
# have a look to the class
from pytorch_widedeep.callbacks import Callback
class SillyCallback(Callback):
def on_train_begin(self, logs=None):
# recordings will be the trainer object attributes
self.trainer.silly_callback = {}
self.trainer.silly_callback["beginning"] = []
self.trainer.silly_callback["end"] = []
def on_epoch_begin(self, epoch, logs=None):
self.trainer.silly_callback["beginning"].append(epoch + 1)
def on_epoch_end(self, epoch, logs=None, metric=None):
self.trainer.silly_callback["end"].append(epoch + 1)
and now, as usual:
trainer = Trainer(
model, objective="binary", metrics=[Accuracy], callbacks=[SillyCallback]
)
trainer.fit(
X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=5, batch_size=64, val_split=0.2
)
epoch 1: 100%|███████████████████████████████████████████████████████████| 611/611 [00:06<00:00, 94.39it/s, loss=0.411, metrics={'acc': 0.814}] valid: 100%|███████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 121.91it/s, loss=0.327, metrics={'acc': 0.8449}] epoch 2: 100%|██████████████████████████████████████████████████████████| 611/611 [00:07<00:00, 85.39it/s, loss=0.324, metrics={'acc': 0.8495}] valid: 100%|████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 88.68it/s, loss=0.298, metrics={'acc': 0.8612}] epoch 3: 100%|██████████████████████████████████████████████████████████| 611/611 [00:08<00:00, 74.35it/s, loss=0.302, metrics={'acc': 0.8593}] valid: 100%|████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 100.51it/s, loss=0.29, metrics={'acc': 0.8665}] epoch 4: 100%|██████████████████████████████████████████████████████████| 611/611 [00:08<00:00, 73.83it/s, loss=0.292, metrics={'acc': 0.8637}] valid: 100%|███████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 105.98it/s, loss=0.286, metrics={'acc': 0.8695}] epoch 5: 100%|███████████████████████████████████████████████████████████| 611/611 [00:08<00:00, 72.15it/s, loss=0.286, metrics={'acc': 0.866}] valid: 100%|████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 92.27it/s, loss=0.284, metrics={'acc': 0.8698}]
trainer.silly_callback
{'beginning': [1, 2, 3, 4, 5], 'end': [1, 2, 3, 4, 5]}