Skip to content

Metrics


ℹ️ NOTE: metrics in this module expect the predictions and ground truth to have the same dimensions for regression and binary classification problems: \((N_{samples}, 1)\). In the case of multiclass classification problems the ground truth is expected to be a 1D tensor with the corresponding classes. See Examples below


We have added the possibility of using the metrics available at the torchmetrics library. Note that this library is still in its early versions and therefore this option should be used with caution. To use torchmetrics simply import them and use them as any of the pytorch-widedeep metrics described below.

from torchmetrics import Accuracy, Precision

accuracy = Accuracy(average=None, num_classes=2)
precision = Precision(average='micro', num_classes=2)

trainer = Trainer(model, objective="binary", metrics=[accuracy, precision])

A functioning example for pytorch-widedeep using torchmetrics can be found in the Examples folder

ℹ️ NOTE: the forward method for all metrics in this module takes two tensors, y_pred and y_true (in that order). Therefore, we do not include the method in the documentation.

Accuracy

Bases: Metric

Class to calculate the accuracy for both binary and categorical problems

Parameters:

Name Type Description Default
top_k int

Accuracy will be computed using the top k most likely classes in multiclass problems

1

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Accuracy
>>>
>>> acc = Accuracy()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> acc(y_pred, y_true)
array(0.5)
>>>
>>> acc = Accuracy(top_k=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> acc(y_pred, y_true)
array(0.66666667)
Source code in pytorch_widedeep/metrics.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class Accuracy(Metric):
    r"""Class to calculate the accuracy for both binary and categorical problems

    Parameters
    ----------
    top_k: int, default = 1
        Accuracy will be computed using the top k most likely classes in
        multiclass problems

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import Accuracy
    >>>
    >>> acc = Accuracy()
    >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
    >>> acc(y_pred, y_true)
    array(0.5)
    >>>
    >>> acc = Accuracy(top_k=2)
    >>> y_true = torch.tensor([0, 1, 2])
    >>> y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
    >>> acc(y_pred, y_true)
    array(0.66666667)
    """

    def __init__(self, top_k: int = 1):
        super(Accuracy, self).__init__()

        self.top_k = top_k
        self.correct_count = 0
        self.total_count = 0
        self._name = "acc"

    def reset(self):
        """
        resets counters to 0
        """
        self.correct_count = 0
        self.total_count = 0

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        num_classes = y_pred.size(1)

        if num_classes == 1:
            y_pred = y_pred.round()
            y_true = y_true
        elif num_classes > 1:
            y_pred = y_pred.topk(self.top_k, 1)[1]
            y_true = y_true.view(-1, 1).expand_as(y_pred)

        self.correct_count += y_pred.eq(y_true).sum().item()  # type: ignore[assignment]
        self.total_count += len(y_pred)
        accuracy = float(self.correct_count) / float(self.total_count)
        return np.array(accuracy)

reset

reset()

resets counters to 0

Source code in pytorch_widedeep/metrics.py
84
85
86
87
88
89
def reset(self):
    """
    resets counters to 0
    """
    self.correct_count = 0
    self.total_count = 0

Precision

Bases: Metric

Class to calculate the precision for both binary and categorical problems

Parameters:

Name Type Description Default
average bool

This applies only to multiclass problems. if True calculate precision for each label, and finds their unweighted mean.

True

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Precision
>>>
>>> prec = Precision()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> prec(y_pred, y_true)
array(0.5)
>>>
>>> prec = Precision(average=True)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> prec(y_pred, y_true)
array(0.33333334)
Source code in pytorch_widedeep/metrics.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class Precision(Metric):
    r"""Class to calculate the precision for both binary and categorical problems

    Parameters
    ----------
    average: bool, default = True
        This applies only to multiclass problems. if ``True`` calculate
        precision for each label, and finds their unweighted mean.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import Precision
    >>>
    >>> prec = Precision()
    >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
    >>> prec(y_pred, y_true)
    array(0.5)
    >>>
    >>> prec = Precision(average=True)
    >>> y_true = torch.tensor([0, 1, 2])
    >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
    >>> prec(y_pred, y_true)
    array(0.33333334)
    """

    def __init__(self, average: bool = True):
        super(Precision, self).__init__()

        self.average = average
        self.true_positives = 0
        self.all_positives = 0
        self.eps = 1e-20
        self._name = "prec"

    def reset(self):
        """
        resets counters to 0
        """
        self.true_positives = 0
        self.all_positives = 0

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        num_class = y_pred.size(1)

        if num_class == 1:
            y_pred = y_pred.round()
            y_true = y_true
        elif num_class > 1:
            y_true = torch.eye(num_class)[y_true.squeeze().cpu().long()]
            y_pred = y_pred.topk(1, 1)[1].view(-1)
            y_pred = torch.eye(num_class)[y_pred.cpu().long()]

        self.true_positives += (y_true * y_pred).sum(dim=0)  # type:ignore
        self.all_positives += y_pred.sum(dim=0)  # type:ignore

        precision = self.true_positives / (self.all_positives + self.eps)

        if self.average:
            return np.array(precision.mean().item())  # type:ignore
        else:
            return precision.detach().cpu().numpy()  # type: ignore[attr-defined]

reset

reset()

resets counters to 0

Source code in pytorch_widedeep/metrics.py
144
145
146
147
148
149
def reset(self):
    """
    resets counters to 0
    """
    self.true_positives = 0
    self.all_positives = 0

Recall

Bases: Metric

Class to calculate the recall for both binary and categorical problems

Parameters:

Name Type Description Default
average bool

This applies only to multiclass problems. if True calculate recall for each label, and finds their unweighted mean.

True

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import Recall
>>>
>>> rec = Recall()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> rec(y_pred, y_true)
array(0.5)
>>>
>>> rec = Recall(average=True)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> rec(y_pred, y_true)
array(0.33333334)
Source code in pytorch_widedeep/metrics.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class Recall(Metric):
    r"""Class to calculate the recall for both binary and categorical problems

    Parameters
    ----------
    average: bool, default = True
        This applies only to multiclass problems. if ``True`` calculate recall
        for each label, and finds their unweighted mean.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import Recall
    >>>
    >>> rec = Recall()
    >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
    >>> rec(y_pred, y_true)
    array(0.5)
    >>>
    >>> rec = Recall(average=True)
    >>> y_true = torch.tensor([0, 1, 2])
    >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
    >>> rec(y_pred, y_true)
    array(0.33333334)
    """

    def __init__(self, average: bool = True):
        super(Recall, self).__init__()

        self.average = average
        self.true_positives = 0
        self.actual_positives = 0
        self.eps = 1e-20
        self._name = "rec"

    def reset(self):
        """
        resets counters to 0
        """
        self.true_positives = 0
        self.actual_positives = 0

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        num_class = y_pred.size(1)

        if num_class == 1:
            y_pred = y_pred.round()
            y_true = y_true
        elif num_class > 1:
            y_true = torch.eye(num_class)[y_true.squeeze().cpu().long()]
            y_pred = y_pred.topk(1, 1)[1].view(-1)
            y_pred = torch.eye(num_class)[y_pred.cpu().long()]

        self.true_positives += (y_true * y_pred).sum(dim=0)  # type: ignore
        self.actual_positives += y_true.sum(dim=0)  # type: ignore

        recall = self.true_positives / (self.actual_positives + self.eps)

        if self.average:
            return np.array(recall.mean().item())  # type:ignore
        else:
            return recall.detach().cpu().numpy()  # type: ignore[attr-defined]

reset

reset()

resets counters to 0

Source code in pytorch_widedeep/metrics.py
210
211
212
213
214
215
def reset(self):
    """
    resets counters to 0
    """
    self.true_positives = 0
    self.actual_positives = 0

FBetaScore

Bases: Metric

Class to calculate the fbeta score for both binary and categorical problems

\[ F_{\beta} = ((1 + {\beta}^2) * \frac{(precision * recall)}{({\beta}^2 * precision + recall)} \]

Parameters:

Name Type Description Default
beta int

Coefficient to control the balance between precision and recall

required
average bool

This applies only to multiclass problems. if True calculate fbeta for each label, and find their unweighted mean.

True

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import FBetaScore
>>>
>>> fbeta = FBetaScore(beta=2)
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> fbeta(y_pred, y_true)
array(0.5)
>>>
>>> fbeta = FBetaScore(beta=2)
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> fbeta(y_pred, y_true)
array(0.33333334)
Source code in pytorch_widedeep/metrics.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class FBetaScore(Metric):
    r"""Class to calculate the fbeta score for both binary and categorical problems

    $$
    F_{\beta} = ((1 + {\beta}^2) * \frac{(precision * recall)}{({\beta}^2 * precision + recall)}
    $$

    Parameters
    ----------
    beta: int
        Coefficient to control the balance between precision and recall
    average: bool, default = True
        This applies only to multiclass problems. if ``True`` calculate fbeta
        for each label, and find their unweighted mean.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import FBetaScore
    >>>
    >>> fbeta = FBetaScore(beta=2)
    >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
    >>> fbeta(y_pred, y_true)
    array(0.5)
    >>>
    >>> fbeta = FBetaScore(beta=2)
    >>> y_true = torch.tensor([0, 1, 2])
    >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
    >>> fbeta(y_pred, y_true)
    array(0.33333334)
    """

    def __init__(self, beta: int, average: bool = True):
        super(FBetaScore, self).__init__()

        self.beta = beta
        self.average = average
        self.precision = Precision(average=False)
        self.recall = Recall(average=False)
        self.eps = 1e-20
        self._name = "".join(["f", str(self.beta)])

    def reset(self):
        """
        resets precision and recall
        """
        self.precision.reset()
        self.recall.reset()

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        prec = self.precision(y_pred, y_true)
        rec = self.recall(y_pred, y_true)
        beta2 = self.beta**2

        fbeta = ((1 + beta2) * prec * rec) / (beta2 * prec + rec + self.eps)

        if self.average:
            return np.array(fbeta.mean().item())  # type: ignore[attr-defined]
        else:
            return fbeta

reset

reset()

resets precision and recall

Source code in pytorch_widedeep/metrics.py
283
284
285
286
287
288
def reset(self):
    """
    resets precision and recall
    """
    self.precision.reset()
    self.recall.reset()

F1Score

Bases: Metric

Class to calculate the f1 score for both binary and categorical problems

Parameters:

Name Type Description Default
average bool

This applies only to multiclass problems. if True calculate f1 for each label, and find their unweighted mean.

True

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import F1Score
>>>
>>> f1 = F1Score()
>>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
>>> f1(y_pred, y_true)
array(0.5)
>>>
>>> f1 = F1Score()
>>> y_true = torch.tensor([0, 1, 2])
>>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
>>> f1(y_pred, y_true)
array(0.33333334)
Source code in pytorch_widedeep/metrics.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class F1Score(Metric):
    r"""Class to calculate the f1 score for both binary and categorical problems

    Parameters
    ----------
    average: bool, default = True
        This applies only to multiclass problems. if ``True`` calculate f1 for
        each label, and find their unweighted mean.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import F1Score
    >>>
    >>> f1 = F1Score()
    >>> y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)
    >>> f1(y_pred, y_true)
    array(0.5)
    >>>
    >>> f1 = F1Score()
    >>> y_true = torch.tensor([0, 1, 2])
    >>> y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])
    >>> f1(y_pred, y_true)
    array(0.33333334)
    """

    def __init__(self, average: bool = True):
        super(F1Score, self).__init__()

        self.average = average
        self.f1 = FBetaScore(beta=1, average=self.average)
        self._name = self.f1._name

    def reset(self):
        """
        resets counters to 0
        """
        self.f1.reset()

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        return self.f1(y_pred, y_true)

reset

reset()

resets counters to 0

Source code in pytorch_widedeep/metrics.py
338
339
340
341
342
def reset(self):
    """
    resets counters to 0
    """
    self.f1.reset()

R2Score

Bases: Metric

Calculates R-Squared, the coefficient of determination:

\[ R^2 = 1 - \frac{\sum_{j=1}^n(y_j - \hat{y_j})^2}{\sum_{j=1}^n(y_j - \bar{y})^2} \]

where \(\hat{y_j}\) is the ground truth, \(y_j\) is the predicted value and \(\bar{y}\) is the mean of the ground truth.

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.metrics import R2Score
>>>
>>> r2 = R2Score()
>>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1)
>>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1)
>>> r2(y_pred, y_true)
array(0.94860814)
Source code in pytorch_widedeep/metrics.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class R2Score(Metric):
    r"""
    Calculates R-Squared, the
    [coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination>):

    $$
    R^2 = 1 - \frac{\sum_{j=1}^n(y_j - \hat{y_j})^2}{\sum_{j=1}^n(y_j - \bar{y})^2}
    $$

    where $\hat{y_j}$ is the ground truth, $y_j$ is the predicted value and
    $\bar{y}$ is the mean of the ground truth.

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.metrics import R2Score
    >>>
    >>> r2 = R2Score()
    >>> y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1)
    >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1)
    >>> r2(y_pred, y_true)
    array(0.94860814)
    """

    def __init__(self):
        self.numerator = 0
        self.denominator = 0
        self.num_examples = 0
        self.y_true_sum = 0

        self._name = "r2"

    def reset(self):
        """
        resets counters to 0
        """
        self.numerator = 0
        self.denominator = 0
        self.num_examples = 0
        self.y_true_sum = 0

    def __call__(self, y_pred: Tensor, y_true: Tensor) -> np.ndarray:
        self.numerator += ((y_pred - y_true) ** 2).sum().item()

        self.num_examples += y_true.shape[0]
        self.y_true_sum += y_true.sum().item()
        y_true_avg = self.y_true_sum / self.num_examples
        self.denominator += ((y_true - y_true_avg) ** 2).sum().item()
        return np.array((1 - (self.numerator / self.denominator)))

reset

reset()

resets counters to 0

Source code in pytorch_widedeep/metrics.py
381
382
383
384
385
386
387
388
def reset(self):
    """
    resets counters to 0
    """
    self.numerator = 0
    self.denominator = 0
    self.num_examples = 0
    self.y_true_sum = 0