Skip to content

Losses

pytorch-widedeep accepts a number of losses and objectives that can be passed to the Trainer class via the parameter objective (see pytorch-widedeep.training.Trainer). For most cases the loss function that pytorch-widedeep will use internally is already implemented in Pytorch.

In addition, pytorch-widedeep implements a series of "custom" loss functions. These are described below for completion since, as mentioned before, they are used internally by the Trainer. Of course, onen could always use them on their own and can be imported as:

from pytorch_widedeep.losses import FocalLoss


ℹ️ NOTE: Losses in this module expect the predictions and ground truth to have the same dimensions for regression and binary classification problems \((N_{samples}, 1)\). In the case of multiclass classification problems the ground truth is expected to be a 1D tensor with the corresponding classes. See Examples below


MSELoss

MSELoss()

Bases: Module

Mean square error loss with the option of using Label Smooth Distribution (LDS)

LDS is based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
25
26
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions

  • target (Tensor) –

    Target tensor with the actual values

  • lds_weight (Optional[Tensor], default: None ) –

    Tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import MSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
>>> loss = MSELoss()(input, target, lds_weight)
Source code in pytorch_widedeep/losses.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def forward(
    self, input: Tensor, target: Tensor, lds_weight: Optional[Tensor] = None
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values
    lds_weight: Tensor, Optional
        Tensor of weights that will multiply the loss value.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import MSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
    >>> loss = MSELoss()(input, target, lds_weight)
    """
    loss = (input - target) ** 2
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

MSLELoss

MSLELoss()

Bases: Module

Mean square log error loss with the option of using Label Smooth Distribution (LDS)

LDS is based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
65
66
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    Tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import MSLELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
>>> loss = MSLELoss()(input, target, lds_weight)
Source code in pytorch_widedeep/losses.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def forward(
    self, input: Tensor, target: Tensor, lds_weight: Optional[Tensor] = None
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        Tensor of weights that will multiply the loss value.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import MSLELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
    >>> loss = MSLELoss()(input, target, lds_weight)
    """
    assert (
        input.min() >= 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"

    loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

RMSELoss

RMSELoss()

Bases: Module

Root mean square error loss adjusted for the possibility of using Label Smooth Distribution (LDS)

LDS is based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
112
113
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    Tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import RMSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
>>> loss = RMSELoss()(input, target, lds_weight)
Source code in pytorch_widedeep/losses.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def forward(
    self, input: Tensor, target: Tensor, lds_weight: Optional[Tensor] = None
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        Tensor of weights that will multiply the loss value.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import RMSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
    >>> loss = RMSELoss()(input, target, lds_weight)
    """
    loss = (input - target) ** 2
    if lds_weight is not None:
        loss *= lds_weight
    return torch.sqrt(torch.mean(loss))

RMSLELoss

RMSLELoss()

Bases: Module

Root mean square log error loss adjusted for the possibility of using Label Smooth Distribution (LDS)

LDS is based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
152
153
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    Tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import RMSLELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
>>> loss = RMSLELoss()(input, target, lds_weight)
Source code in pytorch_widedeep/losses.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def forward(
    self, input: Tensor, target: Tensor, lds_weight: Optional[Tensor] = None
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        Tensor of weights that will multiply the loss value.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import RMSLELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
    >>> loss = RMSLELoss()(input, target, lds_weight)
    """
    assert (
        input.min() >= 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"

    loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
    if lds_weight is not None:
        loss *= lds_weight
    return torch.sqrt(torch.mean(loss))

QuantileLoss

QuantileLoss(quantiles=[0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98])

Bases: Module

Quantile loss defined as:

\[ Loss = max(q \times (y-y_{pred}), (1-q) \times (y_{pred}-y)) \]

All credits go to the implementation at pytorch-forecasting.

Parameters:

  • quantiles (List[float], default: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] ) –

    List of quantiles

Source code in pytorch_widedeep/losses.py
207
208
209
210
211
212
def __init__(
    self,
    quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
):
    super().__init__()
    self.quantiles = quantiles

forward

forward(input, target)

Parameters:

  • input (Tensor) –

    Input tensor with predictions

  • target (Tensor) –

    Target tensor with the actual values

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import QuantileLoss
>>>
>>> # REGRESSION
>>> target = torch.tensor([[0.6, 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2,], [.4, .5]])
>>> qloss = QuantileLoss([0.25, 0.75])
>>> loss = qloss(input, target)
Source code in pytorch_widedeep/losses.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import QuantileLoss
    >>>
    >>> # REGRESSION
    >>> target = torch.tensor([[0.6, 1.5]]).view(-1, 1)
    >>> input = torch.tensor([[.1, .2,], [.4, .5]])
    >>> qloss = QuantileLoss([0.25, 0.75])
    >>> loss = qloss(input, target)
    """

    assert input.shape == torch.Size([target.shape[0], len(self.quantiles)]), (
        "The input and target have inconsistent shape. The dimension of the prediction "
        "of the model that is using QuantileLoss must be equal to number of quantiles, "
        f"i.e. {len(self.quantiles)}."
    )
    target = target.view(-1, 1).float()
    losses = []
    for i, q in enumerate(self.quantiles):
        errors = target - input[..., i]
        losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))

    loss = torch.cat(losses, dim=2)

    return torch.mean(loss)

FocalLoss

FocalLoss(alpha=0.25, gamma=1.0)

Bases: Module

Implementation of the Focal loss for both binary and multiclass classification:

\[ FL(p_t) = \alpha (1 - p_t)^{\gamma} log(p_t) \]

where, for a case of a binary classification problem

\[ \begin{equation} p_t= \begin{cases}p, & \text{if $y=1$}.\\1-p, & \text{otherwise}. \end{cases} \end{equation} \]

Parameters:

  • alpha (float, default: 0.25 ) –

    Focal Loss alpha parameter

  • gamma (float, default: 1.0 ) –

    Focal Loss gamma parameter

Source code in pytorch_widedeep/losses.py
274
275
276
277
def __init__(self, alpha: float = 0.25, gamma: float = 1.0):
    super().__init__()
    self.alpha = alpha
    self.gamma = gamma

forward

forward(input, target)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalLoss
>>>
>>> # BINARY
>>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
>>> loss = FocalLoss()(input, target)
>>>
>>> # MULTICLASS
>>> target = torch.tensor([1, 0, 2]).view(-1, 1)
>>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
>>> loss = FocalLoss()(input, target)
Source code in pytorch_widedeep/losses.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalLoss
    >>>
    >>> # BINARY
    >>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
    >>> loss = FocalLoss()(input, target)
    >>>
    >>> # MULTICLASS
    >>> target = torch.tensor([1, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
    >>> loss = FocalLoss()(input, target)
    """
    input_prob = torch.sigmoid(input)
    if input.size(1) == 1:
        input_prob = torch.cat([1 - input_prob, input_prob], axis=1)  # type: ignore
        num_class = 2
    else:
        num_class = input_prob.size(1)
    binary_target = torch.eye(num_class)[target.squeeze().cpu().long()]
    if use_cuda:
        binary_target = binary_target.cuda()
    binary_target = binary_target.contiguous()
    weight = self._get_weight(input_prob, binary_target)

    return F.binary_cross_entropy(
        input_prob, binary_target, weight, reduction="mean"
    )

BayesianSELoss

BayesianSELoss()

Bases: Module

Squared Loss (log Gaussian) for the case of a regression as specified in the original publication Weight Uncertainty in Neural Networks.

Source code in pytorch_widedeep/losses.py
352
353
def __init__(self):
    super().__init__()

forward

forward(input, target)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import BayesianSELoss
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = BayesianSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import BayesianSELoss
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = BayesianSELoss()(input, target)
    """
    return (0.5 * (input - target) ** 2).sum()

TweedieLoss

TweedieLoss()

Bases: Module

Tweedie loss for extremely unbalanced zero-inflated data

All credits go to Wenbo Shi. See this post and the original publication for details.

Source code in pytorch_widedeep/losses.py
384
385
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None, p=1.5)

Parameters:

  • input (Tensor) –

    Input tensor with predictions

  • target (Tensor) –

    Target tensor with the actual values

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

  • p (float, default: 1.5 ) –

    the power to be used to compute the loss. See the original publication for details

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import TweedieLoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
>>> loss = TweedieLoss()(input, target, lds_weight)
Source code in pytorch_widedeep/losses.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def forward(
    self,
    input: Tensor,
    target: Tensor,
    lds_weight: Optional[Tensor] = None,
    p: float = 1.5,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.
    p: float, default = 1.5
        the power to be used to compute the loss. See the original
        publication for details

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import TweedieLoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> lds_weight = torch.tensor([0.1, 0.2, 0.3, 0.4]).view(-1, 1)
    >>> loss = TweedieLoss()(input, target, lds_weight)
    """

    assert (
        input.min() > 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"
    loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / (
        2 - p
    )
    if lds_weight is not None:
        loss *= lds_weight

    return torch.mean(loss)

ZILNLoss

ZILNLoss()

Bases: Module

Adjusted implementation of the Zero Inflated LogNormal Loss

See A Deep Probabilistic Model for Customer Lifetime Value Prediction and the corresponding code.

Source code in pytorch_widedeep/losses.py
442
443
def __init__(self):
    super().__init__()

forward

forward(input, target)

Parameters:

  • input (Tensor) –

    Input tensor with predictions with spape (N,3), where N is the batch size

  • target (Tensor) –

    Target tensor with the actual target values

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import ZILNLoss
>>>
>>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
>>> loss = ZILNLoss()(input, target)
Source code in pytorch_widedeep/losses.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions with spape (N,3), where N is the batch size
    target: Tensor
        Target tensor with the actual target values

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import ZILNLoss
    >>>
    >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
    >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
    >>> loss = ZILNLoss()(input, target)
    """
    positive = target > 0
    positive = positive.float()

    assert input.shape == torch.Size([target.shape[0], 3]), (
        "Wrong shape of the 'input' tensor. The pred_dim of the "
        "model that is using ZILNLoss must be equal to 3."
    )

    positive_input = input[..., :1]

    classification_loss = F.binary_cross_entropy_with_logits(
        positive_input, positive, reduction="none"
    ).flatten()

    loc = input[..., 1:2]

    # when using max the two input tensors (input and other) have to be of
    # the same type
    max_input = F.softplus(input[..., 2:])
    max_other = torch.sqrt(torch.Tensor([torch.finfo(torch.double).eps])).type(
        max_input.type()
    )
    scale = torch.max(max_input, max_other)
    safe_labels = positive * target + (1 - positive) * torch.ones_like(target)

    regression_loss = -torch.mean(
        positive
        * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(
            safe_labels
        ),
        dim=-1,
    )

    return torch.mean(classification_loss + regression_loss)

L1Loss

L1Loss()

Bases: Module

L1 loss adjusted for the possibility of using Label Smooth Distribution (LDS)

LDS is based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
507
508
def __init__(self):
    super().__init__()

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions

  • target (Tensor) –

    Target tensor with the actual values

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import L1Loss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = L1Loss()(input, target)
Source code in pytorch_widedeep/losses.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def forward(
    self, input: Tensor, target: Tensor, lds_weight: Optional[Tensor] = None
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import L1Loss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = L1Loss()(input, target)
    """
    loss = F.l1_loss(input, target, reduction="none")
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

FocalR_L1Loss

FocalR_L1Loss(beta=0.2, gamma=1.0, activation_fn='sigmoid')

Bases: Module

Focal-R L1 loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

  • beta (float, default: 0.2 ) –

    Focal Loss beta parameter in their implementation

  • gamma (float, default: 1.0 ) –

    Focal Loss gamma parameter

  • activation_fn (Literal[sigmoid, tanh], default: 'sigmoid' ) –

    Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

Source code in pytorch_widedeep/losses.py
557
558
559
560
561
562
563
564
565
566
def __init__(
    self,
    beta: float = 0.2,
    gamma: float = 1.0,
    activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
):
    super().__init__()
    self.beta = beta
    self.gamma = gamma
    self.activation_fn = activation_fn

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_L1Loss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_L1Loss()(input, target)
Source code in pytorch_widedeep/losses.py
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def forward(
    self,
    input: Tensor,
    target: Tensor,
    lds_weight: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_L1Loss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_L1Loss()(input, target)
    """
    loss = F.l1_loss(input, target, reduction="none")
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs(input - target)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

FocalR_MSELoss

FocalR_MSELoss(beta=0.2, gamma=1.0, activation_fn='sigmoid')

Bases: Module

Focal-R MSE loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

  • beta (float, default: 0.2 ) –

    Focal Loss beta parameter in their implementation

  • gamma (float, default: 1.0 ) –

    Focal Loss gamma parameter

  • activation_fn (Literal[sigmoid, tanh], default: 'sigmoid' ) –

    Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

Source code in pytorch_widedeep/losses.py
628
629
630
631
632
633
634
635
636
637
def __init__(
    self,
    beta: float = 0.2,
    gamma: float = 1.0,
    activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
):
    super().__init__()
    self.beta = beta
    self.gamma = gamma
    self.activation_fn = activation_fn

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_MSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_MSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def forward(
    self,
    input: Tensor,
    target: Tensor,
    lds_weight: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_MSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_MSELoss()(input, target)
    """
    loss = (input - target) ** 2
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

FocalR_RMSELoss

FocalR_RMSELoss(beta=0.2, gamma=1.0, activation_fn='sigmoid')

Bases: Module

Focal-R RMSE loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

  • beta (float, default: 0.2 ) –

    Focal Loss beta parameter in their implementation

  • gamma (float, default: 1.0 ) –

    Focal Loss gamma parameter

  • activation_fn (Literal[sigmoid, tanh], default: 'sigmoid' ) –

    Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

Source code in pytorch_widedeep/losses.py
699
700
701
702
703
704
705
706
707
708
def __init__(
    self,
    beta: float = 0.2,
    gamma: float = 1.0,
    activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
):
    super().__init__()
    self.beta = beta
    self.gamma = gamma
    self.activation_fn = activation_fn

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_RMSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_RMSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
def forward(
    self,
    input: Tensor,
    target: Tensor,
    lds_weight: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_RMSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_RMSELoss()(input, target)
    """
    loss = (input - target) ** 2
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    if lds_weight is not None:
        loss *= lds_weight
    return torch.sqrt(torch.mean(loss))

HuberLoss

HuberLoss(beta=0.2)

Bases: Module

Hubbler Loss

Based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
759
760
761
def __init__(self, beta: float = 0.2):
    super().__init__()
    self.beta = beta

forward

forward(input, target, lds_weight=None)

Parameters:

  • input (Tensor) –

    Input tensor with predictions (not probabilities)

  • target (Tensor) –

    Target tensor with the actual classes

  • lds_weight (Optional[Tensor], default: None ) –

    If we choose to use LDS this is the tensor of weights that will multiply the loss value.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import HuberLoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = HuberLoss()(input, target)
Source code in pytorch_widedeep/losses.py
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def forward(
    self,
    input: Tensor,
    target: Tensor,
    lds_weight: Optional[Tensor] = None,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes
    lds_weight: Tensor, Optional
        If we choose to use LDS this is the tensor of weights that will
        multiply the loss value.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import HuberLoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = HuberLoss()(input, target)
    """
    l1_loss = torch.abs(input - target)
    cond = l1_loss < self.beta
    loss = torch.where(
        cond, 0.5 * l1_loss**2 / self.beta, l1_loss - 0.5 * self.beta
    )
    if lds_weight is not None:
        loss *= lds_weight
    return torch.mean(loss)

InfoNCELoss

InfoNCELoss(temperature=0.1, reduction='mean')

Bases: Module

InfoNCE Loss. Loss applied during the Contrastive Denoising Self Supervised Pre-training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

See SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training and references therein

Partially inspired by the code in this repo

Source code in pytorch_widedeep/losses.py
822
823
824
825
826
def __init__(self, temperature: float = 0.1, reduction: str = "mean"):
    super(InfoNCELoss, self).__init__()

    self.temperature = temperature
    self.reduction = reduction

forward

forward(g_projs)

Parameters:

  • g_projs (Tuple[Tensor, Tensor]) –

    Tuple with the two tensors corresponding to the output of the two projection heads, as described 'SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training'.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import InfoNCELoss
>>> g_projs = (torch.rand(3, 5, 16), torch.rand(3, 5, 16))
>>> loss = InfoNCELoss()
>>> res = loss(g_projs)
Source code in pytorch_widedeep/losses.py
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
def forward(self, g_projs: Tuple[Tensor, Tensor]) -> Tensor:
    r"""
    Parameters
    ----------
    g_projs: Tuple
        Tuple with the two tensors corresponding to the output of the two
        projection heads, as described 'SAINT: Improved Neural Networks
        for Tabular Data via Row Attention and Contrastive Pre-Training'.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import InfoNCELoss
    >>> g_projs = (torch.rand(3, 5, 16), torch.rand(3, 5, 16))
    >>> loss = InfoNCELoss()
    >>> res = loss(g_projs)
    """
    z, z_ = g_projs[0], g_projs[1]

    norm_z = F.normalize(z, dim=-1).flatten(1)
    norm_z_ = F.normalize(z_, dim=-1).flatten(1)

    logits = (norm_z @ norm_z_.t()) / self.temperature
    logits_ = (norm_z_ @ norm_z.t()) / self.temperature

    # the target/labels are the entries on the diagonal
    target = torch.arange(len(norm_z), device=norm_z.device)

    loss = F.cross_entropy(logits, target, reduction=self.reduction)
    loss_ = F.cross_entropy(logits_, target, reduction=self.reduction)

    return (loss + loss_) / 2.0

DenoisingLoss

DenoisingLoss(lambda_cat=1.0, lambda_cont=1.0, reduction='mean')

Bases: Module

Denoising Loss. Loss applied during the Contrastive Denoising Self Supervised Pre-training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

See SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training and references therein

Source code in pytorch_widedeep/losses.py
886
887
888
889
890
891
892
893
def __init__(
    self, lambda_cat: float = 1.0, lambda_cont: float = 1.0, reduction: str = "mean"
):
    super(DenoisingLoss, self).__init__()

    self.lambda_cat = lambda_cat
    self.lambda_cont = lambda_cont
    self.reduction = reduction

forward

forward(x_cat_and_cat_, x_cont_and_cont_)

Parameters:

  • x_cat_and_cat_ (Optional[Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]]) –

    Tuple of tensors containing the raw input features and their encodings, referred in the SAINT paper as \(x\) and \(x''\) respectively. If one denoising MLP is used per categorical feature x_cat_and_cat_ will be a list of tuples, one per categorical feature

  • x_cont_and_cont_ (Optional[Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]]) –

    same as x_cat_and_cat_ but for continuous columns

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import DenoisingLoss
>>> x_cat_and_cat_ = (torch.empty(3).random_(3).long(), torch.randn(3, 3))
>>> x_cont_and_cont_ = (torch.randn(3, 1), torch.randn(3, 1))
>>> loss = DenoisingLoss()
>>> res = loss(x_cat_and_cat_, x_cont_and_cont_)
Source code in pytorch_widedeep/losses.py
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
def forward(
    self,
    x_cat_and_cat_: Optional[
        Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ],
    x_cont_and_cont_: Optional[
        Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ],
) -> Tensor:
    r"""
    Parameters
    ----------
    x_cat_and_cat_: tuple of Tensors or lists of tuples
        Tuple of tensors containing the raw input features and their
        encodings, referred in the SAINT paper as $x$ and $x''$
        respectively. If one denoising MLP is used per categorical
        feature `x_cat_and_cat_` will be a list of tuples, one per
        categorical feature
    x_cont_and_cont_: tuple of Tensors or lists of tuples
        same as `x_cat_and_cat_` but for continuous columns

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import DenoisingLoss
    >>> x_cat_and_cat_ = (torch.empty(3).random_(3).long(), torch.randn(3, 3))
    >>> x_cont_and_cont_ = (torch.randn(3, 1), torch.randn(3, 1))
    >>> loss = DenoisingLoss()
    >>> res = loss(x_cat_and_cat_, x_cont_and_cont_)
    """

    loss_cat = (
        self._compute_cat_loss(x_cat_and_cat_)
        if x_cat_and_cat_ is not None
        else torch.tensor(0.0)
    )
    loss_cont = (
        self._compute_cont_loss(x_cont_and_cont_)
        if x_cont_and_cont_ is not None
        else torch.tensor(0.0)
    )

    return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont

EncoderDecoderLoss

EncoderDecoderLoss(eps=1e-09)

Bases: Module

'Standard' Encoder Decoder Loss. Loss applied during the Endoder-Decoder Self-Supervised Pre-Training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

The implementation of this lost is based on that at the tabnet repo, which is in itself an adaptation of that in the original paper TabNet: Attentive Interpretable Tabular Learning.

Source code in pytorch_widedeep/losses.py
993
994
995
def __init__(self, eps: float = 1e-9):
    super(EncoderDecoderLoss, self).__init__()
    self.eps = eps

forward

forward(x_true, x_pred, mask)

Parameters:

  • x_true (Tensor) –

    Embeddings of the input data

  • x_pred (Tensor) –

    Reconstructed embeddings

  • mask (Tensor) –

    Mask with 1s indicated that the reconstruction, and therefore the loss, is based on those features.

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import EncoderDecoderLoss
>>> x_true = torch.rand(3, 3)
>>> x_pred = torch.rand(3, 3)
>>> mask = torch.empty(3, 3).random_(2)
>>> loss = EncoderDecoderLoss()
>>> res = loss(x_true, x_pred, mask)
Source code in pytorch_widedeep/losses.py
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
def forward(self, x_true: Tensor, x_pred: Tensor, mask: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    x_true: Tensor
        Embeddings of the input data
    x_pred: Tensor
        Reconstructed embeddings
    mask: Tensor
        Mask with 1s indicated that the reconstruction, and therefore the
        loss, is based on those features.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import EncoderDecoderLoss
    >>> x_true = torch.rand(3, 3)
    >>> x_pred = torch.rand(3, 3)
    >>> mask = torch.empty(3, 3).random_(2)
    >>> loss = EncoderDecoderLoss()
    >>> res = loss(x_true, x_pred, mask)
    """

    errors = x_pred - x_true

    reconstruction_errors = torch.mul(errors, mask) ** 2

    x_true_means = torch.mean(x_true, dim=0)
    x_true_means[x_true_means == 0] = 1

    x_true_stds = torch.std(x_true, dim=0) ** 2
    x_true_stds[x_true_stds == 0] = x_true_means[x_true_stds == 0]

    features_loss = torch.matmul(reconstruction_errors, 1 / x_true_stds)
    nb_reconstructed_variables = torch.sum(mask, dim=1)
    features_loss_norm = features_loss / (nb_reconstructed_variables + self.eps)

    loss = torch.mean(features_loss_norm)

    return loss