15_Self-Supervised Pre-Training pt 1
Self Supervised Pretraining for Tabular Data¶
We have implemented two Self Supervised Pre-training routines that allow the user to pre-train all tabular models in the library with the exception of the TabPerceiver (which is a special monster).
The two routines implemented are illustrated in the figures below. The 1st is from TabNet: Attentive Interpretable Tabular Learning and is designed for models that do not use transformer-based architectures, while the second is from SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training, and is designed for models that use transformer-based architectures.
Fig 1. Figure 2 in their paper. I have included de original caption in case is useful, althought the Figure itself is pretty self explanatory
Fig 2. Figure 1 in their paper. Here the caption is necessary 😏
It is beyond the scope of this notebook to explain in detail those implementations. Therefore, we strongly recommend the user to go and read the papers if this functionality is of interest to her/him.
One thing is worth noticing however. As seen in Fig 1(the TabNet paper's Fig 2) the masking of the input features happens in the feature space. However, the implementation in this library is inspired by that at the dreamquark-ai repo, which is in itself inspired by the original implementation (by the way, at this point I will write it once again. All TabNet related things in this library are inspired when not directly based in the code in that repo, therefore, ALL CREDIT TO THE GUYS AT dreamquark-ai).
In that implementation the masking happens in the embedding space, and currently does not mask the entire embedding (i.e. categorical feature). We decided to release as it is in this version and we will implement the exact same process described in the paper in future releases.
Having said all of the above let's see how to use self supervision for tabular data with pytorch-widedeep
. We will concentrate in this notebook on the 1st of the two approaches (the 'TabNet approach'). For details on the second approach please see 16_Self_Supervised_Pretraning_pt2
.
Self Supervision for non-transformer-based models..¶
...or in general, for models where the embeddigns can have all different dimensions. In this library, these are: TabMlp
, TabResNet
and TabNet
As shown in Figure, this is an encoder-encoder approach where we learn to predict values in the incoming data that have been masked. However, as I mentioned before, our implementation is a bit different, and the masking occurs in th embedding space.
Nonetheless, the code below illustrates how to use this encoder-decoder approach with pytorch-widedeep
import torch
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.self_supervised_training import EncoderDecoderTrainer
df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
# one could chose to use a validation set for early stopping, hyperparam
# optimization, etc. This is just an example, so we simply use train/test
# split
df_tr, df_te = train_test_split(df, test_size=0.2, stratify=df.income_label)
df_tr.head(2)
age | workclass | fnlwgt | education | educational_num | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
9042 | 26 | Local-gov | 250551 | HS-grad | 9 | Married-civ-spouse | Craft-repair | Own-child | Black | Male | 0 | 0 | 40 | United-States | 0 |
25322 | 50 | Private | 34832 | Bachelors | 13 | Married-civ-spouse | Tech-support | Husband | White | Male | 15024 | 0 | 40 | United-States | 1 |
# As always, we need to define which cols will be represented as embeddings
# and which one will be continuous features
cat_embed_cols = [
"workclass",
"education",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"capital_gain",
"capital_loss",
"native_country",
]
continuous_cols = ["age", "hours_per_week"]
target_col = "income_label"
# We prepare the data to be passed to the model
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols
)
X_tab = tab_preprocessor.fit_transform(df_tr)
target = df_tr[target_col].values
/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised warnings.warn("Continuous columns will not be normalised")
X_tab[:5]
array([[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 26, 40], [ 2, 2, 1, 2, 2, 2, 1, 2, 1, 1, 50, 40], [ 2, 1, 1, 3, 2, 2, 1, 1, 2, 1, 39, 46], [ 2, 3, 2, 4, 1, 2, 2, 1, 1, 1, 17, 10], [ 3, 4, 2, 1, 1, 2, 1, 1, 1, 1, 32, 20]])
# We define a model that will act as the encoder in the encoder/decoder
# architecture. This could be any of: TabMlp, TabResnet or TabNet
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
)
tab_mlp
TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_workclass): Embedding(10, 5, padding_idx=0) (emb_layer_education): Embedding(17, 8, padding_idx=0) (emb_layer_marital_status): Embedding(8, 5, padding_idx=0) (emb_layer_occupation): Embedding(16, 7, padding_idx=0) (emb_layer_relationship): Embedding(7, 4, padding_idx=0) (emb_layer_race): Embedding(6, 4, padding_idx=0) (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0) (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0) (emb_layer_native_country): Embedding(42, 13, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=95, out_features=200, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=200, out_features=100, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) )
# If we do not pass a custom decoder, which is perfectly possible via the
# decoder param, the EncoderDecoderTrainer will automatically build a
# decoder which will be the 'mirror' image of the encoder
encoder_decoder_trainer = EncoderDecoderTrainer(encoder=tab_mlp)
# let's have a look to the encoder_decoder_model (aka ed_model)
encoder_decoder_trainer.ed_model
EncoderDecoderModel( (encoder): TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_workclass): Embedding(10, 5, padding_idx=0) (emb_layer_education): Embedding(17, 8, padding_idx=0) (emb_layer_marital_status): Embedding(8, 5, padding_idx=0) (emb_layer_occupation): Embedding(16, 7, padding_idx=0) (emb_layer_relationship): Embedding(7, 4, padding_idx=0) (emb_layer_race): Embedding(6, 4, padding_idx=0) (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0) (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0) (emb_layer_native_country): Embedding(42, 13, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=95, out_features=200, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=200, out_features=100, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) ) (decoder): TabMlpDecoder( (decoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=100, out_features=200, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=200, out_features=95, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) ) (masker): RandomObfuscator() )
Ignoring the masker
, which just...well...masks, the ed_model
consists of:
- An encoder model that is a
TabMlp
model that is in itself comprised by an Embedding layer (or rather a collection of them, referred ascat_and_cont_embed
) and an encoder (a simple MLP, referred asencoder
) - A decoder which is just an "inverted" MLP (referred as
decoder
)
# And we just...pretrain
encoder_decoder_trainer.pretrain(X_tab, n_epochs=5, batch_size=256)
epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 82.90it/s, loss=4.07] epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 89.87it/s, loss=3.09] epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 92.86it/s, loss=2.53] epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 91.24it/s, loss=2.09] epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 91.38it/s, loss=1.78]
At this point we have two options, we could either save the model for later use or we could continue to supervised training. The latter is rather simple, after running:
encoder_decoder_trainer.pretrain(X_tab, n_epochs=5, batch_size=256)
you just have to
model = WideDeep(deeptabular=tab_mlp)
trainer = Trainer(model=model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=5, batch_size=256)
# And, you know...we get a test metric
X_tab_te = tab_preprocessor.transform(df_te)
target_te = df_te[target_col].values
preds = trainer.predict(X_tab=X_tab_te)
test_acc = accuracy_score(target_te, preds)
Let's say that in any case, we are 'decent' scientists/people and we want to save the model:
encoder_decoder_trainer.save(
path="pretrained_weights", model_filename="encoder_decoder_model.pt"
)
some time has passed...
encoder_decoder_model = torch.load("pretrained_weights/encoder_decoder_model.pt")
Now, AND THIS IS IMPORTANT We have loaded the encoder AND the decoder. To proceed to the supervised training we ONLY need the encoder
pretrained_encoder = encoder_decoder_model.encoder
pretrained_encoder
TabMlp( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_workclass): Embedding(10, 5, padding_idx=0) (emb_layer_education): Embedding(17, 8, padding_idx=0) (emb_layer_marital_status): Embedding(8, 5, padding_idx=0) (emb_layer_occupation): Embedding(16, 7, padding_idx=0) (emb_layer_relationship): Embedding(7, 4, padding_idx=0) (emb_layer_race): Embedding(6, 4, padding_idx=0) (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0) (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0) (emb_layer_native_country): Embedding(42, 13, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): MLP( (mlp): Sequential( (dense_layer_0): Sequential( (0): Linear(in_features=95, out_features=200, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) (dense_layer_1): Sequential( (0): Linear(in_features=200, out_features=100, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.1, inplace=False) ) ) ) )
# and as always, ANY supervised model in this library has to go throuth the WideDeep class:
model = WideDeep(deeptabular=pretrained_encoder)
trainer = Trainer(model=model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=5, batch_size=256)
X_tab_te = tab_preprocessor.transform(df_te)
target_te = df_te[target_col].values
preds = trainer.predict(X_tab=X_tab_te)
test_acc = accuracy_score(target_te, preds)
print(test_acc)
epoch 1: 100%|██████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 88.04it/s, loss=0.374, metrics={'acc': 0.8253}] epoch 2: 100%|██████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 85.63it/s, loss=0.324, metrics={'acc': 0.8491}] epoch 3: 100%|██████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 87.56it/s, loss=0.301, metrics={'acc': 0.8608}] epoch 4: 100%|███████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 73.38it/s, loss=0.29, metrics={'acc': 0.8655}] epoch 5: 100%|██████████████████████████████████████████████████████████| 153/153 [00:01<00:00, 78.68it/s, loss=0.284, metrics={'acc': 0.8686}] predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 173.02it/s]
0.8730678677449074
As we mentioned before, we can also use a TabResNet
or TabNet
model and a custom decoder. Let's have a look:
from pytorch_widedeep.models import TabResnet as TabResnetEncoder, TabResnetDecoder
resnet_encoder = TabResnetEncoder(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=continuous_cols,
blocks_dims=[200, 100, 100],
)
let's have a look to the model
resnet_encoder
TabResnet( (cat_embed): DiffSizeCatEmbeddings( (embed_layers): ModuleDict( (emb_layer_workclass): Embedding(10, 5, padding_idx=0) (emb_layer_education): Embedding(17, 8, padding_idx=0) (emb_layer_marital_status): Embedding(8, 5, padding_idx=0) (emb_layer_occupation): Embedding(16, 7, padding_idx=0) (emb_layer_relationship): Embedding(7, 4, padding_idx=0) (emb_layer_race): Embedding(6, 4, padding_idx=0) (emb_layer_gender): Embedding(3, 2, padding_idx=0) (emb_layer_capital_gain): Embedding(124, 24, padding_idx=0) (emb_layer_capital_loss): Embedding(98, 21, padding_idx=0) (emb_layer_native_country): Embedding(42, 13, padding_idx=0) ) (embedding_dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (encoder): DenseResnet( (dense_resnet): Sequential( (lin_inp): Linear(in_features=95, out_features=200, bias=False) (bn_inp): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (block_0): BasicBlock( (resize): Sequential( (0): Linear(in_features=200, out_features=100, bias=False) (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (lin1): Linear(in_features=200, out_features=100, bias=False) (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True) (dp): Dropout(p=0.1, inplace=False) (lin2): Linear(in_features=100, out_features=100, bias=False) (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (block_1): BasicBlock( (lin1): Linear(in_features=100, out_features=100, bias=False) (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True) (dp): Dropout(p=0.1, inplace=False) (lin2): Linear(in_features=100, out_features=100, bias=False) (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) )
As we can see, the tensor we are trying to reconstruct, the embeddings, is of size 94
(this number is stored in the attribute: esnet_encoder.cat_and_cont_embed.output_dim
), with that information we could build or own decoder as:
# for all possible params see the docs
resnet_decoder = TabResnetDecoder(
embed_dim=resnet_encoder.cat_out_dim + resnet_encoder.cont_out_dim,
blocks_dims=[100, 100, 200],
)
resnet_decoder
TabResnetDecoder( (decoder): DenseResnet( (dense_resnet): Sequential( (block_0): BasicBlock( (lin1): Linear(in_features=100, out_features=100, bias=False) (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True) (dp): Dropout(p=0.1, inplace=False) (lin2): Linear(in_features=100, out_features=100, bias=False) (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (block_1): BasicBlock( (resize): Sequential( (0): Linear(in_features=100, out_features=200, bias=False) (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (lin1): Linear(in_features=100, out_features=200, bias=False) (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (leaky_relu): LeakyReLU(negative_slope=0.01, inplace=True) (dp): Dropout(p=0.1, inplace=False) (lin2): Linear(in_features=200, out_features=200, bias=False) (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (reconstruction_layer): Linear(in_features=200, out_features=95, bias=False) )
and now:
ec_trainer = EncoderDecoderTrainer(
encoder=resnet_encoder,
decoder=resnet_decoder,
masked_prob=0.2,
)
ec_trainer.pretrain(X_tab, n_epochs=5, batch_size=256)
epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:03<00:00, 46.89it/s, loss=1.52] epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:03<00:00, 46.78it/s, loss=0.81] epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 153/153 [00:03<00:00, 39.82it/s, loss=0.56] epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████| 153/153 [00:03<00:00, 46.73it/s, loss=0.417] epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████| 153/153 [00:03<00:00, 46.24it/s, loss=0.329]
# and as always, ANY supervised model in this library has to go throuth the WideDeep class:
model = WideDeep(deeptabular=resnet_encoder)
trainer = Trainer(model=model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=5, batch_size=256)
X_tab_te = tab_preprocessor.transform(df_te)
target_te = df_te[target_col].values
preds = trainer.predict(X_tab=X_tab_te)
test_acc = accuracy_score(target_te, preds)
print(test_acc)
epoch 1: 100%|██████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 58.63it/s, loss=0.335, metrics={'acc': 0.8442}] epoch 2: 100%|███████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 58.02it/s, loss=0.296, metrics={'acc': 0.864}] epoch 3: 100%|██████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 55.91it/s, loss=0.283, metrics={'acc': 0.8687}] epoch 4: 100%|███████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 55.00it/s, loss=0.276, metrics={'acc': 0.871}] epoch 5: 100%|██████████████████████████████████████████████████████████| 153/153 [00:02<00:00, 51.95it/s, loss=0.272, metrics={'acc': 0.8732}] predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 120.15it/s]
0.8725560446309756