17_feature_importance_via_attention_weights
Feature Importance via the attention weights¶
I will start by saying that I consider this feature of the library purely experimental. First of all I think there are multiple ways one could address finding the features importances for these models. However, and more importantly, one has to bear in mind that even tree-based algorithms on the same dataset produce different feature importances. This is more "dramatic" if one uses different techniques, such as shap or feature permutation (see for example this and references therein). All this to say that, sometimes, feature importance is just a measure contained within the experiment run, and for the model used.
With that in mind, each instantiation of a deep tabular model, that has millions of trainable parameters, will potentially produce a different set of feature importances, even if the model has the same architecture. Moreover, this effect will become more apparent if the dataset is relatively easy and there are dependent/related columns so that one could get to the same success metric with different parameters.
In summary, feature importances are implemented in this librray for all attention-based models for tabular data, with the exception of the TabPerceiver
. However this functionality has to be used and interpreted with care and consider of value within the 'universe' (or context) of the model with which these features were produced.
Nonetheless, let's have a look to how one would access to the feature importances when using this library.
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import TabTransformer, ContextAttentionMLP, WideDeep
from pytorch_widedeep.callbacks import EarlyStopping
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
# use_cuda = torch.cuda.is_available()
df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop(["income", "fnlwgt", "educational_num"], axis=1, inplace=True)
target_colname = "income_label"
df.head()
age | workclass | education | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 25 | Private | 11th | Never-married | Machine-op-inspct | Own-child | Black | Male | 0 | 0 | 40 | United-States | 0 |
1 | 38 | Private | HS-grad | Married-civ-spouse | Farming-fishing | Husband | White | Male | 0 | 0 | 50 | United-States | 0 |
2 | 28 | Local-gov | Assoc-acdm | Married-civ-spouse | Protective-serv | Husband | White | Male | 0 | 0 | 40 | United-States | 1 |
3 | 44 | Private | Some-college | Married-civ-spouse | Machine-op-inspct | Husband | Black | Male | 7688 | 0 | 40 | United-States | 1 |
4 | 18 | ? | Some-college | Never-married | ? | Own-child | White | Female | 0 | 0 | 30 | United-States | 0 |
cat_embed_cols = []
for col in df.columns:
if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
cat_embed_cols.append(col)
# all cols will be categorical
assert len(cat_embed_cols) == df.shape[1] - 1
train, test = train_test_split(
df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
)
tab_preprocessor = TabPreprocessor(cat_embed_cols=cat_embed_cols, with_attention=True)
X_tab_train = tab_preprocessor.fit_transform(train)
X_tab_test = tab_preprocessor.transform(test)
target = train[target_colname].values
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
embed_continuous_method="standard",
input_dim=8,
n_heads=2,
n_blocks=1,
attn_dropout=0.1,
transformer_activation="relu",
)
model = WideDeep(deeptabular=tab_transformer)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
threshold=0.001,
threshold_mode="abs",
patience=10,
)
early_stopping = EarlyStopping(
min_delta=0.001, patience=30, restore_best_weights=True, verbose=True
)
trainer = Trainer(
model,
objective="binary",
optimizers=optimizer,
lr_schedulers=lr_scheduler,
reducelronplateau_criterion="loss",
callbacks=[early_stopping],
metrics=[Accuracy],
)
The feature importances will be computed after training, using a sample of the training dataset of size feature_importance_sample_size
trainer.fit(
X_tab=X_tab_train,
target=target,
val_split=0.2,
n_epochs=100,
batch_size=128,
validation_freq=1,
feature_importance_sample_size=1000,
)
epoch 1: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 81.80it/s, loss=0.334, metrics={'acc': 0.847}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 111.34it/s, loss=0.294, metrics={'acc': 0.8669}] epoch 2: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 83.02it/s, loss=0.293, metrics={'acc': 0.8656}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 124.03it/s, loss=0.283, metrics={'acc': 0.8678}] epoch 3: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 87.69it/s, loss=0.282, metrics={'acc': 0.8703}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.22it/s, loss=0.279, metrics={'acc': 0.8717}] epoch 4: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.24it/s, loss=0.277, metrics={'acc': 0.8718}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.29it/s, loss=0.277, metrics={'acc': 0.8731}] epoch 5: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 83.76it/s, loss=0.275, metrics={'acc': 0.8727}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.80it/s, loss=0.276, metrics={'acc': 0.8727}] epoch 6: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 82.78it/s, loss=0.273, metrics={'acc': 0.873}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 100.43it/s, loss=0.276, metrics={'acc': 0.871}] epoch 7: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.35it/s, loss=0.271, metrics={'acc': 0.8742}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 116.14it/s, loss=0.275, metrics={'acc': 0.8726}] epoch 8: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.29it/s, loss=0.271, metrics={'acc': 0.875}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.46it/s, loss=0.276, metrics={'acc': 0.8718}] epoch 9: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.10it/s, loss=0.27, metrics={'acc': 0.8761}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 105.49it/s, loss=0.275, metrics={'acc': 0.8728}] epoch 10: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 70.40it/s, loss=0.269, metrics={'acc': 0.8747}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 105.47it/s, loss=0.275, metrics={'acc': 0.8726}] epoch 11: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 72.83it/s, loss=0.268, metrics={'acc': 0.8742}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 105.03it/s, loss=0.274, metrics={'acc': 0.873}] epoch 12: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 71.86it/s, loss=0.267, metrics={'acc': 0.8743}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 106.61it/s, loss=0.274, metrics={'acc': 0.8734}] epoch 13: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 72.39it/s, loss=0.267, metrics={'acc': 0.876}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 108.05it/s, loss=0.275, metrics={'acc': 0.8717}] epoch 14: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.36it/s, loss=0.265, metrics={'acc': 0.8767}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 109.60it/s, loss=0.276, metrics={'acc': 0.8747}] epoch 15: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.34it/s, loss=0.264, metrics={'acc': 0.876}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.55it/s, loss=0.276, metrics={'acc': 0.8706}] epoch 16: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.35it/s, loss=0.264, metrics={'acc': 0.8777}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.08it/s, loss=0.275, metrics={'acc': 0.8753}] epoch 17: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.96it/s, loss=0.263, metrics={'acc': 0.877}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.83it/s, loss=0.277, metrics={'acc': 0.8739}] epoch 18: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.38it/s, loss=0.263, metrics={'acc': 0.8779}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.18it/s, loss=0.278, metrics={'acc': 0.8714}] epoch 19: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.61it/s, loss=0.261, metrics={'acc': 0.8784}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 119.19it/s, loss=0.278, metrics={'acc': 0.8712}] epoch 20: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 72.43it/s, loss=0.261, metrics={'acc': 0.8791}] valid: 100%|███████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.87it/s, loss=0.28, metrics={'acc': 0.873}] epoch 21: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.97it/s, loss=0.26, metrics={'acc': 0.8787}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 107.50it/s, loss=0.279, metrics={'acc': 0.8732}] epoch 22: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 71.76it/s, loss=0.253, metrics={'acc': 0.8816}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 108.11it/s, loss=0.279, metrics={'acc': 0.8707}] epoch 23: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 71.92it/s, loss=0.252, metrics={'acc': 0.8828}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 100.14it/s, loss=0.28, metrics={'acc': 0.8711}] epoch 24: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.04it/s, loss=0.252, metrics={'acc': 0.8829}] valid: 100%|███████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 98.36it/s, loss=0.28, metrics={'acc': 0.8708}] epoch 25: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.31it/s, loss=0.251, metrics={'acc': 0.883}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.97it/s, loss=0.281, metrics={'acc': 0.8709}] epoch 26: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.63it/s, loss=0.25, metrics={'acc': 0.8834}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 123.07it/s, loss=0.281, metrics={'acc': 0.8698}] epoch 27: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.37it/s, loss=0.251, metrics={'acc': 0.884}] valid: 100%|███████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.75it/s, loss=0.281, metrics={'acc': 0.87}] epoch 28: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.19it/s, loss=0.25, metrics={'acc': 0.883}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.99it/s, loss=0.282, metrics={'acc': 0.8699}] epoch 29: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.19it/s, loss=0.25, metrics={'acc': 0.8829}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 111.11it/s, loss=0.282, metrics={'acc': 0.8695}] epoch 30: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.57it/s, loss=0.249, metrics={'acc': 0.8839}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.64it/s, loss=0.283, metrics={'acc': 0.8689}] epoch 31: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.55it/s, loss=0.249, metrics={'acc': 0.8846}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 125.10it/s, loss=0.283, metrics={'acc': 0.869}] epoch 32: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 82.56it/s, loss=0.248, metrics={'acc': 0.8841}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.45it/s, loss=0.284, metrics={'acc': 0.8687}] epoch 33: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 81.06it/s, loss=0.248, metrics={'acc': 0.8848}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 129.26it/s, loss=0.284, metrics={'acc': 0.8689}] epoch 34: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.53it/s, loss=0.248, metrics={'acc': 0.8854}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.61it/s, loss=0.283, metrics={'acc': 0.869}] epoch 35: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.78it/s, loss=0.248, metrics={'acc': 0.8853}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 127.31it/s, loss=0.283, metrics={'acc': 0.8694}] epoch 36: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 82.51it/s, loss=0.248, metrics={'acc': 0.8863}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 125.94it/s, loss=0.283, metrics={'acc': 0.8693}] epoch 37: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 81.35it/s, loss=0.247, metrics={'acc': 0.8844}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.77it/s, loss=0.283, metrics={'acc': 0.8692}] epoch 38: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.62it/s, loss=0.248, metrics={'acc': 0.8837}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 116.62it/s, loss=0.283, metrics={'acc': 0.8692}] epoch 39: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.82it/s, loss=0.248, metrics={'acc': 0.8842}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 123.64it/s, loss=0.283, metrics={'acc': 0.8695}] epoch 40: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.86it/s, loss=0.247, metrics={'acc': 0.8855}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.34it/s, loss=0.283, metrics={'acc': 0.8692}]
Best Epoch: 10. Best val_loss: 0.27451 Restoring model weights from the end of the best epoch
trainer.feature_importance
{'age': 0.09718182, 'workclass': 0.090637445, 'education': 0.08910798, 'marital_status': 0.08971319, 'occupation': 0.12546304, 'relationship': 0.086381145, 'race': 0.050686445, 'gender': 0.05116429, 'capital_gain': 0.08165918, 'capital_loss': 0.07702667, 'hours_per_week': 0.08205996, 'native_country': 0.07891885}
preds = trainer.predict(X_tab=X_tab_test)
predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 199.63it/s]
accuracy_score(preds, test.income_label)
0.8685772773797339
test.reset_index(drop=True, inplace=True)
test[test.income_label == 0].head(1)
age | workclass | education | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 26 | Private | Some-college | Never-married | Exec-managerial | Not-in-family | White | Male | 0 | 0 | 60 | United-States | 0 |
test[test.income_label == 1].head(1)
age | workclass | education | marital_status | occupation | relationship | race | gender | capital_gain | capital_loss | hours_per_week | native_country | income_label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | 36 | Local-gov | Doctorate | Married-civ-spouse | Prof-specialty | Husband | White | Male | 0 | 1887 | 50 | United-States | 1 |
To get the feature importance of a test dataset, simply use the explain
method
feat_imp_per_sample = trainer.explain(X_tab_test, save_step_masks=False)
list(test.iloc[0].index[np.argsort(-feat_imp_per_sample[0])])
['hours_per_week', 'education', 'relationship', 'occupation', 'workclass', 'capital_gain', 'native_country', 'marital_status', 'capital_loss', 'age', 'race', 'gender']
list(test.iloc[3].index[np.argsort(-feat_imp_per_sample[3])])
['age', 'capital_loss', 'hours_per_week', 'marital_status', 'native_country', 'relationship', 'race', 'education', 'occupation', 'capital_gain', 'gender', 'workclass']
We could do the same with the ContextAttentionMLP
context_attn_mlp = ContextAttentionMLP(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
cat_embed_dropout=0.0,
input_dim=16,
attn_dropout=0.1,
attn_activation="relu",
)
mlp_model = WideDeep(deeptabular=context_attn_mlp)
mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.01, weight_decay=0.0)
mlp_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
mlp_optimizer,
threshold=0.001,
threshold_mode="abs",
patience=10,
)
mlp_early_stopping = EarlyStopping(
min_delta=0.001, patience=30, restore_best_weights=True, verbose=True
)
mlp_trainer = Trainer(
mlp_model,
objective="binary",
optimizers=mlp_optimizer,
lr_schedulers=mlp_lr_scheduler,
reducelronplateau_criterion="loss",
callbacks=[mlp_early_stopping],
metrics=[Accuracy],
)
mlp_trainer.fit(
X_tab=X_tab_train,
target=target,
val_split=0.2,
n_epochs=100,
batch_size=128,
validation_freq=1,
feature_importance_sample_size=1000,
)
epoch 1: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.11it/s, loss=0.405, metrics={'acc': 0.8094}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 119.26it/s, loss=0.309, metrics={'acc': 0.8583}] epoch 2: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 71.70it/s, loss=0.332, metrics={'acc': 0.8447}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 116.36it/s, loss=0.293, metrics={'acc': 0.8646}] epoch 3: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.42it/s, loss=0.319, metrics={'acc': 0.8505}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.05it/s, loss=0.293, metrics={'acc': 0.8654}] epoch 4: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.00it/s, loss=0.312, metrics={'acc': 0.8554}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 116.49it/s, loss=0.291, metrics={'acc': 0.8661}] epoch 5: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.50it/s, loss=0.308, metrics={'acc': 0.8583}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.48it/s, loss=0.287, metrics={'acc': 0.8669}] epoch 6: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.84it/s, loss=0.303, metrics={'acc': 0.8605}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 128.70it/s, loss=0.288, metrics={'acc': 0.8673}] epoch 7: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.93it/s, loss=0.301, metrics={'acc': 0.8597}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 121.47it/s, loss=0.298, metrics={'acc': 0.8628}] epoch 8: 100%|████████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.56it/s, loss=0.3, metrics={'acc': 0.8592}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 119.84it/s, loss=0.281, metrics={'acc': 0.8718}] epoch 9: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.17it/s, loss=0.298, metrics={'acc': 0.8619}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.32it/s, loss=0.28, metrics={'acc': 0.8716}] epoch 10: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 82.13it/s, loss=0.297, metrics={'acc': 0.8615}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.50it/s, loss=0.281, metrics={'acc': 0.8718}] epoch 11: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 82.54it/s, loss=0.293, metrics={'acc': 0.8641}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.57it/s, loss=0.284, metrics={'acc': 0.867}] epoch 12: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 80.92it/s, loss=0.293, metrics={'acc': 0.863}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.42it/s, loss=0.282, metrics={'acc': 0.8701}] epoch 13: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.61it/s, loss=0.293, metrics={'acc': 0.8635}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.56it/s, loss=0.276, metrics={'acc': 0.8719}] epoch 14: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.92it/s, loss=0.29, metrics={'acc': 0.8633}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 112.06it/s, loss=0.286, metrics={'acc': 0.8669}] epoch 15: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.83it/s, loss=0.291, metrics={'acc': 0.865}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 112.88it/s, loss=0.282, metrics={'acc': 0.8677}] epoch 16: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.74it/s, loss=0.29, metrics={'acc': 0.8653}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.85it/s, loss=0.285, metrics={'acc': 0.8672}] epoch 17: 100%|███████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.99it/s, loss=0.29, metrics={'acc': 0.865}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 113.53it/s, loss=0.282, metrics={'acc': 0.8681}] epoch 18: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 71.22it/s, loss=0.288, metrics={'acc': 0.8651}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 112.89it/s, loss=0.288, metrics={'acc': 0.8676}] epoch 19: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.85it/s, loss=0.29, metrics={'acc': 0.8661}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.26it/s, loss=0.284, metrics={'acc': 0.8662}] epoch 20: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.80it/s, loss=0.289, metrics={'acc': 0.8661}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 119.44it/s, loss=0.281, metrics={'acc': 0.8703}] epoch 21: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.72it/s, loss=0.29, metrics={'acc': 0.8661}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 113.04it/s, loss=0.285, metrics={'acc': 0.8648}] epoch 22: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.86it/s, loss=0.289, metrics={'acc': 0.8656}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 111.75it/s, loss=0.282, metrics={'acc': 0.8666}] epoch 23: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.68it/s, loss=0.289, metrics={'acc': 0.8668}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.86it/s, loss=0.282, metrics={'acc': 0.8724}] epoch 24: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.93it/s, loss=0.288, metrics={'acc': 0.8653}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.69it/s, loss=0.285, metrics={'acc': 0.8656}] epoch 25: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.99it/s, loss=0.284, metrics={'acc': 0.8671}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.27it/s, loss=0.277, metrics={'acc': 0.8707}] epoch 26: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.86it/s, loss=0.282, metrics={'acc': 0.8686}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 123.94it/s, loss=0.276, metrics={'acc': 0.8712}] epoch 27: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.86it/s, loss=0.283, metrics={'acc': 0.8691}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.11it/s, loss=0.277, metrics={'acc': 0.8716}] epoch 28: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.43it/s, loss=0.281, metrics={'acc': 0.8696}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.64it/s, loss=0.277, metrics={'acc': 0.8712}] epoch 29: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.39it/s, loss=0.281, metrics={'acc': 0.8696}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.83it/s, loss=0.277, metrics={'acc': 0.872}] epoch 30: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.93it/s, loss=0.28, metrics={'acc': 0.8706}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 113.07it/s, loss=0.275, metrics={'acc': 0.8714}] epoch 31: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.35it/s, loss=0.281, metrics={'acc': 0.8697}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 112.68it/s, loss=0.276, metrics={'acc': 0.872}] epoch 32: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.72it/s, loss=0.28, metrics={'acc': 0.8693}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.50it/s, loss=0.276, metrics={'acc': 0.8709}] epoch 33: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.68it/s, loss=0.28, metrics={'acc': 0.8716}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 110.07it/s, loss=0.277, metrics={'acc': 0.8709}] epoch 34: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.58it/s, loss=0.279, metrics={'acc': 0.8704}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 111.07it/s, loss=0.274, metrics={'acc': 0.8719}] epoch 35: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.03it/s, loss=0.28, metrics={'acc': 0.8687}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 110.15it/s, loss=0.276, metrics={'acc': 0.871}] epoch 36: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.11it/s, loss=0.279, metrics={'acc': 0.8706}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 111.10it/s, loss=0.278, metrics={'acc': 0.8705}] epoch 37: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 74.30it/s, loss=0.279, metrics={'acc': 0.869}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 110.75it/s, loss=0.279, metrics={'acc': 0.8702}] epoch 38: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 72.34it/s, loss=0.28, metrics={'acc': 0.8691}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.79it/s, loss=0.277, metrics={'acc': 0.8698}] epoch 39: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.49it/s, loss=0.279, metrics={'acc': 0.8694}] valid: 100%|███████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.25it/s, loss=0.279, metrics={'acc': 0.87}] epoch 40: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.51it/s, loss=0.28, metrics={'acc': 0.8694}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 110.90it/s, loss=0.277, metrics={'acc': 0.8694}] epoch 41: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.35it/s, loss=0.278, metrics={'acc': 0.8716}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.74it/s, loss=0.28, metrics={'acc': 0.8675}] epoch 42: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.35it/s, loss=0.279, metrics={'acc': 0.8695}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.76it/s, loss=0.277, metrics={'acc': 0.8699}] epoch 43: 100%|█████████████████████████████████████████████████████████| 275/275 [00:04<00:00, 66.14it/s, loss=0.279, metrics={'acc': 0.8681}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 106.20it/s, loss=0.277, metrics={'acc': 0.8714}] epoch 44: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.37it/s, loss=0.279, metrics={'acc': 0.8704}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.18it/s, loss=0.277, metrics={'acc': 0.8716}] epoch 45: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.23it/s, loss=0.278, metrics={'acc': 0.8702}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 124.83it/s, loss=0.278, metrics={'acc': 0.8707}] epoch 46: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.12it/s, loss=0.278, metrics={'acc': 0.8704}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.62it/s, loss=0.279, metrics={'acc': 0.8693}] epoch 47: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.55it/s, loss=0.276, metrics={'acc': 0.8713}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.99it/s, loss=0.279, metrics={'acc': 0.8691}] epoch 48: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.25it/s, loss=0.278, metrics={'acc': 0.8719}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 124.52it/s, loss=0.278, metrics={'acc': 0.8695}] epoch 49: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.35it/s, loss=0.277, metrics={'acc': 0.8721}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 119.82it/s, loss=0.279, metrics={'acc': 0.8691}] epoch 50: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.15it/s, loss=0.277, metrics={'acc': 0.8717}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 122.62it/s, loss=0.278, metrics={'acc': 0.8699}] epoch 51: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.55it/s, loss=0.277, metrics={'acc': 0.8713}] valid: 100%|███████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 117.63it/s, loss=0.278, metrics={'acc': 0.87}] epoch 52: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.04it/s, loss=0.276, metrics={'acc': 0.8721}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.39it/s, loss=0.278, metrics={'acc': 0.8697}] epoch 53: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.15it/s, loss=0.277, metrics={'acc': 0.8721}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 127.56it/s, loss=0.278, metrics={'acc': 0.8699}] epoch 54: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.41it/s, loss=0.277, metrics={'acc': 0.8711}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.95it/s, loss=0.278, metrics={'acc': 0.8698}] epoch 55: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 76.35it/s, loss=0.277, metrics={'acc': 0.8718}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 126.90it/s, loss=0.278, metrics={'acc': 0.8699}] epoch 56: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.83it/s, loss=0.277, metrics={'acc': 0.8707}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 115.13it/s, loss=0.279, metrics={'acc': 0.8691}] epoch 57: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.20it/s, loss=0.277, metrics={'acc': 0.8722}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 123.16it/s, loss=0.279, metrics={'acc': 0.8691}] epoch 58: 100%|██████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 73.33it/s, loss=0.276, metrics={'acc': 0.871}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 123.37it/s, loss=0.278, metrics={'acc': 0.8691}] epoch 59: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.41it/s, loss=0.277, metrics={'acc': 0.8714}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 125.17it/s, loss=0.278, metrics={'acc': 0.8695}] epoch 60: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 79.26it/s, loss=0.276, metrics={'acc': 0.8721}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 120.60it/s, loss=0.278, metrics={'acc': 0.869}] epoch 61: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 75.88it/s, loss=0.278, metrics={'acc': 0.8703}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 124.47it/s, loss=0.278, metrics={'acc': 0.8692}] epoch 62: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.13it/s, loss=0.276, metrics={'acc': 0.8711}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 125.80it/s, loss=0.278, metrics={'acc': 0.8691}] epoch 63: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 77.20it/s, loss=0.277, metrics={'acc': 0.8715}] valid: 100%|█████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 118.50it/s, loss=0.278, metrics={'acc': 0.8695}] epoch 64: 100%|█████████████████████████████████████████████████████████| 275/275 [00:03<00:00, 78.11it/s, loss=0.276, metrics={'acc': 0.8719}] valid: 100%|██████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 114.52it/s, loss=0.278, metrics={'acc': 0.869}]
Best Epoch: 34. Best val_loss: 0.27449 Restoring model weights from the end of the best epoch
mlp_trainer.feature_importance
{'age': 0.116632804, 'workclass': 0.050255153, 'education': 0.094621316, 'marital_status': 0.12328919, 'occupation': 0.107893184, 'relationship': 0.11747801, 'race': 0.054717205, 'gender': 0.07514235, 'capital_gain': 0.059732802, 'capital_loss': 0.06738944, 'hours_per_week': 0.0610674, 'native_country': 0.07178114}
mlp_preds = mlp_trainer.predict(X_tab=X_tab_test)
predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:00<00:00, 212.38it/s]
accuracy_score(mlp_preds, test.income_label)
0.8726714431934494