Skip to content

Validation

Evaluation metrics computation including IoU, F1, accuracy, and confusion matrices.

training.validation

calculate_iou_scores(conf_matrix: torch.Tensor, num_classes: int, other_class_index: int = 13) -> tuple[float, dict[int, float]]

Compute mean IoU (mIoU) and per-class IoU.

Source code in src/training/validation.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def calculate_iou_scores(
    conf_matrix: torch.Tensor,
    num_classes: int,
    other_class_index: int = 13,
) -> tuple[float, dict[int, float]]:
    """Compute mean IoU (mIoU) and per-class IoU."""
    tp = torch.diag(conf_matrix)
    fp = conf_matrix.sum(dim=0) - tp
    fn = conf_matrix.sum(dim=1) - tp
    union = tp + fp + fn

    iou = torch.where(
        union != 0,
        tp.float() / union.float(),
        torch.zeros_like(union, dtype=torch.float),
    )

    valid_classes_mask = torch.ones(
        num_classes,
        dtype=torch.bool,
        device=conf_matrix.device,
    )

    if 0 <= other_class_index < num_classes:
        valid_classes_mask[other_class_index] = False
    else:
        logger.warning(
            "other_class_index %s is out of range [0, %s). It will be ignored.",
            other_class_index,
            num_classes,
        )

    valid_indices = valid_classes_mask.nonzero(as_tuple=True)[0]
    per_class_iou = {int(i): iou[i].item() for i in valid_indices}

    mean_iou = iou[valid_classes_mask].mean().item() if valid_classes_mask.any() else 0.0

    return mean_iou, per_class_iou

compute_timing_metrics(inference_times: list[float], batch_sizes: list[int]) -> dict[str, float]

Compute timing metrics from inference times and batch sizes.

Source code in src/training/validation.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def compute_timing_metrics(
    inference_times: list[float],
    batch_sizes: list[int],
) -> dict[str, float]:
    """Compute timing metrics from inference times and batch sizes."""
    total_inference_time = sum(inference_times)
    total_images = sum(batch_sizes)
    avg_time_per_image = total_inference_time / total_images if total_images > 0 else 0.0
    avg_time_per_batch = sum(inference_times) / len(inference_times) if inference_times else 0.0

    return {
        "total_inference_time": total_inference_time,
        "total_images": total_images,
        "avg_time_per_image": avg_time_per_image,
        "avg_time_per_batch": avg_time_per_batch,
    }

evaluate(model: nn.Module, device: torch.device, data_loader: DataLoader, num_classes: int, other_class_index: int = 13, *, output_size: int | None = None, log_eval_metrics: bool = True, log_confusion_matrix: bool = True, normalize_confusion_matrix: bool = True, sample_ids_to_plot: list[str] | None = None, warmup_runs: int = 10, visualization_labels: dict[str, str] | None = None, class_name_mapping: dict[int, str] | None = None, zone_mosaic_config: dict | None = None, zone_data_loader: DataLoader | None = None) -> dict[str, float]

Evaluate model and log metrics and plots.

Parameters:

Name Type Description Default
model Module

Model to evaluate.

required
device device

Torch device.

required
data_loader DataLoader

Evaluation DataLoader.

required
num_classes int

Number of classes.

required
other_class_index int

Index of 'other' class to exclude from mIoU.

13
log_eval_metrics bool

Whether to log scalar metrics.

True
log_confusion_matrix bool

Whether to log confusion matrix plot and CSV.

True
normalize_confusion_matrix bool

Normalize confusion matrix rows.

True
sample_ids_to_plot list[str] | None

Optional list of sample ids for individual prediction plots.

None
warmup_runs int

Warmup forward passes (ignored in timing).

10
visualization_labels dict[str, str] | None

Optional dict overriding plot text labels.

None
class_name_mapping dict[int, str] | None

Mapping from class index to readable name.

None
zone_mosaic_config dict | None

Optional config for zone prediction mosaic visualization. Expected keys: 'enabled', 'zone_name', 'grid_size', 'patch_size'.

None
zone_data_loader DataLoader | None

Optional DataLoader for a specific zone (for mosaic visualization).

None
Source code in src/training/validation.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
def evaluate(
    model: nn.Module,
    device: torch.device,
    data_loader: DataLoader,
    num_classes: int,
    other_class_index: int = 13,
    *,
    output_size: int | None = None,
    log_eval_metrics: bool = True,
    log_confusion_matrix: bool = True,
    normalize_confusion_matrix: bool = True,
    sample_ids_to_plot: list[str] | None = None,
    warmup_runs: int = 10,
    visualization_labels: dict[str, str] | None = None,
    class_name_mapping: dict[int, str] | None = None,
    zone_mosaic_config: dict | None = None,
    zone_data_loader: DataLoader | None = None,
) -> dict[str, float]:
    """Evaluate model and log metrics and plots.

    Args:
        model: Model to evaluate.
        device: Torch device.
        data_loader: Evaluation DataLoader.
        num_classes: Number of classes.
        other_class_index: Index of 'other' class to exclude from mIoU.
        log_eval_metrics: Whether to log scalar metrics.
        log_confusion_matrix: Whether to log confusion matrix plot and CSV.
        normalize_confusion_matrix: Normalize confusion matrix rows.
        sample_ids_to_plot: Optional list of sample ids for individual prediction plots.
        warmup_runs: Warmup forward passes (ignored in timing).
        visualization_labels: Optional dict overriding plot text labels.
        class_name_mapping: Mapping from class index to readable name.
        zone_mosaic_config: Optional config for zone prediction mosaic visualization.
            Expected keys: 'enabled', 'zone_name', 'grid_size', 'patch_size'.
        zone_data_loader: Optional DataLoader for a specific zone (for mosaic visualization).

    """
    if class_name_mapping is None:
        class_name_mapping = {i: f"class_{i}" for i in range(num_classes)}

    model.eval()
    model.to(device)
    evaluation_metrics_dict = get_evaluation_metrics_dict(
        num_classes,
        device,
        other_class_index=other_class_index,
    )
    logger.info("Starting evaluation on %d batches", len(data_loader))

    sample_ids_to_log = set(sample_ids_to_plot) if sample_ids_to_plot else set()

    first_batch = next(iter(data_loader))
    is_temporal_model = first_batch[BATCH_INDEX_INPUTS].ndim == TEMPORAL_MODEL_NDIM

    if is_temporal_model:
        logger.info("Detected temporal model (5D input). Using temporal evaluation.")
        _perform_warmup_temporal(model, device, data_loader, warmup_runs)
        inference_times, batch_sizes = _evaluate_batches_temporal(
            model,
            device,
            data_loader,
            num_classes,
            evaluation_metrics_dict,
            sample_ids_to_log,
            output_size=output_size,
        )
    else:
        logger.info("Detected standard model. Using standard evaluation.")
        _perform_warmup_standard(model, device, data_loader, warmup_runs)
        inference_times, batch_sizes = _evaluate_batches_standard(
            model,
            device,
            data_loader,
            num_classes,
            evaluation_metrics_dict,
            sample_ids_to_log,
        )

    if zone_mosaic_config and zone_mosaic_config.get("enabled", False):
        if zone_data_loader is None:
            logger.warning(
                "Zone mosaic enabled but no zone_data_loader provided. Skipping mosaic.",
            )
        else:
            zone_name = zone_mosaic_config.get("zone_name", "zone")
            zone_grid_size = zone_mosaic_config.get("grid_size", 10)
            zone_patch_size = zone_mosaic_config.get("patch_size", 512)

            log_prediction_mosaic_to_mlflow(
                model=model,
                data_loader=zone_data_loader,
                device=device,
                num_classes=num_classes,
                zone_name=zone_name,
                grid_size=zone_grid_size,
                patch_size=zone_patch_size,
            )

    logger.info("Finished evaluation")

    return _finalize_evaluation(
        evaluation_metrics_dict,
        inference_times,
        batch_sizes,
        class_name_mapping,
        other_class_index,
        normalize_confusion_matrix=normalize_confusion_matrix,
        visualization_labels=visualization_labels,
        log_eval_metrics=log_eval_metrics,
        log_confusion_matrix=log_confusion_matrix,
    )

get_evaluation_metrics_dict(num_classes: int, device: torch.device, other_class_index: int | None = None) -> dict[str, Metric]

Initialize TorchMetrics for multiclass classification.

Source code in src/training/validation.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_evaluation_metrics_dict(
    num_classes: int,
    device: torch.device,
    other_class_index: int | None = None,
) -> dict[str, Metric]:
    """Initialize TorchMetrics for multiclass classification."""
    return {
        "conf_matrix": MulticlassConfusionMatrix(
            num_classes=num_classes,
            ignore_index=other_class_index,
        ).to(device),
        "macro_f1": MulticlassF1Score(
            num_classes=num_classes,
            average="macro",
            ignore_index=other_class_index,
        ).to(device),
        "f1_per_class": MulticlassF1Score(
            num_classes=num_classes,
            average=None,
            ignore_index=other_class_index,
        ).to(device),
        "overall_f1": MulticlassF1Score(
            num_classes=num_classes,
            average="micro",
            ignore_index=other_class_index,
        ).to(device),
        "macro_accuracy": MulticlassAccuracy(
            num_classes=num_classes,
            average="macro",
            ignore_index=other_class_index,
        ).to(
            device,
        ),
        "overall_accuracy": MulticlassAccuracy(
            num_classes=num_classes,
            average="micro",
            ignore_index=other_class_index,
        ).to(
            device,
        ),
    }

upsample_predictions(outputs: torch.Tensor, target_size: tuple[int, int], output_size: int | None = None) -> torch.Tensor

Upsample model predictions to match the target mask size.

Uses bilinear interpolation on logits (before argmax) for smoother boundaries. When using context window, center-crops to output_size before upsampling.

Parameters:

Name Type Description Default
outputs Tensor

Model outputs with shape (B, C, H, W) where C is num_classes.

required
target_size tuple[int, int]

Tuple of (height, width) to upsample to (mask size).

required
output_size int | None

Expected output spatial size after center-crop. If provided and model output is larger, center-crops to this size before upsampling. Use sentinel_patch_size when using context window.

None

Returns:

Type Description
Tensor

Upsampled logits with shape (B, C, target_size[0], target_size[1]).

Source code in src/training/validation.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def upsample_predictions(
    outputs: torch.Tensor,
    target_size: tuple[int, int],
    output_size: int | None = None,
) -> torch.Tensor:
    """Upsample model predictions to match the target mask size.

    Uses bilinear interpolation on logits (before argmax) for smoother boundaries.
    When using context window, center-crops to output_size before upsampling.

    Args:
        outputs: Model outputs with shape (B, C, H, W) where C is num_classes.
        target_size: Tuple of (height, width) to upsample to (mask size).
        output_size: Expected output spatial size after center-crop. If provided and
            model output is larger, center-crops to this size before upsampling.
            Use sentinel_patch_size when using context window.

    Returns:
        Upsampled logits with shape (B, C, target_size[0], target_size[1]).

    """
    if outputs.shape[-2:] == target_size:
        return outputs

    out_h, out_w = outputs.shape[-2:]

    # Center-crop if using context window (output_size specified and output is larger)
    if output_size is not None and out_h > output_size:
        crop_margin_h = (out_h - output_size) // 2
        crop_margin_w = (out_w - output_size) // 2
        outputs = outputs[
            :,
            :,
            crop_margin_h : crop_margin_h + output_size,
            crop_margin_w : crop_margin_w + output_size,
        ]

    upsampled = F.interpolate(
        outputs,
        size=target_size,
        mode="bilinear",
        align_corners=False,
    )

    return upsampled