18_wide_and_deep_for_recsys_pt1
The goal of this, and the companion (part 2) notebooks is to illustrate how one could use this library in the context of recommendation systems. In particular, this notebook and the scripts at the wide_deep_for_recsys
dir are a response to this issue. Therefore, we will use the Kaggle notebook referred in that issue here.
In order to keep the length of the notebook tractable, we will split this exercise in 2. In this first notebook we will prepare the data in almost the exact same way as it is done in the Kaggle notebook and also show how one could use pytorch-widedeep
to build a model almost identical to the one in that notebook.
In a second notebook, we will show how one could use this library to implement other models, still following the same problem formulation.
from pathlib import Path
import warnings
import pandas as pd
from sklearn.model_selection import train_test_split
from pytorch_widedeep.datasets import load_movielens100k
warnings.filterwarnings("ignore")
save_path = Path("prepared_data")
if not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
data, users, items = load_movielens100k(as_frame=True)
# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:
# list_of_genres = items.columns.tolist()[-19:]
list_of_genres = [
"unknown",
"Action",
"Adventure",
"Animation",
"Children's",
"Comedy",
"Crime",
"Documentary",
"Drama",
"Fantasy",
"Film-Noir",
"Horror",
"Musical",
"Mystery",
"Romance",
"Sci-Fi",
"Thriller",
"War",
"Western",
]
Let's first start by loading the interactions, user and item data
data.head()
user_id | movie_id | rating | timestamp | |
---|---|---|---|---|
0 | 196 | 242 | 3 | 881250949 |
1 | 186 | 302 | 3 | 891717742 |
2 | 22 | 377 | 1 | 878887116 |
3 | 244 | 51 | 2 | 880606923 |
4 | 166 | 346 | 1 | 886397596 |
users.head()
user_id | age | gender | occupation | zip_code | |
---|---|---|---|---|---|
0 | 1 | 24 | M | technician | 85711 |
1 | 2 | 53 | F | other | 94043 |
2 | 3 | 23 | M | writer | 32067 |
3 | 4 | 24 | M | technician | 43537 |
4 | 5 | 33 | F | other | 15213 |
items.head()
movie_id | movie_title | release_date | video_release_date | IMDb_URL | unknown | Action | Adventure | Animation | Children's | ... | Fantasy | Film-Noir | Horror | Musical | Mystery | Romance | Sci-Fi | Thriller | War | Western | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | Toy Story (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Toy%20Story%2... | 0 | 0 | 0 | 1 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
1 | 2 | GoldenEye (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?GoldenEye%20(... | 0 | 1 | 1 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
2 | 3 | Four Rooms (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Four%20Rooms%... | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
3 | 4 | Get Shorty (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Get%20Shorty%... | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 5 | Copycat (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Copycat%20(1995) | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
5 rows × 24 columns
# adding a column with the number of movies watched per user
dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
dataset["one"] = 1
dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum()
dataset.drop("one", axis=1, inplace=True)
dataset.head()
user_id | movie_id | rating | timestamp | num_watched | |
---|---|---|---|---|---|
0 | 1 | 168 | 5 | 874965478 | 1 |
1 | 1 | 172 | 5 | 874965478 | 2 |
2 | 1 | 165 | 5 | 874965518 | 3 |
3 | 1 | 156 | 4 | 874965556 | 4 |
4 | 1 | 196 | 5 | 874965677 | 5 |
# adding a column with the mean rating at a point in time per user
dataset["mean_rate"] = (
dataset.groupby("user_id")["rating"].cumsum() / dataset["num_watched"]
)
dataset.head()
user_id | movie_id | rating | timestamp | num_watched | mean_rate | |
---|---|---|---|---|---|---|
0 | 1 | 168 | 5 | 874965478 | 1 | 5.00 |
1 | 1 | 172 | 5 | 874965478 | 2 | 5.00 |
2 | 1 | 165 | 5 | 874965518 | 3 | 5.00 |
3 | 1 | 156 | 4 | 874965556 | 4 | 4.75 |
4 | 1 | 196 | 5 | 874965677 | 5 | 4.80 |
Problem formulation¶
In this particular exercise the problem is formulated as predicting the next movie that will be watched (in consequence the last interactions will be discarded)
dataset["target"] = dataset.groupby("user_id")["movie_id"].shift(-1)
Following the same processing used by the author in the before-mentioned Kaggle notebook, we build sequences of previous movies watched
# Here the author builds the sequences
dataset["prev_movies"] = dataset["movie_id"].apply(lambda x: str(x))
dataset["prev_movies"] = (
dataset.groupby("user_id")["prev_movies"]
.apply(lambda x: (x + " ").cumsum().str.strip())
.reset_index(drop=True)
)
dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split())
dataset.head()
user_id | movie_id | rating | timestamp | num_watched | mean_rate | target | prev_movies | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 168 | 5 | 874965478 | 1 | 5.00 | 172.0 | [168] |
1 | 1 | 172 | 5 | 874965478 | 2 | 5.00 | 165.0 | [168, 172] |
2 | 1 | 165 | 5 | 874965518 | 3 | 5.00 | 156.0 | [168, 172, 165] |
3 | 1 | 156 | 4 | 874965556 | 4 | 4.75 | 196.0 | [168, 172, 165, 156] |
4 | 1 | 196 | 5 | 874965677 | 5 | 4.80 | 166.0 | [168, 172, 165, 156, 196] |
And now we add a genre_rate
as the mean of all movies rated for a given genre per user
dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left")
for genre in list_of_genres:
dataset[f"{genre}_rate"] = dataset[genre] * dataset["rating"]
dataset[genre] = dataset.groupby("user_id")[genre].cumsum()
dataset[f"{genre}_rate"] = (
dataset.groupby("user_id")[f"{genre}_rate"].cumsum() / dataset[genre]
)
dataset[list_of_genres] = dataset[list_of_genres].apply(
lambda x: x / dataset["num_watched"]
)
dataset.head()
user_id | movie_id | rating | timestamp | num_watched | mean_rate | target | prev_movies | unknown | Action | ... | Fantasy_rate | Film-Noir_rate | Horror_rate | Musical_rate | Mystery_rate | Romance_rate | Sci-Fi_rate | Thriller_rate | War_rate | Western_rate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 168 | 5 | 874965478 | 1 | 5.00 | 172.0 | [168] | 0.0 | 0.000000 | ... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN |
1 | 1 | 172 | 5 | 874965478 | 2 | 5.00 | 165.0 | [168, 172] | 0.0 | 0.500000 | ... | NaN | NaN | NaN | NaN | NaN | 5.0 | 5.0 | NaN | 5.0 | NaN |
2 | 1 | 165 | 5 | 874965518 | 3 | 5.00 | 156.0 | [168, 172, 165] | 0.0 | 0.333333 | ... | NaN | NaN | NaN | NaN | NaN | 5.0 | 5.0 | NaN | 5.0 | NaN |
3 | 1 | 156 | 4 | 874965556 | 4 | 4.75 | 196.0 | [168, 172, 165, 156] | 0.0 | 0.250000 | ... | NaN | NaN | NaN | NaN | NaN | 5.0 | 5.0 | 4.0 | 5.0 | NaN |
4 | 1 | 196 | 5 | 874965677 | 5 | 4.80 | 166.0 | [168, 172, 165, 156, 196] | 0.0 | 0.200000 | ... | NaN | NaN | NaN | NaN | NaN | 5.0 | 5.0 | 4.0 | 5.0 | NaN |
5 rows × 46 columns
Adding user features
dataset = dataset.merge(users, on="user_id", how="left")
dataset.head()
user_id | movie_id | rating | timestamp | num_watched | mean_rate | target | prev_movies | unknown | Action | ... | Mystery_rate | Romance_rate | Sci-Fi_rate | Thriller_rate | War_rate | Western_rate | age | gender | occupation | zip_code | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 168 | 5 | 874965478 | 1 | 5.00 | 172.0 | [168] | 0.0 | 0.000000 | ... | NaN | NaN | NaN | NaN | NaN | NaN | 24 | M | technician | 85711 |
1 | 1 | 172 | 5 | 874965478 | 2 | 5.00 | 165.0 | [168, 172] | 0.0 | 0.500000 | ... | NaN | 5.0 | 5.0 | NaN | 5.0 | NaN | 24 | M | technician | 85711 |
2 | 1 | 165 | 5 | 874965518 | 3 | 5.00 | 156.0 | [168, 172, 165] | 0.0 | 0.333333 | ... | NaN | 5.0 | 5.0 | NaN | 5.0 | NaN | 24 | M | technician | 85711 |
3 | 1 | 156 | 4 | 874965556 | 4 | 4.75 | 196.0 | [168, 172, 165, 156] | 0.0 | 0.250000 | ... | NaN | 5.0 | 5.0 | 4.0 | 5.0 | NaN | 24 | M | technician | 85711 |
4 | 1 | 196 | 5 | 874965677 | 5 | 4.80 | 166.0 | [168, 172, 165, 156, 196] | 0.0 | 0.200000 | ... | NaN | 5.0 | 5.0 | 4.0 | 5.0 | NaN | 24 | M | technician | 85711 |
5 rows × 50 columns
Again, we use the same settings as those in the Kaggle notebook, but COLD_START_TRESH
is pretty aggressive
COLD_START_TRESH = 5
filtred_data = dataset[
(dataset["num_watched"] >= COLD_START_TRESH) & ~(dataset["target"].isna())
].sort_values("timestamp")
train_data, _test_data = train_test_split(filtred_data, test_size=0.2, shuffle=False)
valid_data, test_data = train_test_split(_test_data, test_size=0.5, shuffle=False)
cols_to_drop = [
# "rating",
"timestamp",
"num_watched",
]
df_train = train_data.drop(cols_to_drop, axis=1)
df_valid = valid_data.drop(cols_to_drop, axis=1)
df_test = test_data.drop(cols_to_drop, axis=1)
df_train.to_pickle(save_path / "df_train.pkl")
df_valid.to_pickle(save_path / "df_valid.pkl")
df_test.to_pickle(save_path / "df_test.pkl")
Let's now build a model that is nearly identical to the one use in the Kaggle notebook
import numpy as np
import torch
from torch import nn
from scipy.sparse import coo_matrix
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.preprocessing import TabPreprocessor
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = Path("prepared_data")
PAD_IDX = 0
Let's use some of the functions the author of the kaggle's notebook uses to prepare the data
def get_coo_indexes(lil):
rows = []
cols = []
for i, el in enumerate(lil):
if type(el) != list:
el = [el]
for j in el:
rows.append(i)
cols.append(j)
return rows, cols
def get_sparse_features(series, shape):
coo_indexes = get_coo_indexes(series.tolist())
sparse_df = coo_matrix(
(np.ones(len(coo_indexes[0])), (coo_indexes[0], coo_indexes[1])), shape=shape
)
return sparse_df
def sparse_to_idx(data, pad_idx=-1):
indexes = data.nonzero()
indexes_df = pd.DataFrame()
indexes_df["rows"] = indexes[0]
indexes_df["cols"] = indexes[1]
mdf = indexes_df.groupby("rows").apply(lambda x: x["cols"].tolist())
max_len = mdf.apply(lambda x: len(x)).max()
return mdf.apply(lambda x: pd.Series(x + [pad_idx] * (max_len - len(x)))).values
For the time being, we will not use a validation set for hyperparameter optimization, and we will simply concatenate the validation and the test set in one test set. I simply splitted the data into train/valid/test in case the reader wants to actually do hyperparameter optimization (and because I know in the future I will).
There is also another caveat worth mentioning, related to the indexing of the movies. To build the matrices of movies watched, we use the entire dataset. A more realistic (and correct) approach would be to use ONLY the movies that appear in the training set and consider unknown
or unseen
those in the testing set that have not been seen during training. Nonetheless, this will not affect the purposes of this notebook, which is to illustrate how one could use pytorch-widedeep
to build a recommendation algorithm. However, if one wanted to explore the performance of different algorithms in a "proper" way, these "details" need to be accounted for.
df_test = pd.concat([df_valid, df_test], ignore_index=True)
id_cols = ["user_id", "movie_id"]
max_movie_index = max(df_train.movie_id.max(), df_test.movie_id.max())
X_train = df_train.drop(id_cols + ["rating", "prev_movies", "target"], axis=1)
y_train = np.array(df_train.target.values, dtype="int64")
train_movies_watched = get_sparse_features(
df_train["prev_movies"], (len(df_train), max_movie_index + 1)
)
X_test = df_test.drop(id_cols + ["rating", "prev_movies", "target"], axis=1)
y_test = np.array(df_test.target.values, dtype="int64")
test_movies_watched = get_sparse_features(
df_test["prev_movies"], (len(df_test), max_movie_index + 1)
)
let's have a look to the information in each dataset
X_train.head()
mean_rate | unknown | Action | Adventure | Animation | Children's | Comedy | Crime | Documentary | Drama | ... | Mystery_rate | Romance_rate | Sci-Fi_rate | Thriller_rate | War_rate | Western_rate | age | gender | occupation | zip_code | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
25423 | 4.000000 | 0.0 | 0.400000 | 0.200000 | 0.0 | 0.0 | 0.400000 | 0.0 | 0.0 | 0.200000 | ... | NaN | 4.0 | 4.0 | 4.000000 | 4.0 | NaN | 21 | M | student | 48823 |
25425 | 4.000000 | 0.0 | 0.285714 | 0.142857 | 0.0 | 0.0 | 0.428571 | 0.0 | 0.0 | 0.285714 | ... | NaN | 4.0 | 4.0 | 4.000000 | 4.0 | NaN | 21 | M | student | 48823 |
25424 | 4.000000 | 0.0 | 0.333333 | 0.166667 | 0.0 | 0.0 | 0.333333 | 0.0 | 0.0 | 0.333333 | ... | NaN | 4.0 | 4.0 | 4.000000 | 4.0 | NaN | 21 | M | student | 48823 |
25426 | 3.875000 | 0.0 | 0.250000 | 0.125000 | 0.0 | 0.0 | 0.375000 | 0.0 | 0.0 | 0.250000 | ... | NaN | 4.0 | 4.0 | 3.666667 | 4.0 | NaN | 21 | M | student | 48823 |
25427 | 3.888889 | 0.0 | 0.222222 | 0.111111 | 0.0 | 0.0 | 0.333333 | 0.0 | 0.0 | 0.333333 | ... | NaN | 4.0 | 4.0 | 3.666667 | 4.0 | NaN | 21 | M | student | 48823 |
5 rows × 43 columns
y_train
array([772, 288, 108, ..., 183, 432, 509])
train_movies_watched
<76228x1683 sparse matrix of type '<class 'numpy.float64'>' with 7957390 stored elements in COOrdinate format>
sorted(df_train.prev_movies.tolist()[0])
['173', '185', '255', '286', '298']
np.where(train_movies_watched.todense()[0])
(array([0, 0, 0, 0, 0]), array([173, 185, 255, 286, 298]))
And from now on is when the specifics related to this library start to appear. The only component that is going to be a bit different is the so-called tabular component, referred as continuous
in the notebook.
In the case of pytorch-widedeep
we have the TabPreprocessor
that allows for a lot of flexibility as to how we would like to process the tabular component of this Wide and Deep model. In other words, here our tabular component is a bit more elaborated than that in the notebook, just a bit...
cat_cols = ["gender", "occupation", "zip_code"]
cont_cols = [c for c in X_train if c not in cat_cols]
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_cols,
continuous_cols=cont_cols,
)
X_train_tab = tab_preprocessor.fit_transform(X_train.fillna(0))
X_test_tab = tab_preprocessor.transform(X_test.fillna(0))
Now, in the notebook, the author moves the sparse matrices to sparse tensors and then turns them into dense tensors. In reality, this is not neccessary, one could feed sparse tensors to nn.Linear
layers in pytorch. Nonetheless, this is not the most efficient implementation and is the reason why in our library the wide, linear component is implemented as an embedding layer.
Nonetheless, to reproduce the notebook the best we can and because currently the Wide
model in pytorch-widedeep
is not designed to receive sparse tensors (we might consider implementing this functionality), we will turn the sparse COO matrices into dense arrays. We will then code a fairly simple, custom Wide
component.
X_train_wide = np.array(train_movies_watched.todense())
X_test_wide = np.array(test_movies_watched.todense())
Finally, the author of the notebook uses a simple Embedding
layer to encode the sequences of movies watched, the prev_movies
columns. In my opinion, there is an element of information redundancy here. This is because the wide and text components have implicitely the same information, but in different form. Moreover, both of the models used for these two components ignore the sequential element in the data. Nonetheless, we want to reproduce the Kaggle notebook as close as possible, AND as one can explore later (by simply performing simple ablation studies), the wide component seems to carry most of the predictive power.
X_train_text = sparse_to_idx(train_movies_watched, pad_idx=PAD_IDX)
X_test_text = sparse_to_idx(test_movies_watched, pad_idx=PAD_IDX)
Let's now build the models
class Wide(nn.Module):
def __init__(self, input_dim: int, pred_dim: int):
super().__init__()
self.input_dim = input_dim
self.pred_dim = pred_dim
# When I coded the library I never though that someone would want to code
# their own wide component. However, if you do, the wide component must have
# a 'wide_linear' attribute. In other words, the linear layer must be
# called 'wide_linear'
self.wide_linear = nn.Linear(input_dim, pred_dim)
def forward(self, X):
out = self.wide_linear(X.type(torch.float32))
return out
wide = Wide(X_train_wide.shape[1], max_movie_index + 1)
wide
Wide( (wide_linear): Linear(in_features=1683, out_features=1683, bias=True) )
class SimpleEmbed(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, pad_idx: int):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.pad_idx = pad_idx
# The sequences of movies watched are simply embedded in the Kaggle
# notebook. No RNN, Transformer or any model is used
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
def forward(self, X):
embed = self.embed(X)
embed_mean = torch.mean(embed, dim=1)
return embed_mean
@property
def output_dim(self) -> int:
# All deep components in a custom 'pytorch-widedeep' model must have
# an output_dim property
return self.embed_dim
# In the notebook the author uses simply embeddings
simple_embed = SimpleEmbed(max_movie_index + 1, 16, 0)
simple_embed
SimpleEmbed( (embed): Embedding(1683, 16, padding_idx=0) )
Maybe one would like to use an RNN to account for the sequence nature of the problem. If that was the case it would be as easy as:
basic_rnn = BasicRNN(
vocab_size=max_movie_index + 1,
embed_dim=16,
hidden_dim=32,
n_layers=2,
rnn_type="gru",
)
And finally, the tabular component, which is the notebook is simply a stak of linear + Rely layers. In our case we have an embedding layer before the linear layers to encode categorial and numerical cols
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
cont_norm_layer=None,
mlp_hidden_dims=[1024, 512, 256],
mlp_activation="relu",
)
tab_mlp
TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_occupation): Embedding(22, 9, padding_idx=0) (emb_layer_zip_code): Embedding(648, 60, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=111, out_features=1024, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=1024, out_features=512, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_2): Sequential( (0): Linear(in_features=512, out_features=256, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) )
Finally, we simply wrap up all models with the WideDeep
'collector' class and we are ready to train.
wide_deep_model = WideDeep(
wide=wide, deeptabular=tab_mlp, deeptext=simple_embed, pred_dim=max_movie_index + 1
)
wide_deep_model
WideDeep( (wide): Wide( (wide_linear): Linear(in_features=1683, out_features=1683, bias=True) ) (deeptabular): Sequential( (0): TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_occupation): Embedding(22, 9, padding_idx=0) (emb_layer_zip_code): Embedding(648, 60, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=111, out_features=1024, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=1024, out_features=512, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_2): Sequential( (0): Linear(in_features=512, out_features=256, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) ) (1): Linear(in_features=256, out_features=1683, bias=True) ) (deeptext): Sequential( (0): SimpleEmbed( (embed): Embedding(1683, 16, padding_idx=0) ) (1): Linear(in_features=16, out_features=1683, bias=True) ) )
Note that the main difference between this wide and deep model and the Wide and Deep model in the Kaggle notebook is that in that notebook, the author concatenates the embedings and the tabular features, then passes this concatenation through a stack of linear + Relu layers with a final output dim of 256. Then concatenates this output with the binary features and connects this concatenation with the final linear layer (so the final weights are of dim (batch_size, 256 + 1683)). Our implementation follows the notation of the original paper and instead of concatenating the tabular, text and wide components and then connect them to the output neurons, we first compute their output, and then add it (see here: https://arxiv.org/pdf/1606.07792.pdf, their Eq 3). Note that this is effectively the same, with the caveat that while in one case one initialises a big weight matrix "at once", in our implementation we initialise different matrices for different components. Anyway, let's give it a go.
trainer = Trainer(
model=wide_deep_model,
objective="multiclass",
custom_loss_function=nn.CrossEntropyLoss(ignore_index=PAD_IDX),
optimizers=torch.optim.Adam(wide_deep_model.parameters(), lr=1e-3),
)
trainer.fit(
X_train={
"X_wide": X_train_wide,
"X_tab": X_train_tab,
"X_text": X_train_text,
"target": y_train,
},
X_val={
"X_wide": X_test_wide,
"X_tab": X_test_tab,
"X_text": X_test_text,
"target": y_test,
},
n_epochs=5,
batch_size=512,
shuffle=False,
)
epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00, 7.66it/s, loss=6.66] valid: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 18.75it/s, loss=6.6] epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:21<00:00, 6.95it/s, loss=5.97] valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 21.03it/s, loss=6.52] epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00, 7.51it/s, loss=5.65] valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 20.16it/s, loss=6.53] epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:23<00:00, 6.29it/s, loss=5.41] valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 13.97it/s, loss=6.57] epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00, 7.58it/s, loss=5.2] valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 18.82it/s, loss=6.63]
Now one could continue to the 'compare' metrics section of the Kaggle notebook. However, for the purposes of illustrating how one could use pytorch-widedeep
to build recommendation algorithms we consider this notebook completed and move onto part 2