Skip to content

UNetFormer

UNet-like Transformer with Global-Local Attention for aerial image segmentation.

models.architectures.unetformer

Implementation of the UNetFormer architecture.

This file is adapted from the SSRS repository: https://github.com/sstary/SSRS/blob/main/RS3Mamba/model/UNetFormer.py

Original paper: UNetFormer: A UNet-like transformer for efficient semantic segmentation of remote sensing urban scene imagery https://arxiv.org/abs/2109.08937

AuxHead(in_channels: int = 64, num_classes: int = 8)

Bases: Module

Auxiliary head for deep supervision.

Source code in src/models/architectures/unetformer.py
41
42
43
44
45
def __init__(self, in_channels: int = 64, num_classes: int = 8) -> None:
    super().__init__()
    self.conv = ConvBNReLU(in_channels, in_channels)
    self.drop = nn.Dropout(0.1)
    self.conv_out = Conv(in_channels, num_classes, kernel_size=1)

forward(x: torch.Tensor, h: int, w: int) -> torch.Tensor

Produce auxiliary segmentation logits upsampled to (h, w).

Parameters:

Name Type Description Default
x Tensor

Feature tensor from the decoder.

required
h int

Target output height.

required
w int

Target output width.

required

Returns:

Type Description
Tensor

Logits tensor of shape (B, num_classes, h, w).

Source code in src/models/architectures/unetformer.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
    """Produce auxiliary segmentation logits upsampled to (h, w).

    Args:
        x: Feature tensor from the decoder.
        h: Target output height.
        w: Target output width.

    Returns:
        Logits tensor of shape (B, num_classes, h, w).

    """
    feat = self.conv(x)
    feat = self.drop(feat)
    feat = self.conv_out(feat)
    feat = functional.interpolate(feat, size=(h, w), mode="bilinear", align_corners=False)
    return feat

Decoder(encoder_channels: tuple[int, ...] = (64, 128, 256, 512), decode_channels: int = 64, dropout: float = 0.1, window_size: int = 8, num_classes: int = 6)

Bases: Module

UNetFormer decoder with Global-Local Attention blocks.

Source code in src/models/architectures/unetformer.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    encoder_channels: tuple[int, ...] = (64, 128, 256, 512),
    decode_channels: int = 64,
    dropout: float = 0.1,
    window_size: int = 8,
    num_classes: int = 6,
) -> None:
    super().__init__()

    self.pre_conv = ConvBN(encoder_channels[-1], decode_channels, kernel_size=1)
    self.b4 = Block(dim=decode_channels, num_heads=8, window_size=window_size)

    self.b3 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
    self.p3 = WF(encoder_channels[-2], decode_channels)

    self.b2 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
    self.p2 = WF(encoder_channels[-3], decode_channels)

    self.p1 = FeatureRefinementHead(encoder_channels[-4], decode_channels)

    self.segmentation_head = nn.Sequential(
        ConvBNReLU(decode_channels, decode_channels),
        nn.Dropout2d(p=dropout, inplace=True),
        Conv(decode_channels, num_classes, kernel_size=1),
    )
    self.init_weight()

forward(res1: torch.Tensor, res2: torch.Tensor, res3: torch.Tensor, res4: torch.Tensor, h: int, w: int, *, return_aux_features: bool = False) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]

Decode multi-scale encoder features into segmentation logits.

Parameters:

Name Type Description Default
res1 Tensor

Stage-1 encoder features (highest resolution).

required
res2 Tensor

Stage-2 encoder features.

required
res3 Tensor

Stage-3 encoder features.

required
res4 Tensor

Stage-4 encoder features (lowest resolution).

required
h int

Target output height.

required
w int

Target output width.

required
return_aux_features bool

If True, also return pre-head features for the auxiliary loss.

False

Returns:

Type Description
Tensor | tuple[Tensor, Tensor]

Segmentation logits of shape (B, num_classes, h, w), or a tuple

Tensor | tuple[Tensor, Tensor]

of (logits, aux_features) when return_aux_features is True.

Source code in src/models/architectures/unetformer.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def forward(
    self,
    res1: torch.Tensor,
    res2: torch.Tensor,
    res3: torch.Tensor,
    res4: torch.Tensor,
    h: int,
    w: int,
    *,
    return_aux_features: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Decode multi-scale encoder features into segmentation logits.

    Args:
        res1: Stage-1 encoder features (highest resolution).
        res2: Stage-2 encoder features.
        res3: Stage-3 encoder features.
        res4: Stage-4 encoder features (lowest resolution).
        h: Target output height.
        w: Target output width.
        return_aux_features: If ``True``, also return pre-head features
            for the auxiliary loss.

    Returns:
        Segmentation logits of shape (B, num_classes, h, w), or a tuple
        of (logits, aux_features) when *return_aux_features* is ``True``.

    """
    x = self.b4(self.pre_conv(res4))
    x = self.p3(x, res3)
    x = self.b3(x)

    x = self.p2(x, res2)
    x = self.b2(x)

    x = self.p1(x, res1)

    # Store features for auxiliary head before segmentation
    aux_features = x

    x = self.segmentation_head(x)
    x = functional.interpolate(x, size=(h, w), mode="bilinear", align_corners=False)

    if return_aux_features:
        return x, aux_features
    return x

UNetFormer(decode_channels: int = 64, dropout: float = 0.1, backbone_name: str = 'swsl_resnet18', pretrained: bool = True, window_size: int = 8, num_classes: int = 6, in_channels: int = 3, img_size: int | tuple[int, int] | None = None, *, use_aux_head: bool = False, encoder_type: str = 'timm', samba_config: dict[str, Any] | None = None, drop_path_rate: float = 0.0)

Bases: Module

UNetFormer: A UNet-like Transformer for Semantic Segmentation.

This model uses a CNN backbone (from timm) or Samba encoder for feature extraction and a transformer-style decoder with Global-Local Attention.

Parameters:

Name Type Description Default
decode_channels int

Number of decoder channels. Default: 64

64
dropout float

Dropout rate in decoder. Default: 0.1

0.1
backbone_name str

Name of timm backbone. Default: 'swsl_resnet18'

'swsl_resnet18'
pretrained bool

Whether to use pretrained backbone. Default: True

True
window_size int

Window size for attention. Default: 8

8
num_classes int

Number of output classes. Default: 6

6
in_channels int

Number of input channels. Default: 3

3
use_aux_head bool

Whether to use auxiliary head for deep supervision. Default: False

False
encoder_type str

Type of encoder: 'timm' or 'samba'. Default: 'timm'

'timm'
samba_config dict[str, Any] | None

Configuration dict for Samba encoder when encoder_type='samba'

None
Source code in src/models/architectures/unetformer.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def __init__(
    self,
    decode_channels: int = 64,
    dropout: float = 0.1,
    backbone_name: str = "swsl_resnet18",
    pretrained: bool = True,
    window_size: int = 8,
    num_classes: int = 6,
    in_channels: int = 3,
    img_size: int | tuple[int, int] | None = None,
    *,
    use_aux_head: bool = False,
    encoder_type: str = "timm",
    samba_config: dict[str, Any] | None = None,
    drop_path_rate: float = 0.0,
) -> None:
    super().__init__()
    self.use_aux_head = use_aux_head
    self.encoder_type = encoder_type
    self._warned_nhwc_features = False

    if encoder_type == "samba":
        if SambaEncoder is None:
            msg = "SambaEncoder could not be imported. Check dependencies (e.g. mamba_ssm)."
            raise ImportError(msg)

        samba_config = samba_config or {}
        self.backbone = SambaEncoder(
            in_channels=in_channels,
            stem_hidden_dim=samba_config.get("stem_hidden_dim", 32),
            embed_dims=samba_config.get("embed_dims", [64, 128, 320, 448]),
            mlp_ratios=samba_config.get("mlp_ratios", [8, 8, 4, 4]),
            drop_path_rate=samba_config.get("drop_path_rate", 0.0),
            depths=samba_config.get("depths", [3, 4, 6, 3]),
        )
        encoder_channels = tuple(self.backbone.get_channels())
    else:
        # Select out_indices based on encoder architecture
        # ConvNeXt family has 4 stages (0-3), ResNet family has 5 stages (1-4)
        backbone_lower = backbone_name.lower()
        if "convnext" in backbone_lower or "swin" in backbone_lower:
            out_indices = (0, 1, 2, 3)
        else:
            out_indices = (1, 2, 3, 4)

        timm_kwargs: dict[str, Any] = {
            "features_only": True,
            "out_indices": out_indices,
            "pretrained": pretrained,
            "in_chans": in_channels,
            "drop_path_rate": drop_path_rate,
        }
        if img_size is not None:
            timm_kwargs["img_size"] = img_size

        try:
            self.backbone = timm.create_model(backbone_name, **timm_kwargs)
        except TypeError:
            if "img_size" not in timm_kwargs:
                raise
            logger.warning(
                "Backbone '%s' does not accept img_size=%s. Falling back to default size.",
                backbone_name,
                img_size,
            )
            timm_kwargs.pop("img_size", None)
            self.backbone = timm.create_model(backbone_name, **timm_kwargs)
        encoder_channels = tuple(self.backbone.feature_info.channels())
    self.encoder_channels = encoder_channels

    self.decoder = Decoder(
        encoder_channels,
        decode_channels,
        dropout,
        window_size,
        num_classes,
    )

    # Auxiliary head for deep supervision (as per paper)
    if use_aux_head:
        self.aux_head = AuxHead(
            in_channels=decode_channels,
            num_classes=num_classes,
        )

forward(x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]

Run the full encoder-decoder forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor]

Segmentation logits of shape (B, num_classes, H, W). During

Tensor | tuple[Tensor, Tensor]

training with use_aux_head=True, returns a tuple of

Tensor | tuple[Tensor, Tensor]

(main_logits, aux_logits).

Source code in src/models/architectures/unetformer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Run the full encoder-decoder forward pass.

    Args:
        x: Input image tensor of shape (B, C, H, W).

    Returns:
        Segmentation logits of shape (B, num_classes, H, W). During
        training with ``use_aux_head=True``, returns a tuple of
        (main_logits, aux_logits).

    """
    h, w = x.size()[-2:]
    res1, res2, res3, res4 = self.backbone(x)
    res1 = self._ensure_nchw(res1, int(self.encoder_channels[0]))
    res2 = self._ensure_nchw(res2, int(self.encoder_channels[1]))
    res3 = self._ensure_nchw(res3, int(self.encoder_channels[2]))
    res4 = self._ensure_nchw(res4, int(self.encoder_channels[3]))

    if self.use_aux_head and self.training:
        main_out, aux_features = self.decoder(
            res1,
            res2,
            res3,
            res4,
            h,
            w,
            return_aux_features=True,
        )
        aux_out = self.aux_head(aux_features, h, w)
        return main_out, aux_out

    return self.decoder(res1, res2, res3, res4, h, w)