TSViT Lookup¶
TSViT variant with lookup-based temporal position embeddings.
models.architectures.tsvit_lookup
¶
TSViT with lookup-based temporal position embeddings.
MultiHeadSelfAttention(dim: int, num_heads: int, dropout: float = 0.0)
¶
Bases: Module
Simplified multi-head self-attention with optional padding mask.
Source code in src/models/architectures/tsvit_lookup.py
forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor
¶
Apply multi-head self-attention.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (batch, seq_len, dim) |
required |
key_padding_mask
|
Tensor | None
|
Optional boolean mask of shape (batch, seq_len) where True indicates positions to ignore |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (batch, seq_len, dim) |
Source code in src/models/architectures/tsvit_lookup.py
TSViTLookup(*, image_size: int, patch_size: int, in_channels: int, num_classes: int, train_dates: list[int] | torch.Tensor, dim: int, temporal_depth: int, spatial_depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0, emb_dropout: float = 0.0, temporal_metadata_channels: int = 0, date_range: tuple[int, int] | None = None)
¶
Bases: Module
Temporal-Spatial ViT with lookup-based temporal position embeddings.
This is the most advanced temporal encoding approach from the paper. It learns a separate position embedding for each unique date seen during training, then uses linear interpolation for unseen dates during inference.
Initialize the TSViT module with lookup temporal embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image_size
|
int
|
Height/width (in Sentinel pixels) of the cropped patch. |
required |
patch_size
|
int
|
Edge length of each Vision Transformer patch. |
required |
in_channels
|
int
|
Number of spectral channels produced by the dataset. |
required |
num_classes
|
int
|
Number of segmentation categories. |
required |
train_dates
|
list[int] | Tensor
|
List/tensor of unique dates (e.g., day-of-year 1-365) seen in training. During inference, dates not in this list are interpolated. IMPORTANT: All positions passed to forward() must use the same indexing scheme. For day-of-year: use 1-365 (not 0-364). For months: use 1-12 (not 0-11) or pass custom date_range. |
required |
dim
|
int
|
Embedding dimension of token representations. |
required |
temporal_depth
|
int
|
Number of transformer blocks in the temporal encoder. |
required |
spatial_depth
|
int
|
Number of transformer blocks in the spatial encoder. |
required |
num_heads
|
int
|
Number of attention heads for both encoders. |
required |
mlp_dim
|
int
|
Hidden size of the feed-forward sublayers. |
required |
dropout
|
float
|
Dropout applied inside the transformer blocks. |
0.0
|
emb_dropout
|
float
|
Dropout applied after adding positional embeddings. |
0.0
|
temporal_metadata_channels
|
int
|
Optional number of metadata channels reserved at the end of the spectral dimension (e.g., timestamps). |
0
|
date_range
|
tuple[int, int] | None
|
Optional (min, max) tuple defining the range of valid dates for interpolation during inference. Defaults to (1, 365) for day-of-year. For 0-indexed DOY use (0, 364), for months use (1, 12) or (0, 11). |
None
|
Source code in src/models/architectures/tsvit_lookup.py
forward(x: torch.Tensor, *, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None, inference: bool | None = None) -> torch.Tensor
¶
Run the temporal-spatial transformer over a Sentinel-2 sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (B, T, C, H, W) |
required |
batch_positions
|
Tensor | None
|
Optional (B, T) tensor with date indices (e.g., day-of-year) |
None
|
pad_mask
|
Tensor | None
|
Optional (B, T) boolean mask (True = padded/invalid) |
None
|
inference
|
bool | None
|
If True, use interpolated embeddings; if False, use direct lookup. If None (default), auto-detects based on model.training (True when model.eval(), False when model.train()). |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output logits of shape (B, num_classes, H, W) |
Source code in src/models/architectures/tsvit_lookup.py
train(mode: bool = True) -> TSViTLookup
¶
Mark inference embeddings as stale when switching to eval mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
bool
|
If True, sets to training mode; if False, sets to eval mode. |
True
|
Returns:
| Type | Description |
|---|---|
TSViTLookup
|
Self for method chaining. |
Source code in src/models/architectures/tsvit_lookup.py
TransformerBlock(dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)
¶
Bases: Module
Transformer block with pre-norm and residual connections.
Source code in src/models/architectures/tsvit_lookup.py
TransformerEncoder(dim: int, depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)
¶
Bases: Module
Stack of Transformer blocks with shared padding mask.
Source code in src/models/architectures/tsvit_lookup.py
108 109 110 111 112 113 114 115 116 117 118 119 120 |