Samba Encoder¶
Samba (State Space Model) encoder backbone for UNetFormer.
models.encoders.samba_encoder
¶
Samba Encoder for semantic segmentation.
Adapted from the Samba repository: https://github.com/zhuqinfeng1999/Samba
Original paper: Samba: Semantic Segmentation of Remotely Sensed Images with State Space Model https://doi.org/10.1016/j.heliyon.2024.e38495
DWConv(dim: int = 768)
¶
DownSamples(in_channels: int, out_channels: int)
¶
Bases: Module
Downsampling layer between stages.
Source code in src/models/encoders/samba_encoder.py
MambaLayer(dim: int, d_state: int = 64, d_conv: int = 4, expand: int = 2)
¶
Bases: Module
Single Mamba layer with LayerNorm.
Source code in src/models/encoders/samba_encoder.py
PVT2FFN(in_features: int, hidden_features: int)
¶
Bases: Module
Feed-forward network with depth-wise convolution (PVTv2 style).
Source code in src/models/encoders/samba_encoder.py
SambaBlock(dim: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: type[nn.Module] = nn.LayerNorm)
¶
Bases: Module
Samba block: Mamba + FFN with residual connections.
Source code in src/models/encoders/samba_encoder.py
SambaEncoder(in_channels: int = 3, stem_hidden_dim: int = 32, embed_dims: list[int] | None = None, mlp_ratios: list[float] | None = None, drop_path_rate: float = 0.0, depths: list[int] | None = None, num_stages: int = 4)
¶
Bases: Module
Samba Encoder: Hierarchical Mamba-based encoder for segmentation.
Produces 4-stage hierarchical features compatible with UNet-style decoders.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input image channels. Default: 3 |
3
|
stem_hidden_dim
|
int
|
Hidden dimension in stem convolutions. Default: 32 |
32
|
embed_dims
|
list[int] | None
|
Feature dimensions at each stage. Default: [64, 128, 320, 448] |
None
|
mlp_ratios
|
list[float] | None
|
MLP expansion ratios per stage. Default: [8, 8, 4, 4] |
None
|
drop_path_rate
|
float
|
Stochastic depth rate. Default: 0.0 |
0.0
|
depths
|
list[int] | None
|
Number of blocks per stage. Default: [3, 4, 6, 3] |
None
|
num_stages
|
int
|
Number of encoder stages. Default: 4 |
4
|
Source code in src/models/encoders/samba_encoder.py
forward(x: torch.Tensor) -> list[torch.Tensor]
¶
Forward pass returning hierarchical features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (B, C, H, W) |
required |
Returns:
| Type | Description |
|---|---|
list[Tensor]
|
List of 4 feature tensors at different scales: |
list[Tensor]
|
|
list[Tensor]
|
|
list[Tensor]
|
|
list[Tensor]
|
|
Source code in src/models/encoders/samba_encoder.py
get_channels() -> list[int]
¶
Stem(in_channels: int, stem_hidden_dim: int, out_channels: int)
¶
Bases: Module
Stem module for initial feature extraction.