Skip to content

Losses

Custom loss functions for semantic segmentation.

training.losses

Custom loss functions for segmentation models.

CombinedDiceFocalLoss(dice_weight: float = 0.5, focal_weight: float = 0.5, class_weights: list | tuple | torch.Tensor | None = None, dice_kwargs: dict | None = None, focal_kwargs: dict | None = None, smooth: float = 1e-06)

Bases: Module

Combined loss that weights Dice and Focal losses with optional class weights.

Initialize the combined Dice + Focal loss module.

Parameters:

Name Type Description Default
dice_weight float

Weight for the Dice loss component.

0.5
focal_weight float

Weight for the Focal loss component.

0.5
class_weights list | tuple | Tensor | None

Optional per-class weights for handling class imbalance.

None
dice_kwargs dict | None

Optional kwargs for the Dice loss constructor.

None
focal_kwargs dict | None

Optional kwargs for the Focal loss constructor.

None
smooth float

Smoothing factor for Dice loss numerical stability.

1e-06
Source code in src/training/losses.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    dice_weight: float = 0.5,
    focal_weight: float = 0.5,
    class_weights: list | tuple | torch.Tensor | None = None,
    dice_kwargs: dict | None = None,
    focal_kwargs: dict | None = None,
    smooth: float = 1e-6,
) -> None:
    """Initialize the combined Dice + Focal loss module.

    Args:
        dice_weight: Weight for the Dice loss component.
        focal_weight: Weight for the Focal loss component.
        class_weights: Optional per-class weights for handling class imbalance.
        dice_kwargs: Optional kwargs for the Dice loss constructor.
        focal_kwargs: Optional kwargs for the Focal loss constructor.
        smooth: Smoothing factor for Dice loss numerical stability.

    """
    super().__init__()

    self.dice_weight = dice_weight
    self.focal_weight = focal_weight
    self.smooth = smooth

    if class_weights is not None:
        if not isinstance(class_weights, torch.Tensor):
            class_weights = torch.tensor(class_weights, dtype=torch.float)
        else:
            class_weights = class_weights.float()
        self.register_buffer("class_weights", class_weights)
    else:
        self.register_buffer("class_weights", None)

    dice_params = dice_kwargs or {"mode": "multiclass"}
    self.ignore_index = dice_params.get("ignore_index", None)
    self.dice_loss = smp.losses.DiceLoss(**dice_params)

    focal_params = {"mode": "multiclass"}  # Default
    if focal_kwargs:
        focal_params.update(focal_kwargs)  # Merge user params
    self.focal_loss = smp.losses.FocalLoss(**focal_params)

forward(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor

Compute the combined loss value.

Parameters:

Name Type Description Default
predictions Tensor

Raw logits of shape (N, C, H, W).

required
targets Tensor

Ground truth of shape (N, H, W) with class indices.

required

Returns:

Type Description
Tensor

Combined weighted loss as a scalar tensor.

Source code in src/training/losses.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Compute the combined loss value.

    Args:
        predictions: Raw logits of shape (N, C, H, W).
        targets: Ground truth of shape (N, H, W) with class indices.

    Returns:
        Combined weighted loss as a scalar tensor.

    """
    # Use weighted Dice if class weights provided, otherwise use SMP's DiceLoss
    if self.class_weights is not None:
        dice_loss_value = _compute_weighted_dice_loss(
            predictions,
            targets,
            class_weights=self.class_weights,
            ignore_index=self.ignore_index,
            smooth=self.smooth,
        )
    else:
        dice_loss_value = self.dice_loss(predictions, targets)

    focal_loss_value = self.focal_loss(predictions, targets)

    return self.dice_weight * dice_loss_value + self.focal_weight * focal_loss_value

WeightedCrossEntropyDiceLoss(ce_weight: float = 1.0, dice_weight: float = 1.0, class_weights: list | tuple | torch.Tensor | None = None, ce_kwargs: dict | None = None, dice_kwargs: dict | None = None, smooth: float = 1e-06)

Bases: Module

Weighted Cross-Entropy + Weighted Dice loss.

Combines a (optionally weighted) cross-entropy loss with a weighted Dice loss. Both components apply class weights to handle class imbalance.

Notes: - class_weights may be a list, tuple or torch.Tensor. It will be converted to a tensor on the same device as the predictions at runtime to avoid device-mismatch issues. - ce_kwargs allows customization of the cross-entropy component. - The Dice component computes per-class Dice scores and weights them before averaging, giving minority classes more influence.

Initialize the weighted Cross-Entropy + Dice loss module.

Parameters:

Name Type Description Default
ce_weight float

Weight for the Cross-Entropy loss component.

1.0
dice_weight float

Weight for the Dice loss component.

1.0
class_weights list | tuple | Tensor | None

Optional per-class weights for both CE and Dice.

None
ce_kwargs dict | None

Optional kwargs for the Cross-Entropy loss.

None
dice_kwargs dict | None

Optional kwargs for Dice loss (used to extract ignore_index).

None
smooth float

Smoothing factor for Dice loss numerical stability.

1e-06
Source code in src/training/losses.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __init__(
    self,
    ce_weight: float = 1.0,
    dice_weight: float = 1.0,
    class_weights: list | tuple | torch.Tensor | None = None,
    ce_kwargs: dict | None = None,
    dice_kwargs: dict | None = None,
    smooth: float = 1e-6,
) -> None:
    """Initialize the weighted Cross-Entropy + Dice loss module.

    Args:
        ce_weight: Weight for the Cross-Entropy loss component.
        dice_weight: Weight for the Dice loss component.
        class_weights: Optional per-class weights for both CE and Dice.
        ce_kwargs: Optional kwargs for the Cross-Entropy loss.
        dice_kwargs: Optional kwargs for Dice loss (used to extract ignore_index).
        smooth: Smoothing factor for Dice loss numerical stability.

    """
    super().__init__()

    self.ce_weight = float(ce_weight)
    self.dice_weight = float(dice_weight)
    self.smooth = smooth

    if class_weights is not None:
        if not isinstance(class_weights, torch.Tensor):
            class_weights = torch.tensor(class_weights, dtype=torch.float)
        else:
            class_weights = class_weights.float()
        self.register_buffer("class_weights", class_weights)
    else:
        self.register_buffer("class_weights", None)

    self.ce_kwargs = ce_kwargs or {}
    dice_params = dice_kwargs or {}
    self.ignore_index = dice_params.get("ignore_index", self.ce_kwargs.get("ignore_index"))

forward(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor

Compute the weighted combination of cross-entropy and Dice loss.

Parameters:

Name Type Description Default
predictions Tensor

Model predictions of shape (N, C, H, W) where N is batch size, C is number of classes, and H, W are spatial dimensions. Must be raw logits (not softmax probabilities) as both F.cross_entropy and the Dice computation apply softmax internally.

required
targets Tensor

Ground truth labels of shape (N, H, W) with class indices.

required

Returns:

Type Description
Tensor

Combined loss value as a scalar tensor.

Notes
  • Cross-entropy loss applies class weights and respects ignore_index.
  • Dice loss applies class weights via weighted averaging of per-class scores.
  • Both losses expect raw logits and handle softmax internally.
Source code in src/training/losses.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Compute the weighted combination of cross-entropy and Dice loss.

    Args:
        predictions: Model predictions of shape (N, C, H, W) where N is batch size,
            C is number of classes, and H, W are spatial dimensions.
            **Must be raw logits (not softmax probabilities)** as both
            F.cross_entropy and the Dice computation apply softmax internally.
        targets: Ground truth labels of shape (N, H, W) with class indices.

    Returns:
        Combined loss value as a scalar tensor.

    Notes:
        - Cross-entropy loss applies class weights and respects ignore_index.
        - Dice loss applies class weights via weighted averaging of per-class scores.
        - Both losses expect raw logits and handle softmax internally.

    """
    ce_loss = F.cross_entropy(
        predictions,
        targets.long(),
        weight=self.class_weights,
        **self.ce_kwargs,
    )

    dice_loss = _compute_weighted_dice_loss(
        predictions,
        targets,
        class_weights=self.class_weights,
        ignore_index=self.ignore_index,
        smooth=self.smooth,
    )

    return self.ce_weight * ce_loss + self.dice_weight * dice_loss