Dataloaders
NOTE: This module should contain custom dataloaders
that the user might want to implement. At the moment pytorch-widedeep
offers one custom dataloader, DataLoaderImbalanced
.
DataLoaderImbalanced
Bases: CustomDataLoader
Class to load and shuffle batches with adjusted weights for imbalanced
datasets. If the classes do not begin from 0 remapping is necessary. See
here.
Parameters:
Name |
Type |
Description |
Default |
dataset
|
Optional[WideDeepDataset]
|
see pytorch_widedeep.training._wd_dataset
|
None
|
Other Parameters:
Name |
Type |
Description |
*args |
|
Positional arguments to be passed to the parent CustomDataLoader.
|
**kwargs |
|
This can include any parameter that can be passed to the 'standard'
pytorch
DataLoader
and that is not already explicitely passed to the class. In addition,
the dictionary can also include the extra parameter oversample_mul which
will multiply the number of samples of the minority class to be sampled by
the WeightedRandomSampler .
In other words, the num_samples param in WeightedRandomSampler will be defined as:
\[
minority \space class \space count \times number \space of \space classes \times oversample\_mul
\]
|
Source code in pytorch_widedeep/dataloaders.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 | class DataLoaderImbalanced(CustomDataLoader):
r"""Class to load and shuffle batches with adjusted weights for imbalanced
datasets. If the classes do not begin from 0 remapping is necessary. See
[here](https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab).
Parameters
----------
dataset: `WideDeepDataset`
see `pytorch_widedeep.training._wd_dataset`
Other Parameters
----------------
*args: Any
Positional arguments to be passed to the parent CustomDataLoader.
**kwargs: Dict
This can include any parameter that can be passed to the _'standard'_
pytorch
[DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
and that is not already explicitely passed to the class. In addition,
the dictionary can also include the extra parameter `oversample_mul` which
will multiply the number of samples of the minority class to be sampled by
the [`WeightedRandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler).
In other words, the `num_samples` param in `WeightedRandomSampler` will be defined as:
$$
minority \space class \space count \times number \space of \space classes \times oversample\_mul
$$
"""
def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs):
self.args = args
self.kwargs = kwargs
if dataset is not None:
self._setup_sampler(dataset)
super().__init__(dataset, *args, sampler=self.sampler, **kwargs)
else:
super().__init__()
def set_dataset(self, dataset: WideDeepDataset):
sampler = self._setup_sampler(dataset)
# update the kwargs with the new sampler
self.kwargs["sampler"] = sampler
super().set_dataset(dataset)
def _setup_sampler(self, dataset: WideDeepDataset) -> WeightedRandomSampler:
assert dataset.Y is not None, (
"The 'dataset' instance of WideDeepDataset must contain a "
"target array 'Y'"
)
oversample_mul = self.kwargs.pop("oversample_mul", 1)
weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
samples_weight = list(np.array([weights[i] for i in dataset.Y]))
sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
return sampler
|