Training Loop¶
Training and validation loop with mixed precision, gradient accumulation, and MLflow integration.
training.train
¶
prepare_output_for_comparison(outputs: torch.Tensor, target_size: tuple[int, int], output_size: int | None = None) -> torch.Tensor
¶
Prepare model outputs for comparison with target mask.
When using context window (model output larger than output_size), center-crops to output_size first, then upsamples to target_size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
Tensor
|
Model predictions with shape (B, C, H, W) |
required |
target_size
|
tuple[int, int]
|
Target spatial size (height, width) to match mask |
required |
output_size
|
int | None
|
Expected output spatial size for center-cropping. If provided and output is larger, center-crops to this size. Use sentinel_patch_size when using context window. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor with shape (B, C, target_size[0], target_size[1]) |
Source code in src/training/train.py
train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, scheduler: LRScheduler | None = None, epochs: int = 100, patience: int = 20, num_classes: int = 13, other_class_index: int | None = None, accumulation_steps: int = 1, early_stopping_criterion: str = 'loss', *, use_amp: bool = False, apply_augmentations: bool = True, data_config: dict[str, Any] | None = None, log_evaluation_metrics: bool = True, log_model: bool = True, pruning_callback: Any | None = None, output_size: int | None = None, gradient_clip_val: float | None = None, sentinel_augmenter: Any | None = None) -> dict[str, list[float] | float]
¶
Train a segmentation model, monitoring validation loss and saving the best model.
Detailed metrics should be calculated separately after training using an evaluation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch model. |
required |
train_loader
|
DataLoader
|
DataLoader for training data. |
required |
val_loader
|
DataLoader
|
DataLoader for validation data. |
required |
criterion
|
Module
|
Loss function. |
required |
optimizer
|
Optimizer
|
Optimizer for training. |
required |
device
|
device
|
Device (CPU/GPU). |
required |
scheduler
|
LRScheduler | None
|
Optional learning rate scheduler. |
None
|
apply_augmentations
|
bool
|
Whether to apply augmentations to the training data. Defaults to True. |
True
|
data_config
|
dict[str, Any] | None
|
Full data configuration dict (contains augmentation config, normalization settings, and channel selections). |
None
|
epochs
|
int
|
Maximum number of epochs to train. Defaults to 100. |
100
|
patience
|
int
|
Early stopping patience. Defaults to 20. |
20
|
num_classes
|
int
|
Number of classes in the segmentation task. Defaults to 13. |
13
|
accumulation_steps
|
int
|
Number of steps to accumulate gradients before updating. Defaults to 1. |
1
|
use_amp
|
bool
|
Whether to use Automatic Mixed Precision (AMP). Defaults to False. |
False
|
log_evaluation_metrics
|
bool
|
Whether to log metrics and models to MLflow. Defaults to True. |
True
|
log_model
|
bool
|
Whether to log the best model to MLflow. Defaults to True. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict[str, list[float] | float]
|
History of training and validation losses/mIoUs, and best values. The best model is logged as an MLflow artifact 'best_model' if log_to_mlflow is True. |