Skip to content

Dataloaders


ℹ️ NOTE: This module should contain custom dataloaders that the user might want to implement. At the moment pytorch-widedeep offers one custom dataloader, DataLoaderImbalanced.


DataLoaderImbalanced

DataLoaderImbalanced(dataset, batch_size, num_workers, **kwargs)

Bases: DataLoader

Class to load and shuffle batches with adjusted weights for imbalanced datasets. If the classes do not begin from 0 remapping is necessary. See here.

Parameters:

  • dataset (WideDeepDataset) –

    see pytorch_widedeep.training._wd_dataset

  • batch_size (int) –

    size of batch

  • num_workers (int) –

    number of workers

Other Parameters:

  • **kwargs

    This can include any parameter that can be passed to the 'standard' pytorch DataLoader and that is not already explicitely passed to the class. In addition, the dictionary can also include the extra parameter oversample_mul which will multiply the number of samples of the minority class to be sampled by the WeightedRandomSampler.

    In other words, the num_samples param in WeightedRandomSampler will be defined as:

    \[ minority \space class \space count \times number \space of \space classes \times oversample\_mul \]
Source code in pytorch_widedeep/dataloaders.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self, dataset: WideDeepDataset, batch_size: int, num_workers: int, **kwargs
):
    assert dataset.Y is not None, (
        "The 'dataset' instance of WideDeepDataset must contain a "
        "target array 'Y'"
    )

    self.with_lds = dataset.with_lds
    if "oversample_mul" in kwargs:
        oversample_mul = kwargs["oversample_mul"]
        del kwargs["oversample_mul"]
    else:
        oversample_mul = 1
    weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
    num_samples = int(minor_cls_cnt * num_clss * oversample_mul)
    samples_weight = list(np.array([weights[i] for i in dataset.Y]))
    sampler = WeightedRandomSampler(samples_weight, num_samples, replacement=True)
    super().__init__(
        dataset, batch_size, num_workers=num_workers, sampler=sampler, **kwargs
    )