15_Self-Supervised Pre-Training pt 2
Self Supervised Pretraining for Tabular Data¶
We have implemented two Self Supervised Pre-training routines that allow the user to pre-train all tabular models in the library with the exception of the TabPerceiver (which is a special monster).
The two routines implemented are illustrated in the figures below. The 1st is from TabNet: Attentive Interpretable Tabular Learning and is designed for models that do not use transformer-based architectures, while the second is from SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training, and is designed for models that use transformer-based architectures.
Fig 1. Figure 2 in their paper. I have included de original caption in case is useful, althought the Figure itself is pretty self explanatory
Fig 2. Figure 1 in their paper. Here the caption is necessary 😏
It is beyond the scope of this notebook to explain in detail those implementations. Therefore, we strongly recommend the user to go and read the papers if this functionality is of interest to her/him.
One thing is worth noticing however. As seen in Fig 1(the TabNet paper's Fig 2) the masking of the input features happens in the feature space. However, the implementation in this library is inspired by that at the dreamquark-ai repo, which is in itself inspired by the original implementation (by the way, at this point I will write it once again. All TabNet related things in this library are inspired when not directly based in the code in that repo, therefore, ALL CREDIT TO THE GUYS AT dreamquark-ai).
In that implementation the masking happens in the embedding space, and currently does not mask the entire embedding (i.e. categorical feature). We decided to release as it is in this version and we will implement the exact same process described in the paper in future releases.
Having said all of the above let's see how to use self supervision for tabular data with pytorch-widedeep
. We will concentrate in this notebook on the 2nd of the two approaches (the 'SAINT approach'). For details on the 1st approach (the 'TabNet' approach) please see 16_Self_Supervised_Pretraning_pt1
.
Self Supervision transformer-based models..¶
...or in general, for models where the embeddigns have all the same dimensions. In this library, these are:
- TabTransformer
- FTTransformer
- SAINT
- TabFastFormer
Note that there is one additional Transformer-based model, the TabPerceiver
, however this is a "particular" model and at the moment we do not support self supervision for it, but it will come.
Let see at one example using the FTTransformer
.
import torch
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import WideDeep, FTTransformer
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.self_supervised_training import (
ContrastiveDenoisingTrainer,
)
/Users/javierrodriguezzaurin/.pyenv/versions/3.10.15/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
# one could chose to use a validation set for early stopping, hyperparam
# optimization, etc. This is just an example, so we simply use train/test
# split
df_tr, df_te = train_test_split(df, test_size=0.2, stratify=df.income_label)
cat_embed_cols = [
"workclass",
"education",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"capital_gain",
"capital_loss",
"native_country",
]
continuous_cols = ["age", "hours_per_week"]
target_col = "income_label"
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols,
continuous_cols=continuous_cols,
with_attention=True,
with_cls_token=True, # this is optional
)
X_tab = tab_preprocessor.fit_transform(df_tr)
target = df_tr[target_col].values
/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:364: UserWarning: Continuous columns will not be normalised warnings.warn("Continuous columns will not be normalised")
ft_transformer = FTTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
embed_continuous_method="standard",
input_dim=32,
kv_compression_factor=0.5,
n_blocks=3,
n_heads=4,
)
# for a full list of the params for the the ContrastiveDenoisingTrainer (which are many) please see the docs.
# Note that using these params involves some knowledge of the routine and the architecture of the model used
contrastive_denoising_trainer = ContrastiveDenoisingTrainer(
model=ft_transformer,
preprocessor=tab_preprocessor,
)
contrastive_denoising_trainer.pretrain(X_tab, n_epochs=5, batch_size=256)
epoch 1: 100%|██████████| 153/153 [00:22<00:00, 6.91it/s, loss=656] epoch 2: 100%|██████████| 153/153 [00:07<00:00, 20.84it/s, loss=141] epoch 3: 100%|██████████| 153/153 [00:07<00:00, 20.75it/s, loss=137] epoch 4: 100%|██████████| 153/153 [00:07<00:00, 21.14it/s, loss=135] epoch 5: 100%|██████████| 153/153 [00:07<00:00, 21.05it/s, loss=134]
contrastive_denoising_trainer.save(
path="pretrained_weights", model_filename="contrastive_denoising_model.pt"
)
some time has passed
# some time has passed, we load the model with torch as usual:
contrastive_denoising_model = torch.load(
"pretrained_weights/contrastive_denoising_model.pt"
)
/var/folders/pd/_2wz_qt16yq1fk6jn_xtxqk40000gn/T/ipykernel_71058/884975850.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. contrastive_denoising_model = torch.load(
NOW, AND THIS IS IMPORTANT! We have loaded the entire contrastive, denoising model. To proceed to the supervised training we ONLY need the attention-based model, which is the 'model' attribute of the trainer, let's have a look
contrastive_denoising_model.model
FTTransformer( (cat_embed): SameSizeCatEmbeddings( (embed): Embedding(322, 32, padding_idx=0) (dropout): Dropout(p=0.0, inplace=False) ) (cont_norm): Identity() (cont_embed): ContEmbeddings( INFO: [ContLinear = weight(n_cont_cols, embed_dim) + bias(n_cont_cols, embed_dim)] (linear): ContLinear(n_cont_cols=2, embed_dim=32, embed_dropout=0.0) (dropout): Dropout(p=0.0, inplace=False) ) (encoder): Sequential( (fttransformer_block0): FTTransformerEncoder( (attn): LinearAttentionLinformer( (dropout): Dropout(p=0.2, inplace=False) (qkv_proj): Linear(in_features=32, out_features=96, bias=False) (out_proj): Linear(in_features=32, out_features=32, bias=False) ) (ff): FeedForward( (w_1): Linear(in_features=32, out_features=84, bias=True) (w_2): Linear(in_features=42, out_features=32, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation): REGLU() ) (attn_normadd): NormAdd( (dropout): Dropout(p=0.2, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (ff_normadd): NormAdd( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (fttransformer_block1): FTTransformerEncoder( (attn): LinearAttentionLinformer( (dropout): Dropout(p=0.2, inplace=False) (qkv_proj): Linear(in_features=32, out_features=96, bias=False) (out_proj): Linear(in_features=32, out_features=32, bias=False) ) (ff): FeedForward( (w_1): Linear(in_features=32, out_features=84, bias=True) (w_2): Linear(in_features=42, out_features=32, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation): REGLU() ) (attn_normadd): NormAdd( (dropout): Dropout(p=0.2, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (ff_normadd): NormAdd( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (fttransformer_block2): FTTransformerEncoder( (attn): LinearAttentionLinformer( (dropout): Dropout(p=0.2, inplace=False) (qkv_proj): Linear(in_features=32, out_features=96, bias=False) (out_proj): Linear(in_features=32, out_features=32, bias=False) ) (ff): FeedForward( (w_1): Linear(in_features=32, out_features=84, bias=True) (w_2): Linear(in_features=42, out_features=32, bias=True) (dropout): Dropout(p=0.1, inplace=False) (activation): REGLU() ) (attn_normadd): NormAdd( (dropout): Dropout(p=0.2, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (ff_normadd): NormAdd( (dropout): Dropout(p=0.1, inplace=False) (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) ) )
pretrained_model = contrastive_denoising_model.model
# and as always, ANY supervised model in this library has to go throuth the WideDeep class:
model = WideDeep(deeptabular=pretrained_model)
trainer = Trainer(model=model, objective="binary", metrics=[Accuracy])
trainer.fit(X_tab=X_tab, target=target, n_epochs=5, batch_size=256)
# And, you know...we get a test metric
X_tab_te = tab_preprocessor.transform(df_te)
target_te = df_te[target_col].values
preds = trainer.predict(X_tab=X_tab_te)
test_acc = accuracy_score(target_te, preds)
print(test_acc)
epoch 1: 100%|██████████| 153/153 [00:03<00:00, 50.99it/s, loss=0.383, metrics={'acc': 0.8218}] epoch 2: 100%|██████████| 153/153 [00:02<00:00, 53.22it/s, loss=0.328, metrics={'acc': 0.8505}] epoch 3: 100%|██████████| 153/153 [00:02<00:00, 54.50it/s, loss=0.309, metrics={'acc': 0.8588}] epoch 4: 100%|██████████| 153/153 [00:02<00:00, 53.90it/s, loss=0.299, metrics={'acc': 0.8641}] epoch 5: 100%|██████████| 153/153 [00:02<00:00, 54.02it/s, loss=0.292, metrics={'acc': 0.8665}] predict: 100%|██████████| 39/39 [00:00<00:00, 88.79it/s]
0.8679496366055891