UNetFormer¶
UNet-like Transformer with Global-Local Attention for aerial image segmentation.
models.architectures.unetformer
¶
Implementation of the UNetFormer architecture.
This file is adapted from the SSRS repository: https://github.com/sstary/SSRS/blob/main/RS3Mamba/model/UNetFormer.py
Original paper: UNetFormer: A UNet-like transformer for efficient semantic segmentation of remote sensing urban scene imagery https://arxiv.org/abs/2109.08937
AuxHead(in_channels: int = 64, num_classes: int = 8)
¶
Bases: Module
Auxiliary head for deep supervision.
Source code in src/models/architectures/unetformer.py
forward(x: torch.Tensor, h: int, w: int) -> torch.Tensor
¶
Produce auxiliary segmentation logits upsampled to (h, w).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Feature tensor from the decoder. |
required |
h
|
int
|
Target output height. |
required |
w
|
int
|
Target output width. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Logits tensor of shape (B, num_classes, h, w). |
Source code in src/models/architectures/unetformer.py
Decoder(encoder_channels: tuple[int, ...] = (64, 128, 256, 512), decode_channels: int = 64, dropout: float = 0.1, window_size: int = 8, num_classes: int = 6)
¶
Bases: Module
UNetFormer decoder with Global-Local Attention blocks.
Source code in src/models/architectures/unetformer.py
forward(res1: torch.Tensor, res2: torch.Tensor, res3: torch.Tensor, res4: torch.Tensor, h: int, w: int, *, return_aux_features: bool = False) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]
¶
Decode multi-scale encoder features into segmentation logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
res1
|
Tensor
|
Stage-1 encoder features (highest resolution). |
required |
res2
|
Tensor
|
Stage-2 encoder features. |
required |
res3
|
Tensor
|
Stage-3 encoder features. |
required |
res4
|
Tensor
|
Stage-4 encoder features (lowest resolution). |
required |
h
|
int
|
Target output height. |
required |
w
|
int
|
Target output width. |
required |
return_aux_features
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
Segmentation logits of shape (B, num_classes, h, w), or a tuple |
Tensor | tuple[Tensor, Tensor]
|
of (logits, aux_features) when return_aux_features is |
Source code in src/models/architectures/unetformer.py
UNetFormer(decode_channels: int = 64, dropout: float = 0.1, backbone_name: str = 'swsl_resnet18', pretrained: bool = True, window_size: int = 8, num_classes: int = 6, in_channels: int = 3, img_size: int | tuple[int, int] | None = None, *, use_aux_head: bool = False, encoder_type: str = 'timm', samba_config: dict[str, Any] | None = None, drop_path_rate: float = 0.0)
¶
Bases: Module
UNetFormer: A UNet-like Transformer for Semantic Segmentation.
This model uses a CNN backbone (from timm) or Samba encoder for feature extraction and a transformer-style decoder with Global-Local Attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decode_channels
|
int
|
Number of decoder channels. Default: 64 |
64
|
dropout
|
float
|
Dropout rate in decoder. Default: 0.1 |
0.1
|
backbone_name
|
str
|
Name of timm backbone. Default: 'swsl_resnet18' |
'swsl_resnet18'
|
pretrained
|
bool
|
Whether to use pretrained backbone. Default: True |
True
|
window_size
|
int
|
Window size for attention. Default: 8 |
8
|
num_classes
|
int
|
Number of output classes. Default: 6 |
6
|
in_channels
|
int
|
Number of input channels. Default: 3 |
3
|
use_aux_head
|
bool
|
Whether to use auxiliary head for deep supervision. Default: False |
False
|
encoder_type
|
str
|
Type of encoder: 'timm' or 'samba'. Default: 'timm' |
'timm'
|
samba_config
|
dict[str, Any] | None
|
Configuration dict for Samba encoder when encoder_type='samba' |
None
|
Source code in src/models/architectures/unetformer.py
forward(x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]
¶
Run the full encoder-decoder forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image tensor of shape (B, C, H, W). |
required |
Returns:
| Type | Description |
|---|---|
Tensor | tuple[Tensor, Tensor]
|
Segmentation logits of shape (B, num_classes, H, W). During |
Tensor | tuple[Tensor, Tensor]
|
training with |
Tensor | tuple[Tensor, Tensor]
|
(main_logits, aux_logits). |