Losses¶
Custom loss functions for semantic segmentation.
training.losses
¶
Custom loss functions for segmentation models.
CombinedDiceFocalLoss(dice_weight: float = 0.5, focal_weight: float = 0.5, class_weights: list | tuple | torch.Tensor | None = None, dice_kwargs: dict | None = None, focal_kwargs: dict | None = None, smooth: float = 1e-06)
¶
Bases: Module
Combined loss that weights Dice and Focal losses with optional class weights.
Initialize the combined Dice + Focal loss module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dice_weight
|
float
|
Weight for the Dice loss component. |
0.5
|
focal_weight
|
float
|
Weight for the Focal loss component. |
0.5
|
class_weights
|
list | tuple | Tensor | None
|
Optional per-class weights for handling class imbalance. |
None
|
dice_kwargs
|
dict | None
|
Optional kwargs for the Dice loss constructor. |
None
|
focal_kwargs
|
dict | None
|
Optional kwargs for the Focal loss constructor. |
None
|
smooth
|
float
|
Smoothing factor for Dice loss numerical stability. |
1e-06
|
Source code in src/training/losses.py
forward(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor
¶
Compute the combined loss value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Tensor
|
Raw logits of shape (N, C, H, W). |
required |
targets
|
Tensor
|
Ground truth of shape (N, H, W) with class indices. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Combined weighted loss as a scalar tensor. |
Source code in src/training/losses.py
WeightedCrossEntropyDiceLoss(ce_weight: float = 1.0, dice_weight: float = 1.0, class_weights: list | tuple | torch.Tensor | None = None, ce_kwargs: dict | None = None, dice_kwargs: dict | None = None, smooth: float = 1e-06)
¶
Bases: Module
Weighted Cross-Entropy + Weighted Dice loss.
Combines a (optionally weighted) cross-entropy loss with a weighted Dice loss. Both components apply class weights to handle class imbalance.
Notes:
- class_weights may be a list, tuple or torch.Tensor. It will be
converted to a tensor on the same device as the predictions at
runtime to avoid device-mismatch issues.
- ce_kwargs allows customization of the cross-entropy component.
- The Dice component computes per-class Dice scores and weights them
before averaging, giving minority classes more influence.
Initialize the weighted Cross-Entropy + Dice loss module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ce_weight
|
float
|
Weight for the Cross-Entropy loss component. |
1.0
|
dice_weight
|
float
|
Weight for the Dice loss component. |
1.0
|
class_weights
|
list | tuple | Tensor | None
|
Optional per-class weights for both CE and Dice. |
None
|
ce_kwargs
|
dict | None
|
Optional kwargs for the Cross-Entropy loss. |
None
|
dice_kwargs
|
dict | None
|
Optional kwargs for Dice loss (used to extract ignore_index). |
None
|
smooth
|
float
|
Smoothing factor for Dice loss numerical stability. |
1e-06
|
Source code in src/training/losses.py
forward(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor
¶
Compute the weighted combination of cross-entropy and Dice loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Tensor
|
Model predictions of shape (N, C, H, W) where N is batch size, C is number of classes, and H, W are spatial dimensions. Must be raw logits (not softmax probabilities) as both F.cross_entropy and the Dice computation apply softmax internally. |
required |
targets
|
Tensor
|
Ground truth labels of shape (N, H, W) with class indices. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Combined loss value as a scalar tensor. |
Notes
- Cross-entropy loss applies class weights and respects ignore_index.
- Dice loss applies class weights via weighted averaging of per-class scores.
- Both losses expect raw logits and handle softmax internally.