Skip to content

Dataloaders


ℹ️ NOTE: This module should contain custom dataloaders that the user might want to implement. At the moment pytorch-widedeep offers one custom dataloader, DataLoaderImbalanced.


DataLoaderImbalanced

Bases: CustomDataLoader

Class to load and shuffle batches with adjusted weights for imbalanced datasets. If the classes do not begin from 0 remapping is necessary. See here.

Parameters:

Name Type Description Default
dataset Optional[WideDeepDataset]

see pytorch_widedeep.training._wd_dataset

None

Other Parameters:

Name Type Description
*args

Positional arguments to be passed to the parent CustomDataLoader.

**kwargs

This can include any parameter that can be passed to the 'standard' pytorch DataLoader and that is not already explicitely passed to the class. In addition, the dictionary can also include the extra parameter oversample_mul which will multiply the number of samples of the minority class to be sampled by the WeightedRandomSampler.

In other words, the num_samples param in WeightedRandomSampler will be defined as:

\[ minority \space class \space count \times number \space of \space classes \times oversample\_mul \]
Source code in pytorch_widedeep/dataloaders.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class DataLoaderImbalanced(CustomDataLoader):
    r"""Class to load and shuffle batches with adjusted weights for imbalanced
    datasets. If the classes do not begin from 0 remapping is necessary. See
    [here](https://towardsdatascience.com/pytorch-tabular-multiclass-classification-9f8211a123ab).

    Parameters
    ----------
    dataset: `WideDeepDataset`
        see `pytorch_widedeep.training._wd_dataset`

    Other Parameters
    ----------------
    *args: Any
        Positional arguments to be passed to the parent CustomDataLoader.
    **kwargs: Dict
        This can include any parameter that can be passed to the _'standard'_
        pytorch
        [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)
        and that is not already explicitely passed to the class. In addition,
        the dictionary can also include the extra parameter `oversample_mul` which
        will multiply the number of samples of the minority class to be sampled by
        the [`WeightedRandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler).

        In other words, the `num_samples` param in `WeightedRandomSampler` will be defined as:

        $$
        minority \space class \space count \times number \space of \space classes \times oversample\_mul
        $$
    """

    def __init__(self, dataset: Optional[WideDeepDataset] = None, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

        if dataset is not None:
            self._setup_sampler(dataset)
            super().__init__(dataset, *args, sampler=self.sampler, **kwargs)
        else:
            super().__init__()

    def set_dataset(self, dataset: WideDeepDataset):
        sampler = self._setup_sampler(dataset)
        # update the kwargs with the new sampler
        self.kwargs["sampler"] = sampler
        super().set_dataset(dataset)

    def _setup_sampler(self, dataset: WideDeepDataset) -> WeightedRandomSampler:
        assert dataset.Y is not None, (
            "The 'dataset' instance of WideDeepDataset must contain a "
            "target array 'Y'"
        )

        oversample_mul = self.kwargs.pop("oversample_mul", 1)
        weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
        num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
        samples_weight = list(np.array([weights[i] for i in dataset.Y]))
        sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
        return sampler