Skip to content

Samba Encoder

Samba (State Space Model) encoder backbone for UNetFormer.

models.encoders.samba_encoder

Samba Encoder for semantic segmentation.

Adapted from the Samba repository: https://github.com/zhuqinfeng1999/Samba

Original paper: Samba: Semantic Segmentation of Remotely Sensed Images with State Space Model https://doi.org/10.1016/j.heliyon.2024.e38495

DWConv(dim: int = 768)

Bases: Module

Depth-wise convolution for spatial mixing in FFN.

Source code in src/models/encoders/samba_encoder.py
24
25
26
def __init__(self, dim: int = 768) -> None:
    super().__init__()
    self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

DownSamples(in_channels: int, out_channels: int)

Bases: Module

Downsampling layer between stages.

Source code in src/models/encoders/samba_encoder.py
137
138
139
140
141
def __init__(self, in_channels: int, out_channels: int) -> None:
    super().__init__()
    self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
    self.norm = nn.LayerNorm(out_channels)
    self.apply(self._init_weights)

MambaLayer(dim: int, d_state: int = 64, d_conv: int = 4, expand: int = 2)

Bases: Module

Single Mamba layer with LayerNorm.

Source code in src/models/encoders/samba_encoder.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    dim: int,
    d_state: int = 64,
    d_conv: int = 4,
    expand: int = 2,
) -> None:
    super().__init__()
    self.dim = dim
    self.norm = nn.LayerNorm(dim)
    self.mamba = Mamba(
        d_model=dim,
        d_state=d_state,
        d_conv=d_conv,
        expand=expand,
    )

PVT2FFN(in_features: int, hidden_features: int)

Bases: Module

Feed-forward network with depth-wise convolution (PVTv2 style).

Source code in src/models/encoders/samba_encoder.py
39
40
41
42
43
44
45
def __init__(self, in_features: int, hidden_features: int) -> None:
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.dwconv = DWConv(hidden_features)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden_features, in_features)
    self.apply(self._init_weights)

SambaBlock(dim: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: type[nn.Module] = nn.LayerNorm)

Bases: Module

Samba block: Mamba + FFN with residual connections.

Source code in src/models/encoders/samba_encoder.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    dim: int,
    mlp_ratio: float = 4.0,
    drop_path: float = 0.0,
    norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
    super().__init__()
    self.norm2 = norm_layer(dim)
    self.attn = MambaLayer(dim)
    self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
    self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    self.apply(self._init_weights)

SambaEncoder(in_channels: int = 3, stem_hidden_dim: int = 32, embed_dims: list[int] | None = None, mlp_ratios: list[float] | None = None, drop_path_rate: float = 0.0, depths: list[int] | None = None, num_stages: int = 4)

Bases: Module

Samba Encoder: Hierarchical Mamba-based encoder for segmentation.

Produces 4-stage hierarchical features compatible with UNet-style decoders.

Parameters:

Name Type Description Default
in_channels int

Number of input image channels. Default: 3

3
stem_hidden_dim int

Hidden dimension in stem convolutions. Default: 32

32
embed_dims list[int] | None

Feature dimensions at each stage. Default: [64, 128, 320, 448]

None
mlp_ratios list[float] | None

MLP expansion ratios per stage. Default: [8, 8, 4, 4]

None
drop_path_rate float

Stochastic depth rate. Default: 0.0

0.0
depths list[int] | None

Number of blocks per stage. Default: [3, 4, 6, 3]

None
num_stages int

Number of encoder stages. Default: 4

4
Source code in src/models/encoders/samba_encoder.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def __init__(
    self,
    in_channels: int = 3,
    stem_hidden_dim: int = 32,
    embed_dims: list[int] | None = None,
    mlp_ratios: list[float] | None = None,
    drop_path_rate: float = 0.0,
    depths: list[int] | None = None,
    num_stages: int = 4,
) -> None:
    super().__init__()

    if embed_dims is None:
        embed_dims = [64, 128, 320, 448]
    if mlp_ratios is None:
        mlp_ratios = [8, 8, 4, 4]
    if depths is None:
        depths = [3, 4, 6, 3]

    self.num_stages = num_stages
    self.depths = depths
    self.embed_dims = embed_dims

    # Stochastic depth decay rule
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
    cur = 0

    for i in range(num_stages):
        # Patch embedding / downsampling
        if i == 0:
            patch_embed = Stem(in_channels, stem_hidden_dim, embed_dims[i])
        else:
            patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])

        # Samba blocks for this stage
        block = nn.ModuleList(
            [
                SambaBlock(
                    dim=embed_dims[i],
                    mlp_ratio=mlp_ratios[i],
                    drop_path=dpr[cur + j],
                )
                for j in range(depths[i])
            ]
        )

        norm = nn.LayerNorm(embed_dims[i])
        cur += depths[i]

        setattr(self, f"patch_embed{i + 1}", patch_embed)
        setattr(self, f"block{i + 1}", block)
        setattr(self, f"norm{i + 1}", norm)

    self.apply(self._init_weights)

forward(x: torch.Tensor) -> list[torch.Tensor]

Forward pass returning hierarchical features.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, C, H, W)

required

Returns:

Type Description
list[Tensor]

List of 4 feature tensors at different scales:

list[Tensor]
  • Stage 1: (B, embed_dims[0], H/4, W/4)
list[Tensor]
  • Stage 2: (B, embed_dims[1], H/8, W/8)
list[Tensor]
  • Stage 3: (B, embed_dims[2], H/16, W/16)
list[Tensor]
  • Stage 4: (B, embed_dims[3], H/32, W/32)
Source code in src/models/encoders/samba_encoder.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
    """Forward pass returning hierarchical features.

    Args:
        x: Input tensor of shape (B, C, H, W)

    Returns:
        List of 4 feature tensors at different scales:
        - Stage 1: (B, embed_dims[0], H/4, W/4)
        - Stage 2: (B, embed_dims[1], H/8, W/8)
        - Stage 3: (B, embed_dims[2], H/16, W/16)
        - Stage 4: (B, embed_dims[3], H/32, W/32)

    """
    b = x.shape[0]
    outs = []

    for i in range(self.num_stages):
        patch_embed = getattr(self, f"patch_embed{i + 1}")
        block = getattr(self, f"block{i + 1}")
        norm = getattr(self, f"norm{i + 1}")

        x, h, w = patch_embed(x)

        for blk in block:
            x = blk(x, h, w)

        x = norm(x)
        x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

    return outs

get_channels() -> list[int]

Return output channels for each stage.

Source code in src/models/encoders/samba_encoder.py
335
336
337
def get_channels(self) -> list[int]:
    """Return output channels for each stage."""
    return list(self.embed_dims)

Stem(in_channels: int, stem_hidden_dim: int, out_channels: int)

Bases: Module

Stem module for initial feature extraction.

Source code in src/models/encoders/samba_encoder.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def __init__(
    self,
    in_channels: int,
    stem_hidden_dim: int,
    out_channels: int,
) -> None:
    super().__init__()
    hidden_dim = stem_hidden_dim
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2, padding=3, bias=False),
        nn.BatchNorm2d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(hidden_dim),
        nn.ReLU(inplace=True),
    )
    self.proj = nn.Conv2d(hidden_dim, out_channels, kernel_size=3, stride=2, padding=1)
    self.norm = nn.LayerNorm(out_channels)
    self.apply(self._init_weights)