Skip to content

myco.losses

Loss functions for evaluating model performance and optimizing model training. Losses are a subset of metrics - all losses are metrics, but not all metrics are losses. This is because a loss must decrease in value to indicate an increase in model performance. R-squared is an example of a metric that is not a loss - an increase in r2 corresponds to better performance. Losses and Metrics are also independed python object types that have different assumptions regarding subclasses and class functions.

MycoLoss

Bases: tf.keras.losses.Loss

Base class for creating loss functions to evaluate performance on slices (subsets) of data

Source code in myco/losses.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
@tf.keras.utils.register_keras_serializable(package="myco")
class MycoLoss(tf.keras.losses.Loss):
    """Base class for creating loss functions to evaluate performance on slices (subsets) of data"""

    loss_function: Callable = None
    name: str = None
    nodata: float = None
    greater_than: float = None
    less_than: float = None
    dtype: tf.dtypes.DType = None

    def __init__(
        self,
        loss_function: Callable,
        name: str = None,
        nodata: float = None,
        greater_than: float = None,
        less_than: float = None,
        dtype: tf.dtypes.DType = tf.float32,
        **kwargs,
    ):
        """Create a custom loss function that filters input data prior to the loss calculation.

        Args:
            loss_function: a callable function that takes (y_true, y_pred) arguments and returns a scalar value
            name: the name of the loss function (this is printed during model training)
            nodata: the nodata value to exclude from calculations
            greater_than: only include values above this threshold
            less_than: only include values below this threshold
            dtype: the tf data type to compute the metric in. Not implemented yet.
        """
        if name is None:
            if hasattr(loss_function, "name"):
                name = loss_function.name
            else:
                name = "custom_loss"

        super().__init__(name=name)

        self.loss_function = loss_function
        self.nodata = nodata
        self.greater_than = greater_than
        self.less_than = less_than
        self.dtype = dtype

    def _format_tensors(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, weights: tf.Tensor
    ) -> tuple:
        """Ensure consistent data types, shapes, and weights values for all inputs"""
        y_true = tf.cast(y_true, dtype=self.dtype)
        y_pred = tf.cast(y_pred, dtype=self.dtype)

        # set weight size to (n_samples, ysize, xsize, 1) in case of multi-label
        if weights is None:
            weights = tf.ones_like(y_true, dtype=tf.float32)
        else:
            weights = tf.cast(weights, dtype=tf.float32)

        return y_true, y_pred, weights

    def _filter_nodata(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, weights: tf.Tensor
    ) -> tuple:
        """Removes nodata values from y_true and y_pred tensors"""
        is_valid = tf.math.not_equal(
            y_true, tf.convert_to_tensor(self.nodata, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        weights = tf.boolean_mask(weights, is_valid)
        return y_true, y_pred, weights

    def _filter_greater_than(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, weights: tf.Tensor
    ) -> tuple:
        """Removes values below the `greater_than` threshold from tensors"""
        is_valid = tf.math.greater(
            y_true, tf.convert_to_tensor(self.greater_than, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        weights = tf.boolean_mask(weights, is_valid)
        return y_true, y_pred, weights

    def _filter_less_than(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, weights: tf.Tensor
    ) -> tuple:
        """Removes values above the `less_than` threshold from tensors"""
        is_valid = tf.math.less(
            y_true, tf.convert_to_tensor(self.less_than, dtype=y_true.dtype)
        )
        is_valid = tf.reduce_all(is_valid, axis=-1)
        y_true = tf.boolean_mask(y_true, is_valid)
        y_pred = tf.boolean_mask(y_pred, is_valid)
        weights = tf.boolean_mask(weights, is_valid)
        return y_true, y_pred, weights

    def call(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None
    ):
        """Logic for the loss calculation"""

        # convert input data to dtype-formatted tensors
        y_true, y_pred, weights = self._format_tensors(y_true, y_pred, sample_weight)

        # filter nodata/low/high values
        if self.nodata is not None:
            y_true, y_pred, weights = self._filter_nodata(y_true, y_pred, weights)
        if self.greater_than is not None:
            y_true, y_pred, weights = self._filter_greater_than(y_true, y_pred, weights)
        if self.less_than is not None:
            y_true, y_pred, weights = self._filter_less_than(y_true, y_pred, weights)

        return self.loss_function(y_true, y_pred, weights)

    def get_config(self) -> dict:
        config = super().get_config().copy()
        config.update(
            nodata=self.nodata,
            greater_than=self.greater_than,
            less_than=self.less_than,
        )
        return config

    @classmethod
    def from_config(cls, config) -> "MycoLoss":
        return cls(**config)

__init__(loss_function, name=None, nodata=None, greater_than=None, less_than=None, dtype=tf.float32, **kwargs)

Create a custom loss function that filters input data prior to the loss calculation.

Parameters:

Name Type Description Default
loss_function Callable

a callable function that takes (y_true, y_pred) arguments and returns a scalar value

required
name str

the name of the loss function (this is printed during model training)

None
nodata float

the nodata value to exclude from calculations

None
greater_than float

only include values above this threshold

None
less_than float

only include values below this threshold

None
dtype tf.dtypes.DType

the tf data type to compute the metric in. Not implemented yet.

tf.float32
Source code in myco/losses.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    loss_function: Callable,
    name: str = None,
    nodata: float = None,
    greater_than: float = None,
    less_than: float = None,
    dtype: tf.dtypes.DType = tf.float32,
    **kwargs,
):
    """Create a custom loss function that filters input data prior to the loss calculation.

    Args:
        loss_function: a callable function that takes (y_true, y_pred) arguments and returns a scalar value
        name: the name of the loss function (this is printed during model training)
        nodata: the nodata value to exclude from calculations
        greater_than: only include values above this threshold
        less_than: only include values below this threshold
        dtype: the tf data type to compute the metric in. Not implemented yet.
    """
    if name is None:
        if hasattr(loss_function, "name"):
            name = loss_function.name
        else:
            name = "custom_loss"

    super().__init__(name=name)

    self.loss_function = loss_function
    self.nodata = nodata
    self.greater_than = greater_than
    self.less_than = less_than
    self.dtype = dtype

call(y_true, y_pred, sample_weight=None)

Logic for the loss calculation

Source code in myco/losses.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def call(
    self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: tf.Tensor = None
):
    """Logic for the loss calculation"""

    # convert input data to dtype-formatted tensors
    y_true, y_pred, weights = self._format_tensors(y_true, y_pred, sample_weight)

    # filter nodata/low/high values
    if self.nodata is not None:
        y_true, y_pred, weights = self._filter_nodata(y_true, y_pred, weights)
    if self.greater_than is not None:
        y_true, y_pred, weights = self._filter_greater_than(y_true, y_pred, weights)
    if self.less_than is not None:
        y_true, y_pred, weights = self._filter_less_than(y_true, y_pred, weights)

    return self.loss_function(y_true, y_pred, weights)

get_loss(name)

Returns an un-initiated loss function object by name.

Source code in myco/losses.py
211
212
213
214
def get_loss(name: str) -> MycoLoss:
    """Returns an un-initiated loss function object by name."""
    assert name in get_names(), f"Invalid loss function: {name}"
    return SUPPORTED[name]

get_names()

Returns a list of the available loss functions supported in configuration.

Source code in myco/losses.py
206
207
208
def get_names():
    """Returns a list of the available loss functions supported in configuration."""
    return list(SUPPORTED.keys())