Skip to content

TSViT Lookup

TSViT variant with lookup-based temporal position embeddings.

models.architectures.tsvit_lookup

TSViT with lookup-based temporal position embeddings.

MultiHeadSelfAttention(dim: int, num_heads: int, dropout: float = 0.0)

Bases: Module

Simplified multi-head self-attention with optional padding mask.

Source code in src/models/architectures/tsvit_lookup.py
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, dim: int, num_heads: int, dropout: float = 0.0) -> None:
    super().__init__()
    if dim % num_heads != 0:
        msg = "Embedding dimension must be divisible by number of heads"
        raise ValueError(msg)

    self.num_heads = num_heads
    self.head_dim = dim // num_heads

    self.qkv = nn.Linear(dim, dim * 3, bias=False)
    self.proj = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))

forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor

Apply multi-head self-attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, seq_len, dim)

required
key_padding_mask Tensor | None

Optional boolean mask of shape (batch, seq_len) where True indicates positions to ignore

None

Returns:

Type Description
Tensor

Output tensor of shape (batch, seq_len, dim)

Source code in src/models/architectures/tsvit_lookup.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def forward(
    self,
    x: torch.Tensor,
    key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Apply multi-head self-attention.

    Args:
        x: Input tensor of shape (batch, seq_len, dim)
        key_padding_mask: Optional boolean mask of shape (batch, seq_len)
            where True indicates positions to ignore

    Returns:
        Output tensor of shape (batch, seq_len, dim)

    """
    batch, seq_len, _ = x.shape
    q, k, v = self.qkv(x).chunk(3, dim=-1)

    # Reshape to (batch, num_heads, seq_len, head_dim)
    q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # Convert key_padding_mask to attention mask format for SDPA
    # SDPA expects: True = attend, False = ignore (opposite of key_padding_mask)
    attn_mask = None
    if key_padding_mask is not None:
        # Expand to (batch, 1, 1, seq_len) for broadcasting
        attn_mask = ~key_padding_mask.unsqueeze(1).unsqueeze(2)

    # Use PyTorch's fused scaled_dot_product_attention (FlashAttention when available)
    out = torch.nn.functional.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=0.0,
        is_causal=False,
    )

    out = out.transpose(1, 2).reshape(batch, seq_len, -1)
    return self.proj(out)

TSViTLookup(*, image_size: int, patch_size: int, in_channels: int, num_classes: int, train_dates: list[int] | torch.Tensor, dim: int, temporal_depth: int, spatial_depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0, emb_dropout: float = 0.0, temporal_metadata_channels: int = 0, date_range: tuple[int, int] | None = None)

Bases: Module

Temporal-Spatial ViT with lookup-based temporal position embeddings.

This is the most advanced temporal encoding approach from the paper. It learns a separate position embedding for each unique date seen during training, then uses linear interpolation for unseen dates during inference.

Initialize the TSViT module with lookup temporal embeddings.

Parameters:

Name Type Description Default
image_size int

Height/width (in Sentinel pixels) of the cropped patch.

required
patch_size int

Edge length of each Vision Transformer patch.

required
in_channels int

Number of spectral channels produced by the dataset.

required
num_classes int

Number of segmentation categories.

required
train_dates list[int] | Tensor

List/tensor of unique dates (e.g., day-of-year 1-365) seen in training. During inference, dates not in this list are interpolated. IMPORTANT: All positions passed to forward() must use the same indexing scheme. For day-of-year: use 1-365 (not 0-364). For months: use 1-12 (not 0-11) or pass custom date_range.

required
dim int

Embedding dimension of token representations.

required
temporal_depth int

Number of transformer blocks in the temporal encoder.

required
spatial_depth int

Number of transformer blocks in the spatial encoder.

required
num_heads int

Number of attention heads for both encoders.

required
mlp_dim int

Hidden size of the feed-forward sublayers.

required
dropout float

Dropout applied inside the transformer blocks.

0.0
emb_dropout float

Dropout applied after adding positional embeddings.

0.0
temporal_metadata_channels int

Optional number of metadata channels reserved at the end of the spectral dimension (e.g., timestamps).

0
date_range tuple[int, int] | None

Optional (min, max) tuple defining the range of valid dates for interpolation during inference. Defaults to (1, 365) for day-of-year. For 0-indexed DOY use (0, 364), for months use (1, 12) or (0, 11).

None
Source code in src/models/architectures/tsvit_lookup.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def __init__(
    self,
    *,
    image_size: int,
    patch_size: int,
    in_channels: int,
    num_classes: int,
    train_dates: list[int] | torch.Tensor,
    dim: int,
    temporal_depth: int,
    spatial_depth: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
    emb_dropout: float = 0.0,
    temporal_metadata_channels: int = 0,
    date_range: tuple[int, int] | None = None,
) -> None:
    """Initialize the TSViT module with lookup temporal embeddings.

    Args:
        image_size: Height/width (in Sentinel pixels) of the cropped patch.
        patch_size: Edge length of each Vision Transformer patch.
        in_channels: Number of spectral channels produced by the dataset.
        num_classes: Number of segmentation categories.
        train_dates: List/tensor of unique dates (e.g., day-of-year 1-365) seen in training.
            During inference, dates not in this list are interpolated.
            IMPORTANT: All positions passed to forward() must use the same indexing scheme.
            For day-of-year: use 1-365 (not 0-364).
            For months: use 1-12 (not 0-11) or pass custom date_range.
        dim: Embedding dimension of token representations.
        temporal_depth: Number of transformer blocks in the temporal encoder.
        spatial_depth: Number of transformer blocks in the spatial encoder.
        num_heads: Number of attention heads for both encoders.
        mlp_dim: Hidden size of the feed-forward sublayers.
        dropout: Dropout applied inside the transformer blocks.
        emb_dropout: Dropout applied after adding positional embeddings.
        temporal_metadata_channels: Optional number of metadata channels reserved
            at the end of the spectral dimension (e.g., timestamps).
        date_range: Optional (min, max) tuple defining the range of valid dates for
            interpolation during inference. Defaults to (1, 365) for day-of-year.
            For 0-indexed DOY use (0, 364), for months use (1, 12) or (0, 11).

    """
    super().__init__()
    if image_size % patch_size != 0:
        msg = "image_size must be divisible by patch_size"
        raise ValueError(msg)
    if in_channels <= temporal_metadata_channels:
        msg = "in_channels must be larger than temporal_metadata_channels"
        raise ValueError(msg)

    self.image_size = image_size
    self.patch_size = patch_size
    self.num_classes = num_classes
    self.dim = dim
    self.temporal_metadata_channels = temporal_metadata_channels

    # Register train dates as buffer (non-trainable parameter)
    if isinstance(train_dates, list):
        train_dates = torch.tensor(sorted(train_dates), dtype=torch.long)
    else:
        train_dates = torch.sort(train_dates.long())[0]

    # For inference, we'll interpolate over all possible dates in the specified range
    if date_range is None:
        date_range = (1, 365)  # Default to 1-indexed day-of-year
    self.date_range = date_range

    # Validate train_dates are within date_range
    if train_dates.min() < date_range[0] or train_dates.max() > date_range[1]:
        msg = (
            f"train_dates must be within date_range [{date_range[0]}, {date_range[1]}]. "
            f"Got train_dates range: [{train_dates.min()}, {train_dates.max()}]"
        )
        raise ValueError(msg)

    self.register_buffer("train_dates", train_dates)
    self.register_buffer(
        "eval_dates",
        torch.arange(date_range[0], date_range[1] + 1, dtype=torch.long),
    )

    self.num_patches_1d = image_size // patch_size
    self.num_patches = self.num_patches_1d**2
    patch_dim = (in_channels - temporal_metadata_channels) * (patch_size**2)

    self.to_patch_embedding = nn.Sequential(
        Rearrange(
            "b t c (h p1) (w p2) -> (b h w) t (p1 p2 c)",
            p1=patch_size,
            p2=patch_size,
        ),
        nn.Linear(patch_dim, dim),
    )

    # Learnable position embedding for each training date
    self.temporal_pos_embedding = nn.Parameter(
        torch.randn(len(train_dates), dim),
        requires_grad=True,
    )

    # Pre-compute interpolated embeddings for evaluation (lazy evaluation)
    num_eval_dates = date_range[1] - date_range[0] + 1
    self.register_buffer(
        "inference_temporal_pos_embedding",
        torch.zeros(num_eval_dates, dim),
    )
    self._inference_embeddings_stale = True  # Mark for lazy computation

    self.temporal_dropout = nn.Dropout(emb_dropout)
    self.temporal_cls_tokens = nn.Parameter(torch.randn(1, num_classes, dim))
    self.temporal_encoder = TransformerEncoder(
        dim=dim,
        depth=temporal_depth,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
    )

    self.space_pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, dim))
    self.space_dropout = nn.Dropout(emb_dropout)
    self.space_encoder = TransformerEncoder(
        dim=dim,
        depth=spatial_depth,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
    )
    self.head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, patch_size**2))

forward(x: torch.Tensor, *, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None, inference: bool | None = None) -> torch.Tensor

Run the temporal-spatial transformer over a Sentinel-2 sequence.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, T, C, H, W)

required
batch_positions Tensor | None

Optional (B, T) tensor with date indices (e.g., day-of-year)

None
pad_mask Tensor | None

Optional (B, T) boolean mask (True = padded/invalid)

None
inference bool | None

If True, use interpolated embeddings; if False, use direct lookup. If None (default), auto-detects based on model.training (True when model.eval(), False when model.train()).

None

Returns:

Type Description
Tensor

Output logits of shape (B, num_classes, H, W)

Source code in src/models/architectures/tsvit_lookup.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def forward(
    self,
    x: torch.Tensor,
    *,
    batch_positions: torch.Tensor | None = None,
    pad_mask: torch.Tensor | None = None,
    inference: bool | None = None,
) -> torch.Tensor:
    """Run the temporal-spatial transformer over a Sentinel-2 sequence.

    Args:
        x: Input tensor of shape (B, T, C, H, W)
        batch_positions: Optional (B, T) tensor with date indices (e.g., day-of-year)
        pad_mask: Optional (B, T) boolean mask (True = padded/invalid)
        inference: If True, use interpolated embeddings; if False, use direct lookup.
            If None (default), auto-detects based on model.training
            (True when model.eval(), False when model.train()).

    Returns:
        Output logits of shape (B, num_classes, H, W)

    """
    # Auto-detect inference mode if not specified
    if inference is None:
        inference = not self.training
    if x.ndim != 5:
        msg = "TSViTLookup expects input of shape (B, T, C, H, W)"
        raise ValueError(msg)

    batch_size, time, _, height, width = x.shape
    if height != self.image_size or width != self.image_size:
        msg = (
            f"Input spatial resolution {(height, width)} does not match "
            f"configured image_size {self.image_size}"
        )
        raise ValueError(msg)

    # Update interpolated embeddings if in inference mode
    if inference:
        self._update_inference_temporal_position_embeddings()

    clean_x, metadata = self._split_metadata(x)

    # Get temporal positions (default to sequential if not provided)
    if batch_positions is None:
        # Use sequential dates starting from the minimum of the date range
        batch_positions = (
            torch.arange(
                self.date_range[0],
                self.date_range[0] + time,
                device=x.device,
            )
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )

    temporal_pos = self._get_temporal_position_embeddings(
        batch_positions,
        inference=inference,
    )

    effective_mask = self._resolve_pad_mask(pad_mask, metadata, clean_x.device)

    temporal_tokens = self._encode_temporal(
        clean_x,
        temporal_pos,
        effective_mask,
        batch_size,
        time,
    )

    return self._decode_spatial(temporal_tokens, batch_size)

train(mode: bool = True) -> TSViTLookup

Mark inference embeddings as stale when switching to eval mode.

Parameters:

Name Type Description Default
mode bool

If True, sets to training mode; if False, sets to eval mode.

True

Returns:

Type Description
TSViTLookup

Self for method chaining.

Source code in src/models/architectures/tsvit_lookup.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def train(self, mode: bool = True) -> TSViTLookup:
    """Mark inference embeddings as stale when switching to eval mode.

    Args:
        mode: If True, sets to training mode; if False, sets to eval mode.

    Returns:
        Self for method chaining.

    """
    result = super().train(mode)
    if not mode:  # Switching to eval mode
        self._inference_embeddings_stale = True
    return result

TransformerBlock(dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)

Bases: Module

Transformer block with pre-norm and residual connections.

Source code in src/models/architectures/tsvit_lookup.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.attn = MultiHeadSelfAttention(dim, num_heads, dropout=dropout)
    self.norm2 = nn.LayerNorm(dim)
    self.mlp = nn.Sequential(
        nn.Linear(dim, mlp_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(mlp_dim, dim),
        nn.Dropout(dropout),
    )

TransformerEncoder(dim: int, depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)

Bases: Module

Stack of Transformer blocks with shared padding mask.

Source code in src/models/architectures/tsvit_lookup.py
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    dim: int,
    depth: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self.layers = nn.ModuleList(
        [TransformerBlock(dim, num_heads, mlp_dim, dropout) for _ in range(depth)],
    )
    self.norm = nn.LayerNorm(dim)