16_Usign-a-custom-hugging-face-model
Using a Hugginface model¶
In this notebook we will show how to use an "external" Hugginface model along with any other model in the libray. In particular we will show how to combine it with a tabular DL model.
Since we are here, we will also compare the performance of a few models on a text classification problem.
The notebook will go as follows:
- Text classification using tf-idf + LightGBM
- Text classification using a basic RNN
- Text classification using Distilbert
In all 3 cases we will add some tabular features to see if these help.
In general, I would not pay much attention to the results since I have placed no effort in getting the best possible results (i.e. no hyperparameter optimization or trying different architectures, for example).
Let's go
import numpy as np
import torch
import lightgbm as lgb
from lightgbm import Dataset as lgbDataset
from scipy.sparse import hstack, csr_matrix
from sklearn.metrics import (
f1_score,
recall_score,
accuracy_score,
precision_score,
confusion_matrix,
)
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from torch import Tensor, nn
from transformers import DistilBertModel, DistilBertTokenizer
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep
from pytorch_widedeep.metrics import F1Score, Accuracy
from pytorch_widedeep.utils import Tokenizer, LabelEncoder
from pytorch_widedeep.preprocessing import TextPreprocessor, TabPreprocessor
from pytorch_widedeep.datasets import load_womens_ecommerce
from pytorch_widedeep.utils.fastai_transforms import (
fix_html,
spec_add_spaces,
rm_useless_spaces,
)
Let's load the data and have a look:
df = load_womens_ecommerce(as_frame=True)
df.columns = [c.replace(" ", "_").lower() for c in df.columns]
# classes from [0,num_class)
df["rating"] = (df["rating"] - 1).astype("int64")
# group reviews with 1 and 2 scores into one class
df.loc[df.rating == 0, "rating"] = 1
# and back again to [0,num_class)
df["rating"] = (df["rating"] - 1).astype("int64")
# drop short reviews
df = df[~df.review_text.isna()]
df["review_length"] = df.review_text.apply(lambda x: len(x.split(" ")))
df = df[df.review_length >= 5]
df = df.drop("review_length", axis=1).reset_index(drop=True)
df.head()
clothing_id | age | title | review_text | rating | recommended_ind | positive_feedback_count | division_name | department_name | class_name | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 767 | 33 | None | Absolutely wonderful - silky and sexy and comf... | 2 | 1 | 0 | Initmates | Intimate | Intimates |
1 | 1080 | 34 | None | Love this dress! it's sooo pretty. i happene... | 3 | 1 | 4 | General | Dresses | Dresses |
2 | 1077 | 60 | Some major design flaws | I had such high hopes for this dress and reall... | 1 | 0 | 0 | General | Dresses | Dresses |
3 | 1049 | 50 | My favorite buy! | I love, love, love this jumpsuit. it's fun, fl... | 3 | 1 | 0 | General Petite | Bottoms | Pants |
4 | 847 | 47 | Flattering shirt | This shirt is very flattering to all due to th... | 3 | 1 | 6 | General | Tops | Blouses |
So, we will use the review_text
column to predict the rating
. Later on, we will try to combine it with some other columns (like division_name
and age
) see if these help.
Let's first have a look to the distribution of ratings
df.rating.value_counts()
rating 3 12515 2 4904 1 2820 0 2369 Name: count, dtype: int64
This shows that we could have perhaps grouped rating scores of 1, 2 and 3 into 1...but anyway, let's just move on with those 4 classes.
We are not going to carry any hyperparameter optimization here, so, we will only need a train and a test set (i.e. no need of a validation set for the example in this notebook)
train, test = train_test_split(df, train_size=0.8, random_state=1, stratify=df.rating)
Let's see what we have to beat. What metrics would we obtain if we always predict the most common rating (3)?
most_common_pred = [train.rating.value_counts().index[0]] * len(test)
most_common_acc = accuracy_score(test.rating, most_common_pred)
most_common_f1 = f1_score(test.rating, most_common_pred, average="weighted")
print(f"Accuracy: {most_common_acc}. F1 Score: {most_common_f1}")
Accuracy: 0.553516143299425. F1 Score: 0.3944344218301668
ok, these are our "baseline" metrics.
Let's start by using simply tf-idf + lightGBM
1. Text classification using tf-idf + LightGBM¶
# ?Tokenizer
# this Tokenizer is part of our utils module but of course, any valid tokenizer can be used here.
# When using notebooks there seems to be an issue related with multiprocessing (and sometimes tqdm)
# that can only be solved by using only one CPU
tok = Tokenizer(n_cpus=1)
tok_reviews_tr = tok.process_all(train.review_text.tolist())
tok_reviews_te = tok.process_all(test.review_text.tolist())
vectorizer = TfidfVectorizer(
max_features=5000, preprocessor=lambda x: x, tokenizer=lambda x: x, min_df=5
)
X_text_tr = vectorizer.fit_transform(tok_reviews_tr)
X_text_te = vectorizer.transform(tok_reviews_te)
/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/sklearn/feature_extraction/text.py:525: UserWarning: The parameter 'token_pattern' will not be used since 'tokenizer' is not None' warnings.warn(
X_text_tr
<18086x4566 sparse matrix of type '<class 'numpy.float64'>' with 884074 stored elements in Compressed Sparse Row format>
We now move our matrices to lightGBM Dataset
format
lgbtrain_text = lgbDataset(
X_text_tr,
train.rating.values,
free_raw_data=False,
)
lgbtest_text = lgbDataset(
X_text_te,
test.rating.values,
reference=lgbtrain_text,
free_raw_data=False,
)
and off we go. By the way, I think as we run the next cell, we should appreciate how fast lightGBM runs. Yes, the input is a sparse matrix, but still, trains on 18086x4566 in a matter of secs
lgb_text_model = lgb.train(
{"objective": "multiclass", "num_classes": 4},
lgbtrain_text,
valid_sets=[lgbtest_text, lgbtrain_text],
valid_names=["test", "train"],
)
preds_text = lgb_text_model.predict(X_text_te)
pred_text_class = np.argmax(preds_text, 1)
acc_text = accuracy_score(lgbtest_text.label, pred_text_class)
f1_text = f1_score(lgbtest_text.label, pred_text_class, average="weighted")
cm_text = confusion_matrix(lgbtest_text.label, pred_text_class)
print(f"LightGBM Accuracy: {acc_text}. LightGBM F1 Score: {f1_text}")
LightGBM Accuracy: 0.6444051304732419. LightGBM F1 Score: 0.617154488246181
print(f"LightGBM Confusion Matrix: \n {cm_text}")
LightGBM Confusion Matrix: [[ 199 135 61 79] [ 123 169 149 123] [ 30 94 279 578] [ 16 30 190 2267]]
Ok, so, with no hyperparameter optimization lightGBM gets an accuracy of 0.64 and a F1 score of 0.62. This is significantly better than predicting always the most popular.
Let's see if in this implementation, some additional features, like age
or class_name
are of any help
tab_cols = [
"age",
"division_name",
"department_name",
"class_name",
]
for tab_df in [train, test]:
for c in ["division_name", "department_name", "class_name"]:
tab_df[c] = tab_df[c].str.lower()
tab_df[c].fillna("missing", inplace=True)
# This is our LabelEncoder. A class that is designed to work with the models in this library but
# can be used for general purposes
le = LabelEncoder(columns_to_encode=["division_name", "department_name", "class_name"])
train_tab_le = le.fit_transform(train)
test_tab_le = le.transform(test)
train_tab_le.head()
clothing_id | age | title | review_text | rating | recommended_ind | positive_feedback_count | division_name | department_name | class_name | |
---|---|---|---|---|---|---|---|---|---|---|
4541 | 836 | 35 | None | Bought this on sale in my reg size- 10. im 5'9... | 2 | 1 | 2 | 1 | 1 | 1 |
18573 | 1022 | 25 | Look like "mom jeans" | Maybe i just have the wrong body type for thes... | 1 | 0 | 0 | 2 | 2 | 2 |
1058 | 815 | 39 | Ig brought me here | Love the way this top layers under my jackets ... | 2 | 1 | 0 | 1 | 1 | 1 |
12132 | 984 | 47 | Runs small especially the arms | I love this jacket. it's the prettiest and mos... | 3 | 1 | 0 | 1 | 3 | 3 |
20756 | 1051 | 42 | True red, true beauty. | These pants are gorgeous--the fabric has a sat... | 3 | 1 | 0 | 2 | 2 | 4 |
let's for example have a look to the encodings for the categorical feature class_name
le.encoding_dict["class_name"]
{'blouses': 1, 'jeans': 2, 'jackets': 3, 'pants': 4, 'knits': 5, 'dresses': 6, 'skirts': 7, 'sweaters': 8, 'fine gauge': 9, 'legwear': 10, 'lounge': 11, 'shorts': 12, 'outerwear': 13, 'intimates': 14, 'swim': 15, 'trend': 16, 'sleep': 17, 'layering': 18, 'missing': 19, 'casual bottoms': 20, 'chemises': 21}
# tabular training and test sets
X_tab_tr = csr_matrix(train_tab_le[tab_cols].values)
X_tab_te = csr_matrix(test_tab_le[tab_cols].values)
# text + tabular training and test sets
X_tab_text_tr = hstack((X_tab_tr, X_text_tr))
X_tab_text_te = hstack((X_tab_te, X_text_te))
X_tab_tr
<18086x4 sparse matrix of type '<class 'numpy.int64'>' with 72344 stored elements in Compressed Sparse Row format>
X_tab_text_tr
<18086x4570 sparse matrix of type '<class 'numpy.float64'>' with 956418 stored elements in Compressed Sparse Row format>
lgbtrain_tab_text = lgbDataset(
X_tab_text_tr,
train.rating.values,
categorical_feature=[0, 1, 2, 3],
free_raw_data=False,
)
lgbtest_tab_text = lgbDataset(
X_tab_text_te,
test.rating.values,
reference=lgbtrain_tab_text,
free_raw_data=False,
)
lgb_tab_text_model = lgb.train(
{"objective": "multiclass", "num_classes": 4},
lgbtrain_tab_text,
valid_sets=[lgbtrain_tab_text, lgbtest_tab_text],
valid_names=["test", "train"],
verbose_eval=False,
)
/opt/conda/envs/wd38/lib/python3.8/site-packages/lightgbm/basic.py:2065: UserWarning: Using categorical_feature in Dataset. _log_warning('Using categorical_feature in Dataset.') /opt/conda/envs/wd38/lib/python3.8/site-packages/lightgbm/basic.py:2068: UserWarning: categorical_feature in Dataset is overridden. New categorical_feature is [0, 1, 2, 3] _log_warning('categorical_feature in Dataset is overridden.\n' /opt/conda/envs/wd38/lib/python3.8/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead. _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.138280 seconds. You can set `force_col_wise=true` to remove the overhead. [LightGBM] [Info] Total Bins 143432 [LightGBM] [Info] Number of data points in the train set: 18086, number of used features: 2289 [LightGBM] [Info] Start training from score -2.255919 [LightGBM] [Info] Start training from score -2.081545 [LightGBM] [Info] Start training from score -1.528281 [LightGBM] [Info] Start training from score -0.591354
/opt/conda/envs/wd38/lib/python3.8/site-packages/lightgbm/basic.py:1780: UserWarning: Overriding the parameters from Reference Dataset. _log_warning('Overriding the parameters from Reference Dataset.') /opt/conda/envs/wd38/lib/python3.8/site-packages/lightgbm/basic.py:1513: UserWarning: categorical_column in param dict is overridden. _log_warning(f'{cat_alias} in param dict is overridden.')
preds_tab_text = lgb_tab_text_model.predict(X_tab_text_te)
preds_tab_text_class = np.argmax(preds_tab_text, 1)
acc_tab_text = accuracy_score(lgbtest_tab_text.label, preds_tab_text_class)
f1_tab_text = f1_score(lgbtest_tab_text.label, preds_tab_text_class, average="weighted")
cm_tab_text = confusion_matrix(lgbtest_tab_text.label, preds_tab_text_class)
print(
f"LightGBM text + tabular Accuracy: {acc_tab_text}. LightGBM text + tabular F1 Score: {f1_tab_text}"
)
LightGBM text + tabular Accuracy: 0.6382131800088456. LightGBM text + tabular F1 Score: 0.6080251307242649
print(f"LightGBM text + tabular Confusion Matrix:\n {cm_tab_text}")
LightGBM text + tabular Confusion Matrix: [[ 193 123 68 90] [ 123 146 157 138] [ 37 90 272 582] [ 16 37 175 2275]]
So, in this set up, the addition tabular columns do not help performance.
2. Text classification using pytorch-widedeep's built-in models (a basic RNN)¶
Moving on now to fully using pytorch-widedeep
in this dataset, let's have a look on how one could use a simple RNN to predict the ratings with the library.
text_preprocessor = TextPreprocessor(
text_col="review_text", max_vocab=5000, min_freq=5, maxlen=90, n_cpus=1
)
wd_X_text_tr = text_preprocessor.fit_transform(train)
wd_X_text_te = text_preprocessor.transform(test)
The vocabulary contains 4328 tokens
basic_rnn = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=300,
hidden_dim=64,
n_layers=3,
rnn_dropout=0.2,
head_hidden_dims=[32],
)
wd_text_model = WideDeep(deeptext=basic_rnn, pred_dim=4)
wd_text_model
WideDeep( (deeptext): Sequential( (0): BasicRNN( (word_embed): Embedding(4328, 300, padding_idx=1) (rnn): LSTM(300, 64, num_layers=3, batch_first=True, dropout=0.2) (rnn_mlp): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU(inplace=True) ) ) ) ) (1): Linear(in_features=32, out_features=4, bias=True) ) )
text_trainer = Trainer(
wd_text_model,
objective="multiclass",
metrics=[Accuracy, F1Score(average=True)],
num_workers=0, # As in the case of the tokenizer, in notebook I need to set this to 0 for the Trainer to work
)
text_trainer.fit(
X_text=wd_X_text_tr,
target=train.rating.values,
n_epochs=5,
batch_size=256,
)
epoch 1: 100%|███████████████████████████████████████████████| 71/71 [00:01<00:00, 52.39it/s, loss=1.16, metrics={'acc': 0.5349, 'f1': 0.2011}] epoch 2: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 70.35it/s, loss=0.964, metrics={'acc': 0.5827, 'f1': 0.3005}] epoch 3: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 70.33it/s, loss=0.845, metrics={'acc': 0.6252, 'f1': 0.4133}] epoch 4: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 69.99it/s, loss=0.765, metrics={'acc': 0.6575, 'f1': 0.4875}] epoch 5: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 69.55it/s, loss=0.709, metrics={'acc': 0.6879, 'f1': 0.5423}]
wd_pred_text = text_trainer.predict_proba(X_text=wd_X_text_te)
wd_pred_text_class = np.argmax(wd_pred_text, 1)
predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 211.51it/s]
wd_acc_text = accuracy_score(test.rating, wd_pred_text_class)
wd_f1_text = f1_score(test.rating, wd_pred_text_class, average="weighted")
wd_cm_text = confusion_matrix(test.rating, wd_pred_text_class)
print(f"Basic RNN Accuracy: {wd_acc_text}. Basic RNN F1 Score: {wd_f1_text}")
Basic RNN Accuracy: 0.6076957098628926. Basic RNN F1 Score: 0.6017335854471788
print(f"Basic RNN Confusion Matrix:\n {wd_cm_text}")
Basic RNN Confusion Matrix: [[ 327 76 62 9] [ 285 115 117 47] [ 131 122 315 413] [ 42 69 401 1991]]
The performance is very similar to that of using simply tf-idf and lightgbm. Let see if adding tabular features helps when using pytorch-widedeep
# ?TabPreprocessor
tab_preprocessor = TabPreprocessor(cat_embed_cols=tab_cols)
wd_X_tab_tr = tab_preprocessor.fit_transform(train)
wd_X_tab_te = tab_preprocessor.transform(test)
# ?TabMlp
tab_model = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
mlp_hidden_dims=[100, 50],
)
tab_model
TabMlp( (cat_and_cont_embed): DiffSizeCatAndContEmbeddings( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_age): Embedding(78, 18, padding_idx=0) (emb_layer_division_name): Embedding(5, 3, padding_idx=0) (emb_layer_department_name): Embedding(8, 5, padding_idx=0) (emb_layer_class_name): Embedding(22, 9, padding_idx=0) ) (embedding_dropout): Dropout(p=0.1, inplace=False) ) ) (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Dropout(p=0.1, inplace=False) (1): Linear(in_features=35, out_features=100, bias=True) (2): ReLU(inplace=True) ) (dense_layer_1): Sequential( (0): Dropout(p=0.1, inplace=False) (1): Linear(in_features=100, out_features=50, bias=True) (2): ReLU(inplace=True) ) ) ) )
text_model = BasicRNN(
vocab_size=len(text_preprocessor.vocab.itos),
embed_dim=300,
hidden_dim=64,
n_layers=3,
rnn_dropout=0.2,
head_hidden_dims=[32],
)
wd_tab_and_text_model = WideDeep(deeptabular=tab_model, deeptext=text_model, pred_dim=4)
wd_tab_and_text_model
WideDeep( (deeptabular): Sequential( (0): TabMlp( (cat_and_cont_embed): DiffSizeCatAndContEmbeddings( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_age): Embedding(78, 18, padding_idx=0) (emb_layer_division_name): Embedding(5, 3, padding_idx=0) (emb_layer_department_name): Embedding(8, 5, padding_idx=0) (emb_layer_class_name): Embedding(22, 9, padding_idx=0) ) (embedding_dropout): Dropout(p=0.1, inplace=False) ) ) (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Dropout(p=0.1, inplace=False) (1): Linear(in_features=35, out_features=100, bias=True) (2): ReLU(inplace=True) ) (dense_layer_1): Sequential( (0): Dropout(p=0.1, inplace=False) (1): Linear(in_features=100, out_features=50, bias=True) (2): ReLU(inplace=True) ) ) ) ) (1): Linear(in_features=50, out_features=4, bias=True) ) (deeptext): Sequential( (0): BasicRNN( (word_embed): Embedding(4328, 300, padding_idx=1) (rnn): LSTM(300, 64, num_layers=3, batch_first=True, dropout=0.2) (rnn_mlp): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=64, out_features=32, bias=True) (1): ReLU(inplace=True) ) ) ) ) (1): Linear(in_features=32, out_features=4, bias=True) ) )
tab_and_text_trainer = Trainer(
wd_tab_and_text_model,
objective="multiclass",
metrics=[Accuracy, F1Score(average=True)],
num_workers=0, # As in the case of the tokenizer, in notebook I need to set this to 0 for the Trainer to work
)
tab_and_text_trainer.fit(
X_tab=wd_X_tab_tr,
X_text=wd_X_text_tr,
target=train.rating.values,
n_epochs=5,
batch_size=256,
)
epoch 1: 100%|████████████████████████████████████████████████| 71/71 [00:01<00:00, 52.04it/s, loss=1.13, metrics={'acc': 0.538, 'f1': 0.1911}] epoch 2: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 52.28it/s, loss=0.936, metrics={'acc': 0.5887, 'f1': 0.3507}] epoch 3: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 52.26it/s, loss=0.825, metrics={'acc': 0.6394, 'f1': 0.4545}] epoch 4: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 51.33it/s, loss=0.757, metrics={'acc': 0.6696, 'f1': 0.5214}] epoch 5: 100%|██████████████████████████████████████████████| 71/71 [00:01<00:00, 50.39it/s, loss=0.702, metrics={'acc': 0.6963, 'f1': 0.5654}]
wd_pred_tab_and_text = tab_and_text_trainer.predict_proba(
X_tab=wd_X_tab_te, X_text=wd_X_text_te
)
wd_pred_tab_and_text_class = np.argmax(wd_pred_tab_and_text, 1)
predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 136.94it/s]
wd_acc_tab_and_text = accuracy_score(test.rating, wd_pred_tab_and_text_class)
wd_f1_tab_and_text = f1_score(
test.rating, wd_pred_tab_and_text_class, average="weighted"
)
wd_cm_tab_and_text = confusion_matrix(test.rating, wd_pred_tab_and_text_class)
print(
f"Basic RNN + Tabular Accuracy: {wd_acc_tab_and_text}. Basic RNN + TabularF1 Score: {wd_f1_tab_and_text}"
)
print(f"Basic RNN + Tabular Confusion Matrix:\n {wd_cm_tab_and_text}")
Basic RNN + Tabular Accuracy: 0.6333480760725343. Basic RNN + TabularF1 Score: 0.6332310089593208 Basic RNN + Tabular Confusion Matrix: [[ 267 132 65 10] [ 198 168 159 39] [ 57 113 410 401] [ 12 58 414 2019]]
3. Text classification using a Hugginface model as a custom model in pytorch-widedeep's¶
We are going to "manually" code the Tokenizer and the model and see how they can be used as part of the process along with the pytorch-widedeep
library.
Tokenizer:
class BertTokenizer(object):
def __init__(
self,
pretrained_tokenizer="distilbert-base-uncased",
do_lower_case=True,
max_length=90,
):
super(BertTokenizer, self).__init__()
self.pretrained_tokenizer = pretrained_tokenizer
self.do_lower_case = do_lower_case
self.max_length = max_length
def fit(self, texts):
self.tokenizer = DistilBertTokenizer.from_pretrained(
self.pretrained_tokenizer, do_lower_case=self.do_lower_case
)
return self
def transform(self, texts):
input_ids = []
for text in texts:
encoded_sent = self.tokenizer.encode_plus(
text=self._pre_rules(text),
add_special_tokens=True,
max_length=self.max_length,
padding="max_length",
truncation=True,
)
input_ids.append(encoded_sent.get("input_ids"))
return np.stack(input_ids)
def fit_transform(self, texts):
return self.fit(texts).transform(texts)
@staticmethod
def _pre_rules(text):
return fix_html(rm_useless_spaces(spec_add_spaces(text)))
Model:
class BertModel(nn.Module):
def __init__(
self,
model_name: str = "distilbert-base-uncased",
freeze_bert: bool = False,
):
super(BertModel, self).__init__()
self.bert = DistilBertModel.from_pretrained(
model_name,
)
if freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, X_inp: Tensor) -> Tensor:
attn_mask = (X_inp != 0).type(torch.int8)
outputs = self.bert(input_ids=X_inp, attention_mask=attn_mask)
return outputs[0][:, 0, :]
@property
def output_dim(self) -> int:
# This is THE ONLY requirement for any model to work with pytorch-widedeep. Must
# have a 'output_dim' property so the WideDeep class knows the incoming dims
# from the custom model. in this case, I hardcoded it
return 768
bert_tokenizer = BertTokenizer()
X_bert_tr = bert_tokenizer.fit_transform(train["review_text"].tolist())
X_bert_te = bert_tokenizer.transform(test["review_text"].tolist())
As I mentioned a number of times in the documentation and examples, pytorch-widedeep
is designed for flexibility. For any of the data modes (tabular, text and images) there are available components/models in the library. However, the user can choose to use any model they want with the only requirement that such model must have a output_dim
property.
With that in mind, the BertModel
class defined above can be used by pytorch-widedeep
as any other of the internal components. In other words, simply...pass it to the WideDeep
class. In this case we are going to add a FC-head as part of the classifier.
bert_model = BertModel(freeze_bert=True)
wd_bert_model = WideDeep(
deeptext=bert_model,
head_hidden_dims=[256, 128, 64],
pred_dim=4,
)
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight'] - This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
wd_bert_model
WideDeep( (deeptext): BertModel( (bert): DistilBertModel( (embeddings): Embeddings( (word_embeddings): Embedding(30522, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (transformer): Transformer( (layer): ModuleList( (0): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (1): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (2): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (3): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (4): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) (5): TransformerBlock( (attention): MultiHeadSelfAttention( (dropout): Dropout(p=0.1, inplace=False) (q_lin): Linear(in_features=768, out_features=768, bias=True) (k_lin): Linear(in_features=768, out_features=768, bias=True) (v_lin): Linear(in_features=768, out_features=768, bias=True) (out_lin): Linear(in_features=768, out_features=768, bias=True) ) (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (ffn): FFN( (dropout): Dropout(p=0.1, inplace=False) (lin1): Linear(in_features=768, out_features=3072, bias=True) (lin2): Linear(in_features=3072, out_features=768, bias=True) (activation): GELUActivation() ) (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) ) ) ) ) ) (deephead): Sequential( (0): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=768, out_features=256, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=256, out_features=128, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_2): Sequential( (0): Linear(in_features=128, out_features=64, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) (1): Linear(in_features=64, out_features=4, bias=True) ) )
wd_bert_trainer = Trainer(
wd_bert_model,
objective="multiclass",
metrics=[Accuracy, F1Score(average=True)],
num_workers=0, # As in the case of the tokenizer, in notebook I need to set this to 0 for the Trainer to work
)
wd_bert_trainer.fit(
X_text=X_bert_tr,
target=train.rating.values,
n_epochs=3,
batch_size=64,
)
epoch 1: 100%|████████████████████████████████████████████| 283/283 [00:14<00:00, 19.68it/s, loss=0.968, metrics={'acc': 0.5879, 'f1': 0.3591}] epoch 2: 100%|████████████████████████████████████████████| 283/283 [00:14<00:00, 19.63it/s, loss=0.884, metrics={'acc': 0.6178, 'f1': 0.4399}] epoch 3: 100%|█████████████████████████████████████████████| 283/283 [00:14<00:00, 19.55it/s, loss=0.87, metrics={'acc': 0.6234, 'f1': 0.4527}]
wd_bert_pred_text = wd_bert_trainer.predict_proba(X_text=X_bert_te)
wd_bert_pred_text_class = np.argmax(wd_bert_pred_text, 1)
predict: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:03<00:00, 21.97it/s]
wd_bert_acc = accuracy_score(test.rating, wd_bert_pred_text_class)
wd_bert_f1 = f1_score(test.rating, wd_bert_pred_text_class, average="weighted")
wd_bert_cm = confusion_matrix(test.rating, wd_bert_pred_text_class)
print(f"Distilbert Accuracy: {wd_bert_acc}. Distilbert F1 Score: {wd_bert_f1}")
print(f"Distilbert Confusion Matrix:\n {wd_bert_cm}")
Distilbert Accuracy: 0.6326846528084918. Distilbert F1 Score: 0.5796652991272998 Distilbert Confusion Matrix: [[ 287 75 22 90] [ 197 136 62 169] [ 68 119 123 671] [ 40 64 84 2315]]
Now, adding a tabular model follows the exact same process as the one described in section 2.
tab_model = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
mlp_hidden_dims=[100, 50],
)
wd_tab_bert_model = WideDeep(
deeptabular=tab_model,
deeptext=bert_model,
head_hidden_dims=[256, 128, 64],
pred_dim=4,
)
wd_tab_bert_trainer = Trainer(
wd_tab_bert_model,
objective="multiclass",
metrics=[Accuracy, F1Score(average=True)],
num_workers=0, # As in the case of the tokenizer, in notebook I need to set this to 0 for the Trainer to work
)
wd_tab_bert_trainer.fit(
X_tab=wd_X_tab_tr,
X_text=X_bert_tr,
target=train.rating.values,
n_epochs=3,
batch_size=64,
)
epoch 1: 100%|████████████████████████████████████████████| 283/283 [00:15<00:00, 18.15it/s, loss=0.974, metrics={'acc': 0.5838, 'f1': 0.3404}] epoch 2: 100%|█████████████████████████████████████████████| 283/283 [00:15<00:00, 18.38it/s, loss=0.885, metrics={'acc': 0.618, 'f1': 0.4378}] epoch 3: 100%|████████████████████████████████████████████| 283/283 [00:15<00:00, 18.40it/s, loss=0.868, metrics={'acc': 0.6252, 'f1': 0.4575}]
wd_tab_bert_pred_text = wd_tab_bert_trainer.predict_proba(
X_tab=wd_X_tab_te, X_text=X_bert_te
)
wd_tab_bert_pred_text_class = np.argmax(wd_tab_bert_pred_text, 1)
predict: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:03<00:00, 21.32it/s]
wd_tab_bert_acc = accuracy_score(test.rating, wd_tab_bert_pred_text_class)
wd_tab_bert_f1 = f1_score(test.rating, wd_tab_bert_pred_text_class, average="weighted")
wd_tab_bert_cm = confusion_matrix(test.rating, wd_tab_bert_pred_text_class)
print(
f"Distilbert + Tabular Accuracy: {wd_tab_bert_acc}. Distilbert+ Tabular F1 Score: {wd_tab_bert_f1}"
)
print(f"Distilbert + Tabular Confusion Matrix:\n {wd_tab_bert_cm}")
Distilbert + Tabular Accuracy: 0.6242812914639541. Distilbert+ Tabular F1 Score: 0.5508351761564895 Distilbert + Tabular Confusion Matrix: [[ 297 56 11 110] [ 229 91 38 206] [ 86 90 71 734] [ 49 48 42 2364]]