18_wide_and_deep_for_recsys_pt2
This is the second of the two notebooks where we aim to illustrate how one could use this library to build recommendation algorithms using the example in this Kaggle notebook as guidance. In the previous notebook we used pytorch-widedeep
to build a model that replicated almost exactly that in the notebook. In this, shorter notebook we will show how one could use the library to explore other models, following the same problem formulation, this is: given a state of a user at a certain point in time having watched a series of movies, our goal is to predict which movie the user will watch next.
Assuming that one has read (and run) the previous notebook, the required data will be stored in a local dir called prepared_data
, so let's read it:
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from torch import nn
from pytorch_widedeep import Trainer
from pytorch_widedeep.utils import pad_sequences
from pytorch_widedeep.models import TabMlp, WideDeep, Transformer
from pytorch_widedeep.preprocessing import TabPreprocessor
save_path = Path("prepared_data")
PAD_IDX = 0
id_cols = ["user_id", "movie_id"]
df_train = pd.read_pickle(save_path / "df_train.pkl")
df_valid = pd.read_pickle(save_path / "df_valid.pkl")
df_test = pd.read_pickle(save_path / "df_test.pkl")
...remember that in the previous notebook we explained that we are not going to use a validation set here (in a real-world example, or simply a more realistic example, one should always use it).
df_test = pd.concat([df_valid, df_test], ignore_index=True)
Also remember that, in the previous notebook we discussed that the 'maxlen'
and 'max_movie_index'
parameters should be computed using only the train set. In particular, to properly do the tokenization, one would have to use ONLY train tokens and add a token for new 'unknown'/'unseen' movies in the test set. This can also be done with this library or manually, so I will leave it to the reader to implement that tokenzation appraoch.
maxlen = max(
df_train.prev_movies.apply(lambda x: len(x)).max(),
df_test.prev_movies.apply(lambda x: len(x)).max(),
)
max_movie_index = max(df_train.movie_id.max(), df_test.movie_id.max())
From now one things are pretty simple, moreover bearing in mind that in this example we are not going to use a wide component since, in pple, one would believe that the information in that component is also 'carried' by the movie sequences (However in the previous notebook, if one performs ablation studies, these suggest that most of the prediction power comes from the linear, wide model).
In the example here we are going to explore one (of many) possibilities. We are simply going to encode the triplet (user, item, rating)
and use it as a deeptabular
component and the sequences of previously watched movies as the deeptext
component. For the deeptext
component we are going to use a basic encoder-only transformer model.
Let's start with the tabular data preparation
df_train_user_item = df_train[["user_id", "movie_id", "rating"]]
train_movies_sequences = df_train.prev_movies.apply(
lambda x: [int(el) for el in x]
).to_list()
y_train = df_train.target.values.astype(int)
df_test_user_item = df_train[["user_id", "movie_id", "rating"]]
test_movies_sequences = df_test.prev_movies.apply(
lambda x: [int(el) for el in x]
).to_list()
y_test = df_test.target.values.astype(int)
tab_preprocessor = tab_preprocessor = TabPreprocessor(
cat_embed_cols=["user_id", "movie_id", "rating"],
)
X_train_tab = tab_preprocessor.fit_transform(df_train_user_item)
X_test_tab = tab_preprocessor.transform(df_test_user_item)
And not the text component, simply padding the sequences:
X_train_text = np.array(
[
pad_sequences(
s,
maxlen=maxlen,
pad_first=False,
pad_idx=PAD_IDX,
)
for s in train_movies_sequences
]
)
X_test_text = np.array(
[
pad_sequences(
s,
maxlen=maxlen,
pad_first=False,
pad_idx=0,
)
for s in test_movies_sequences
]
)
We now define the model components and the wide and deep model.
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
mlp_hidden_dims=[1024, 512, 256],
mlp_activation="relu",
)
# plenty of options here, see the docs
transformer = Transformer(
vocab_size=max_movie_index + 1,
embed_dim=32,
n_heads=2,
n_blocks=2,
seq_length=maxlen,
)
wide_deep_model = WideDeep(
deeptabular=tab_mlp, deeptext=transformer, pred_dim=max_movie_index + 1
)
wide_deep_model
WideDeep( (deeptabular): Sequential( (0): TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_user_id): Embedding(749, 65, padding_idx=0) (emb_layer_movie_id): Embedding(1612, 100, padding_idx=0) (emb_layer_rating): Embedding(6, 4, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=169, out_features=1024, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=1024, out_features=512, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_2): Sequential( (0): Linear(in_features=512, out_features=256, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) ) (1): Linear(in_features=256, out_features=1683, bias=True) ) (deeptext): Sequential( (0): Transformer( (embedding): Embedding(1683, 32, padding_idx=0) (pos_encoder): PositionalEncoding( (dropout): Dropout(p=0.1, inplace=False) ) (encoder): Sequential( (transformer_block0): TransformerEncoder( (attn): MultiHeadedAttention( (dropout): Dropout(p=0.1, inplace=False) (q_proj): Linear(in_features=32, out_features=32, bias=False) (kv_proj): Linear(in_features=32, out_features=64, bias=False) (out_proj): Linear(in_features=32, out_features=32, bias=False) ) (ff): FeedForward( (w_1): Linear(in_features=32, out_features=128, bias=True) (w_2): Linear(in_features=128, out_features=32, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation): GELU(approximate='none') ) (attn_addnorm): AddNorm( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (ff_addnorm): AddNorm( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (transformer_block1): TransformerEncoder( (attn): MultiHeadedAttention( (dropout): Dropout(p=0.1, inplace=False) (q_proj): Linear(in_features=32, out_features=32, bias=False) (kv_proj): Linear(in_features=32, out_features=64, bias=False) (out_proj): Linear(in_features=32, out_features=32, bias=False) ) (ff): FeedForward( (w_1): Linear(in_features=32, out_features=128, bias=True) (w_2): Linear(in_features=128, out_features=32, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation): GELU(approximate='none') ) (attn_addnorm): AddNorm( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (ff_addnorm): AddNorm( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) ) ) (1): Linear(in_features=23552, out_features=1683, bias=True) ) )
And as in the previous notebook, let's train (you will need a GPU for this)
trainer = Trainer(
model=wide_deep_model,
objective="multiclass",
custom_loss_function=nn.CrossEntropyLoss(ignore_index=PAD_IDX),
optimizers=torch.optim.Adam(wide_deep_model.parameters(), lr=1e-3),
)
trainer.fit(
X_train={
"X_tab": X_train_tab,
"X_text": X_train_text,
"target": y_train,
},
X_val={
"X_tab": X_test_tab,
"X_text": X_test_text,
"target": y_test,
},
n_epochs=10,
batch_size=521,
shuffle=False,
)
epoch 1: 0%| | 0/147 [00:34<?, ?it/s]