Skip to content

myco.metrics

Metrics for evaluating model performance and during and after model training. Metrics are a superset of loss functions - all losses are metrics, but not all metrics are losses. This is because a loss must decrease in value to indicate an increase in model performance. R-squared is an example of a metric that is not a loss - an increase in r2 corresponds to better performance. Losses and Metrics are also independent python object types that have different assumptions regarding subclasses and class functions.

MycoMetric

Bases: tf.keras.metrics.Metric

Common methods to attach to metrics, like nodata filtering.

Source code in myco/metrics.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@tf.keras.utils.register_keras_serializable(package="myco")
class MycoMetric(tf.keras.metrics.Metric):
    """Common methods to attach to metrics, like nodata filtering."""

    def __init__(
        self,
        name: str = None,
        nodata: float = None,
        greater_than: float = None,
        less_than: float = None,
        scaler: TFScaler = None,
        dtype: tf.dtypes.DType = tf.float32,
        is_categorical: bool = False,
        **kwargs,
    ) -> None:
        """Create a custom metric function that filters input data prior to the calculation.

        Args:
            name: the name of the metric (logged in model training)
            nodata: the nodata value to exclude from calculations.
            greater_than: only include values above this threshold.
            less_than: only include values below this threshold.
            scaler: apply inverse_transform methods to rescale data
            dtype: the tf data type to compute the metric in.
            is_categorical: add base categorical states for computing
                confusion matrix statistics
        """
        super().__init__(name=name, dtype=dtype, **kwargs)

        self.nodata = nodata
        self.greater_than = greater_than
        self.less_than = less_than
        self.scaler = scaler

        if is_categorical:
            self.tp = self.add_weight(
                name="true_positives",
                shape=(),
                initializer="zeros",
                dtype=dtype,
            )
            self.fp = self.add_weight(
                name="false_positives",
                shape=(),
                initializer="zeros",
                dtype=dtype,
            )
            self.tn = self.add_weight(
                name="true_negatives",
                shape=(),
                initializer="zeros",
                dtype=dtype,
            )
            self.fn = self.add_weight(
                name="false_negatives",
                shape=(),
                initializer="zeros",
                dtype=dtype,
            )

    def format_tensors(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
    ) -> tuple:
        """Ensure consistent data types, shapes, and weights values for all inputs"""
        # dtype casting
        y_true = tf.cast(y_true, dtype=self.dtype)
        y_pred = tf.cast(y_pred, dtype=self.dtype)

        # sample weight formatting
        if sample_weight is None:
            sample_weight = tf.ones_like(y_true, dtype=self.dtype)
        else:
            sample_weight = tf.cast(sample_weight, dtype=self.dtype)

        # nodata filtering
        if self.nodata is not None:
            y_true, y_pred, sample_weight = self.filter_nodata(
                y_true, y_pred, sample_weight
            )
        if self.greater_than is not None:
            y_true, y_pred, sample_weight = self.filter_greater_than(
                y_true, y_pred, sample_weight
            )
        if self.less_than is not None:
            y_true, y_pred, sample_weight = self.filter_less_than(
                y_true, y_pred, sample_weight
            )

        # inverse scaling
        if self.scaler is not None:
            y_true = self.scaler.inverse_transform(y_true)
            y_pred = self.scaler.inverse_transform(y_pred)

        return y_true, y_pred, sample_weight

    def format_categorical(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None
    ) -> tf.Tensor:
        """Compute categorical confusion matrix components"""
        tp = tf.keras.metrics.TruePositives()(y_true, y_pred, sample_weight)
        fp = tf.keras.metrics.FalsePositives()(y_true, y_pred, sample_weight)
        tn = tf.keras.metrics.TrueNegatives()(y_true, y_pred, sample_weight)
        fn = tf.keras.metrics.FalseNegatives()(y_true, y_pred, sample_weight)

        return tp, fp, tn, fn

    def filter_nodata(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
    ) -> tuple:
        """Removes nodata values from y_true and y_pred tensors"""
        is_valid = tf.math.not_equal(
            y_true, tf.convert_to_tensor(self.nodata, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        sample_weight = tf.boolean_mask(sample_weight, is_valid)
        return y_true, y_pred, sample_weight

    def filter_greater_than(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
    ) -> tuple:
        """Removes values below the `greater_than` threshold from tensors"""
        is_valid = tf.math.greater(
            y_true, tf.convert_to_tensor(self.greater_than, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        sample_weight = tf.boolean_mask(sample_weight, is_valid)
        return y_true, y_pred, sample_weight

    def filter_less_than(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
    ) -> tuple:
        """Removes values above the `less_than` threshold from tensors"""
        is_valid = tf.math.less(
            y_true, tf.convert_to_tensor(self.less_than, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        sample_weight = tf.boolean_mask(sample_weight, is_valid)
        return y_true, y_pred, sample_weight

    def get_config(self) -> dict:
        config = super().get_config()
        config.update(
            nodata=self.nodata,
            greater_than=self.greater_than,
            less_than=self.less_than,
            scaler=self.scaler,
        )
        return config

    @classmethod
    def from_config(cls, config) -> "MycoMetric":
        return cls(**config)

    def reset_state(self) -> None:
        K.batch_set_value([(v, np.zeros(v.shape)) for v in self.variables])

__init__(name=None, nodata=None, greater_than=None, less_than=None, scaler=None, dtype=tf.float32, is_categorical=False, **kwargs)

Create a custom metric function that filters input data prior to the calculation.

Parameters:

Name Type Description Default
name str

the name of the metric (logged in model training)

None
nodata float

the nodata value to exclude from calculations.

None
greater_than float

only include values above this threshold.

None
less_than float

only include values below this threshold.

None
scaler TFScaler

apply inverse_transform methods to rescale data

None
dtype tf.dtypes.DType

the tf data type to compute the metric in.

tf.float32
is_categorical bool

add base categorical states for computing confusion matrix statistics

False
Source code in myco/metrics.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    name: str = None,
    nodata: float = None,
    greater_than: float = None,
    less_than: float = None,
    scaler: TFScaler = None,
    dtype: tf.dtypes.DType = tf.float32,
    is_categorical: bool = False,
    **kwargs,
) -> None:
    """Create a custom metric function that filters input data prior to the calculation.

    Args:
        name: the name of the metric (logged in model training)
        nodata: the nodata value to exclude from calculations.
        greater_than: only include values above this threshold.
        less_than: only include values below this threshold.
        scaler: apply inverse_transform methods to rescale data
        dtype: the tf data type to compute the metric in.
        is_categorical: add base categorical states for computing
            confusion matrix statistics
    """
    super().__init__(name=name, dtype=dtype, **kwargs)

    self.nodata = nodata
    self.greater_than = greater_than
    self.less_than = less_than
    self.scaler = scaler

    if is_categorical:
        self.tp = self.add_weight(
            name="true_positives",
            shape=(),
            initializer="zeros",
            dtype=dtype,
        )
        self.fp = self.add_weight(
            name="false_positives",
            shape=(),
            initializer="zeros",
            dtype=dtype,
        )
        self.tn = self.add_weight(
            name="true_negatives",
            shape=(),
            initializer="zeros",
            dtype=dtype,
        )
        self.fn = self.add_weight(
            name="false_negatives",
            shape=(),
            initializer="zeros",
            dtype=dtype,
        )

filter_greater_than(y_true, y_pred, sample_weight)

Removes values below the greater_than threshold from tensors

Source code in myco/metrics.py
138
139
140
141
142
143
144
145
146
147
148
149
def filter_greater_than(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
) -> tuple:
    """Removes values below the `greater_than` threshold from tensors"""
    is_valid = tf.math.greater(
        y_true, tf.convert_to_tensor(self.greater_than, dtype=y_true.dtype)
    )
    is_valid = tf.reduce_all(is_valid, axis=-1)
    y_true = tf.boolean_mask(y_true, is_valid)
    y_pred = tf.boolean_mask(y_pred, is_valid)
    sample_weight = tf.boolean_mask(sample_weight, is_valid)
    return y_true, y_pred, sample_weight

filter_less_than(y_true, y_pred, sample_weight)

Removes values above the less_than threshold from tensors

Source code in myco/metrics.py
151
152
153
154
155
156
157
158
159
160
161
162
def filter_less_than(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
) -> tuple:
    """Removes values above the `less_than` threshold from tensors"""
    is_valid = tf.math.less(
        y_true, tf.convert_to_tensor(self.less_than, dtype=y_true.dtype)
    )
    is_valid = tf.reduce_all(is_valid, axis=-1)
    y_true = tf.boolean_mask(y_true, is_valid)
    y_pred = tf.boolean_mask(y_pred, is_valid)
    sample_weight = tf.boolean_mask(sample_weight, is_valid)
    return y_true, y_pred, sample_weight

filter_nodata(y_true, y_pred, sample_weight)

Removes nodata values from y_true and y_pred tensors

Source code in myco/metrics.py
125
126
127
128
129
130
131
132
133
134
135
136
def filter_nodata(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
) -> tuple:
    """Removes nodata values from y_true and y_pred tensors"""
    is_valid = tf.math.not_equal(
        y_true, tf.convert_to_tensor(self.nodata, dtype=y_true.dtype)
    )
    is_valid = tf.reduce_all(is_valid, axis=-1)
    y_true = tf.boolean_mask(y_true, is_valid)
    y_pred = tf.boolean_mask(y_pred, is_valid)
    sample_weight = tf.boolean_mask(sample_weight, is_valid)
    return y_true, y_pred, sample_weight

format_categorical(y_true, y_pred, sample_weight=None)

Compute categorical confusion matrix components

Source code in myco/metrics.py
114
115
116
117
118
119
120
121
122
123
def format_categorical(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None
) -> tf.Tensor:
    """Compute categorical confusion matrix components"""
    tp = tf.keras.metrics.TruePositives()(y_true, y_pred, sample_weight)
    fp = tf.keras.metrics.FalsePositives()(y_true, y_pred, sample_weight)
    tn = tf.keras.metrics.TrueNegatives()(y_true, y_pred, sample_weight)
    fn = tf.keras.metrics.FalseNegatives()(y_true, y_pred, sample_weight)

    return tp, fp, tn, fn

format_tensors(y_true, y_pred, sample_weight)

Ensure consistent data types, shapes, and weights values for all inputs

Source code in myco/metrics.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def format_tensors(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor
) -> tuple:
    """Ensure consistent data types, shapes, and weights values for all inputs"""
    # dtype casting
    y_true = tf.cast(y_true, dtype=self.dtype)
    y_pred = tf.cast(y_pred, dtype=self.dtype)

    # sample weight formatting
    if sample_weight is None:
        sample_weight = tf.ones_like(y_true, dtype=self.dtype)
    else:
        sample_weight = tf.cast(sample_weight, dtype=self.dtype)

    # nodata filtering
    if self.nodata is not None:
        y_true, y_pred, sample_weight = self.filter_nodata(
            y_true, y_pred, sample_weight
        )
    if self.greater_than is not None:
        y_true, y_pred, sample_weight = self.filter_greater_than(
            y_true, y_pred, sample_weight
        )
    if self.less_than is not None:
        y_true, y_pred, sample_weight = self.filter_less_than(
            y_true, y_pred, sample_weight
        )

    # inverse scaling
    if self.scaler is not None:
        y_true = self.scaler.inverse_transform(y_true)
        y_pred = self.scaler.inverse_transform(y_pred)

    return y_true, y_pred, sample_weight

get_metric(name)

Returns an un-initialized metric object by name.

Source code in myco/metrics.py
800
801
802
803
def get_metric(name: str) -> MycoMetric:
    """Returns an un-initialized metric object by name."""
    assert name in get_names(), f"Invalid metric: {name}"
    return SUPPORTED[name]

get_names()

Returns a list of the available metrics supported in configuration.

Source code in myco/metrics.py
795
796
797
def get_names():
    """Returns a list of the available metrics supported in configuration."""
    return list(SUPPORTED.keys())