Multimodal Late Fusion¶
Late fusion model combining aerial and Sentinel-2 sub-models with learnable class-wise weights.
models.architectures.multimodal_fusion
¶
Multimodal Late Fusion model combining aerial and Sentinel-2 modalities.
This module implements a late fusion architecture that combines predictions from a pre-trained aerial model (e.g., UNetFormer) and a pre-trained Sentinel-2 temporal model (e.g., TSViT) using learnable per-class modality weights.
MultiScaleChannelAttention(channels: int, r: int = 16)
¶
Bases: Module
Multi-Scale Channel Attention Module (MS-CAM).
Based on 'Attentional Feature Fusion' (Dai et al., 2021). Fuses features by considering both global context (GAP) and local context through pointwise 1x1 convolutions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
int
|
Number of input/output channels. |
required |
r
|
int
|
Reduction ratio for the bottleneck. |
16
|
Initialize the MS-CAM module.
Source code in src/models/architectures/multimodal_fusion.py
forward(x: torch.Tensor) -> torch.Tensor
¶
Compute spatial-varying gating weights.
MultimodalLateFusion(aerial_model: nn.Module, sentinel_model: nn.Module, num_classes: int, *, freeze_encoders: bool = True, freeze_encoder_stats: bool | None = None, fusion_mode: str = 'weighted', aerial_resolution: tuple[int, int] = (512, 512), sentinel_resolution: tuple[int, int] = (10, 10), sentinel_output_resolution: tuple[int, int] | None = None, use_cloud_uncertainty: bool = False, modality_weights: list[float] | None = None, init_class_weights: dict[int, list[float]] | list[list[float]] | list[float] | None = None, gate_class_priors: dict[int, float] | list[float] | float | None = None)
¶
Bases: Module
Late fusion model combining aerial and Sentinel-2 predictions.
This model fuses predictions from two pre-trained modality-specific models: - An aerial model (e.g., UNetFormer) for high-resolution imagery - A Sentinel model (e.g., TSViT) for temporal satellite data
The fusion uses learnable per-class weights to determine each modality's contribution for each semantic class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
aerial_model
|
Module
|
Pre-trained model for aerial imagery. |
required |
sentinel_model
|
Module
|
Pre-trained model for Sentinel-2 time series. |
required |
num_classes
|
int
|
Number of output segmentation classes. |
required |
freeze_encoders
|
bool
|
Whether to freeze pre-trained encoder weights. |
True
|
fusion_mode
|
str
|
Fusion strategy - 'weighted' (per-class weights), 'gated' (content-aware spatial gates), 'concat' (channel concatenation), or 'average'. |
'weighted'
|
aerial_resolution
|
tuple[int, int]
|
Tuple (H, W) for aerial model output resolution. |
(512, 512)
|
sentinel_resolution
|
tuple[int, int]
|
Tuple (H, W) for Sentinel model output resolution. |
(10, 10)
|
use_cloud_uncertainty
|
bool
|
Whether to use cloud coverage as input to gated fusion. |
False
|
init_class_weights
|
dict[int, list[float]] | list[list[float]] | list[float] | None
|
Optional initial modality weights for weighted fusion. Can be: - list[float] of length 2: global [aerial, sentinel] weights for all classes - list[list[float]] of length num_classes: per-class [aerial, sentinel] weights - dict[int, list[float]]: per-class weights by class index |
None
|
gate_class_priors
|
dict[int, float] | list[float] | float | None
|
Optional initial aerial priors for gated fusion. These are prior gate values in [0, 1] (higher = trust aerial more). Can be: - float: global aerial prior for all classes - list[float] of length num_classes: per-class priors - dict[int, float]: per-class priors by class index |
None
|
Initialize the multimodal late fusion model.
Source code in src/models/architectures/multimodal_fusion.py
forward(aerial_input: torch.Tensor, sentinel_input: torch.Tensor, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None, cloud_coverage: torch.Tensor | None = None) -> torch.Tensor
¶
Forward pass combining both modalities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
aerial_input
|
Tensor
|
Aerial imagery tensor of shape (B, C, H, W). |
required |
sentinel_input
|
Tensor
|
Sentinel-2 time series of shape (B, T, C, H, W). |
required |
batch_positions
|
Tensor | None
|
Temporal positions of shape (B, T) for Sentinel model. |
None
|
pad_mask
|
Tensor | None
|
Boolean padding mask of shape (B, T) where True indicates a padded (invalid) timestep to be ignored by the Sentinel model. |
None
|
cloud_coverage
|
Tensor | None
|
Cloud coverage tensor for gated fusion, shape (B, 1, H, W) at Sentinel resolution. Will be upsampled to aerial resolution. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Fused predictions of shape (B, num_classes, H, W) at aerial resolution. |
Source code in src/models/architectures/multimodal_fusion.py
get_fusion_weights() -> dict[str, torch.Tensor]
¶
Get the current per-class fusion weights.
Returns:
| Type | Description |
|---|---|
dict[str, Tensor]
|
Dictionary with 'raw' unnormalized weights and 'normalized' softmax weights. |
Source code in src/models/architectures/multimodal_fusion.py
train(mode: bool = True) -> MultimodalLateFusion
¶
Set training mode.
When encoders are frozen, it's common to keep them in eval mode during fusion training so BatchNorm running stats don't drift and Dropout stays disabled.
Source code in src/models/architectures/multimodal_fusion.py
trainable_parameters() -> list[nn.Parameter]
¶
Return only the trainable fusion parameters (not frozen encoders).
Returns:
| Type | Description |
|---|---|
list[Parameter]
|
List of trainable parameters. |
Source code in src/models/architectures/multimodal_fusion.py
load_pretrained_multimodal(aerial_checkpoint: str | None, sentinel_checkpoint: str | None, aerial_model: nn.Module, sentinel_model: nn.Module, *, device: torch.device | str = 'cpu', strict: bool = True, strip_prefixes: list[str] | None = None) -> tuple[nn.Module, nn.Module]
¶
Load pre-trained weights into aerial and Sentinel models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
aerial_checkpoint
|
str | None
|
Path to aerial model checkpoint, or None to skip. |
required |
sentinel_checkpoint
|
str | None
|
Path to Sentinel model checkpoint, or None to skip. |
required |
aerial_model
|
Module
|
Aerial model instance to load weights into. |
required |
sentinel_model
|
Module
|
Sentinel model instance to load weights into. |
required |
device
|
device | str
|
Device to load checkpoints to. |
'cpu'
|
strict
|
bool
|
Whether to enforce that checkpoint keys match the model exactly. |
True
|
strip_prefixes
|
list[str] | None
|
Optional list of prefixes to strip from checkpoint keys (e.g., ['module.'] for DataParallel checkpoints). Prefix stripping is only applied when all keys share the prefix. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Module, Module]
|
Tuple of (aerial_model, sentinel_model) with loaded weights. |