RS3Mamba¶
State Space Model (SSM) architecture for remote sensing segmentation.
models.architectures.rs3mamba
¶
RS3Mamba: Visual State Space Model for Remote Sensing Semantic Segmentation.
This file is adapted from the SSRS repository: https://github.com/sstary/SSRS/blob/main/RS3Mamba/model/RS3Mamba.py
Original paper: RS3Mamba: Visual State Space Model for Remote Sensing Image Semantic Segmentation https://arxiv.org/abs/2404.02457
ChannelAttention(gate_channels: int, reduction_ratio: int = 2, pool_types: list[str] | None = None)
¶
Bases: Module
Channel attention module with avg, max, and soft pooling.
Source code in src/models/architectures/rs3mamba.py
Decoder(encoder_channels: tuple[int, ...] = (64, 128, 256, 512), decode_channels: int = 64, dropout: float = 0.1, window_size: int = 8, num_classes: int = 6)
¶
Bases: Module
UNetFormer-style decoder with Global-Local Attention.
Initialize the decoder used to upsample and produce segmentation maps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
encoder_channels
|
tuple[int, ...]
|
Tuple with channels from encoder stages. |
(64, 128, 256, 512)
|
decode_channels
|
int
|
Number of decoder channels. |
64
|
dropout
|
float
|
Dropout probability in segmentation head. |
0.1
|
window_size
|
int
|
Window size used by attention blocks. |
8
|
num_classes
|
int
|
Number of segmentation classes. |
6
|
Source code in src/models/architectures/rs3mamba.py
forward(res1: torch.Tensor, res2: torch.Tensor, res3: torch.Tensor, res4: torch.Tensor, h: int, w: int) -> torch.Tensor
¶
Run the decoder to produce segmentation logits at size (h, w).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
res1
|
Tensor
|
Shallowest encoder feature map. |
required |
res2
|
Tensor
|
Intermediate encoder feature map. |
required |
res3
|
Tensor
|
Deeper encoder feature map. |
required |
res4
|
Tensor
|
Deepest encoder feature map. |
required |
h
|
int
|
Target output height. |
required |
w
|
int
|
Target output width. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Segmentation logits resized to (h, w). |
Source code in src/models/architectures/rs3mamba.py
init_weight() -> None
¶
Initialize Conv2d weights using Kaiming normalization.
This initializes weights for convolutional layers in the decoder.
Source code in src/models/architectures/rs3mamba.py
FusionAttention(dim: int = 256, ssmdims: int = 256, num_heads: int = 16, qkv_bias: bool = False, window_size: int = 8, relative_pos_embedding: bool = True)
¶
Bases: Module
Attention module for fusing CNN and Mamba features.
Source code in src/models/architectures/rs3mamba.py
forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor
¶
Fuse CNN features (x) with Mamba features (y).
Source code in src/models/architectures/rs3mamba.py
FusionBlock(dim: int = 256, ssmdims: int = 256, num_heads: int = 16, mlp_ratio: float = 4.0, qkv_bias: bool = False, drop: float = 0.0, drop_path: float = 0.0, act_layer: type[nn.Module] = nn.ReLU6, norm_layer: type[nn.Module] = nn.BatchNorm2d, window_size: int = 8, use_channel_attention: bool = True)
¶
Bases: Module
Block for fusing CNN and Mamba features with attention and MLP.
Source code in src/models/architectures/rs3mamba.py
RS3Mamba(decode_channels: int = 64, dropout: float = 0.1, backbone_name: str = 'swsl_resnet18', pretrained: bool = True, window_size: int = 8, num_classes: int = 6, in_channels: int = 3, use_channel_attention: bool = True)
¶
Bases: Module
RS³Mamba: Visual State Space Model for Remote Sensing Semantic Segmentation.
This model combines a CNN backbone (ResNet) with a VMamba encoder for capturing both local and global features in remote sensing images.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decode_channels
|
int
|
Number of decoder channels. Default: 64 |
64
|
dropout
|
float
|
Dropout rate in decoder. Default: 0.1 |
0.1
|
backbone_name
|
str
|
Name of timm backbone. Default: 'swsl_resnet18' |
'swsl_resnet18'
|
pretrained
|
bool
|
Whether to use pretrained backbone. Default: True |
True
|
window_size
|
int
|
Window size for attention. Default: 8 |
8
|
num_classes
|
int
|
Number of output classes. Default: 6 |
6
|
in_channels
|
int
|
Number of input channels. Default: 3 |
3
|
Initialize the RS3Mamba model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decode_channels
|
int
|
Number of decoder channels. |
64
|
dropout
|
float
|
Dropout rate used in decoder. |
0.1
|
backbone_name
|
str
|
Name of the backbone model from timm. |
'swsl_resnet18'
|
pretrained
|
bool
|
Whether to load pretrained backbone weights. |
True
|
window_size
|
int
|
Window size used for attention modules. |
8
|
num_classes
|
int
|
Number of output classes. |
6
|
in_channels
|
int
|
Number of input image channels. |
3
|
use_channel_attention
|
bool
|
Whether to use Channel Attention in FusionBlock. |
True
|
Source code in src/models/architectures/rs3mamba.py
forward(x: torch.Tensor) -> torch.Tensor
¶
Forward pass combining VMamba encoder and CNN backbone.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (B, C, H, W). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Segmentation logits resized to the input spatial size. |
Source code in src/models/architectures/rs3mamba.py
SoftPool2d(kernel_size: int, stride: int | None = None)
¶
load_pretrained_ckpt(model: RS3Mamba, ckpt_path: str = './pretrain/vmamba_tiny_e292.pth') -> RS3Mamba
¶
Load pretrained VMamba weights into RS3Mamba model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
RS3Mamba
|
RS3Mamba model instance |
required |
ckpt_path
|
str
|
Path to VMamba pretrained weights |
'./pretrain/vmamba_tiny_e292.pth'
|
Returns:
| Type | Description |
|---|---|
RS3Mamba
|
Model with loaded weights |