TSViT¶
Temporal-Spatial Vision Transformer for Sentinel-2 time series segmentation.
models.architectures.tsvit
¶
Temporal-Spatial Vision Transformer (TSViT) implementation.
This module adapts the TSViT architecture from "ViTs for SITS: Vision Transformers for Satellite Image Time Series" (CVPR 2023, Tarasiou et al.) to the FLAIR-2 Sentinel-2 only scenario. The implementation follows the original design (temporal transformer followed by spatial transformer over patch tokens) while relaxing assumptions about sequence length and date encodings so it can ingest the monthly averaged sentinel stacks produced by this repository.
MultiHeadSelfAttention(dim: int, num_heads: int, dropout: float = 0.0)
¶
Bases: Module
Simplified multi-head self-attention with optional padding mask.
Source code in src/models/architectures/tsvit.py
forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor
¶
Apply multi-head self-attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch, seq_len, dim) |
required |
key_padding_mask
|
Tensor | None
|
Optional boolean mask of shape (batch, seq_len) where True indicates positions to ignore |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch, seq_len, dim) |
Source code in src/models/architectures/tsvit.py
TSViT(*, image_size: int, patch_size: int, in_channels: int, num_classes: int, max_seq_len: int, dim: int, temporal_depth: int, spatial_depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0, emb_dropout: float = 0.0, temporal_metadata_channels: int = 0)
¶
Bases: Module
Temporal-Spatial Vision Transformer tailored to Sentinel-2 patches.
Initialize the TSViT module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image_size
|
int
|
Height/width (in Sentinel pixels) of the cropped patch. |
required |
patch_size
|
int
|
Edge length of each Vision Transformer patch. |
required |
in_channels
|
int
|
Number of spectral channels produced by the dataset. |
required |
num_classes
|
int
|
Number of segmentation categories. |
required |
max_seq_len
|
int
|
Size of the one-hot position encoding. The last index (max_seq_len - 1) is reserved for padding tokens, so valid position indices must be in range [0, max_seq_len - 2]. Set to 13 for month-of-year encoding (months 0-11 + padding) or 367 for day-of-year encoding (days 0-365 + padding). |
required |
dim
|
int
|
Embedding dimension of token representations. |
required |
temporal_depth
|
int
|
Number of transformer blocks in the temporal encoder. |
required |
spatial_depth
|
int
|
Number of transformer blocks in the spatial encoder. |
required |
num_heads
|
int
|
Number of attention heads for both encoders. |
required |
mlp_dim
|
int
|
Hidden size of the feed-forward sublayers. |
required |
dropout
|
float
|
Dropout applied inside the transformer blocks. |
0.0
|
emb_dropout
|
float
|
Dropout applied after adding positional embeddings. |
0.0
|
temporal_metadata_channels
|
int
|
Optional number of metadata channels reserved at the end of the spectral dimension (e.g., timestamps). |
0
|
Source code in src/models/architectures/tsvit.py
forward(x: torch.Tensor, *, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None) -> torch.Tensor
¶
Run the temporal-spatial transformer over a Sentinel-2 sequence.
Source code in src/models/architectures/tsvit.py
TransformerBlock(dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)
¶
Bases: Module
Transformer block with pre-norm and residual connections.
Source code in src/models/architectures/tsvit.py
forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor
¶
Apply transformer block with attention and feedforward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch, seq_len, dim) |
required |
key_padding_mask
|
Tensor | None
|
Optional boolean mask for attention |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch, seq_len, dim) |
Source code in src/models/architectures/tsvit.py
TransformerEncoder(dim: int, depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)
¶
Bases: Module
Stack of Transformer blocks with shared padding mask.
Source code in src/models/architectures/tsvit.py
128 129 130 131 132 133 134 135 136 137 138 139 140 |
forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor
¶
Apply stack of transformer blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch, seq_len, dim) |
required |
key_padding_mask
|
Tensor | None
|
Optional boolean mask for attention |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch, seq_len, dim) after layer normalization |