Skip to content

TSViT

Temporal-Spatial Vision Transformer for Sentinel-2 time series segmentation.

models.architectures.tsvit

Temporal-Spatial Vision Transformer (TSViT) implementation.

This module adapts the TSViT architecture from "ViTs for SITS: Vision Transformers for Satellite Image Time Series" (CVPR 2023, Tarasiou et al.) to the FLAIR-2 Sentinel-2 only scenario. The implementation follows the original design (temporal transformer followed by spatial transformer over patch tokens) while relaxing assumptions about sequence length and date encodings so it can ingest the monthly averaged sentinel stacks produced by this repository.

MultiHeadSelfAttention(dim: int, num_heads: int, dropout: float = 0.0)

Bases: Module

Simplified multi-head self-attention with optional padding mask.

Source code in src/models/architectures/tsvit.py
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, dim: int, num_heads: int, dropout: float = 0.0) -> None:
    super().__init__()
    if dim % num_heads != 0:
        msg = "Embedding dimension must be divisible by number of heads"
        raise ValueError(msg)

    self.num_heads = num_heads
    self.head_dim = dim // num_heads

    self.qkv = nn.Linear(dim, dim * 3, bias=False)
    self.proj = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))

forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor

Apply multi-head self-attention.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, seq_len, dim)

required
key_padding_mask Tensor | None

Optional boolean mask of shape (batch, seq_len) where True indicates positions to ignore

None

Returns:

Type Description
Tensor

Output tensor of shape (batch, seq_len, dim)

Source code in src/models/architectures/tsvit.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def forward(
    self,
    x: torch.Tensor,
    key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Apply multi-head self-attention.

    Args:
        x: Input tensor of shape (batch, seq_len, dim)
        key_padding_mask: Optional boolean mask of shape (batch, seq_len)
            where True indicates positions to ignore

    Returns:
        Output tensor of shape (batch, seq_len, dim)

    """
    batch, seq_len, _ = x.shape
    q, k, v = self.qkv(x).chunk(3, dim=-1)

    # Reshape to (batch, num_heads, seq_len, head_dim)
    q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    # Convert key_padding_mask to attention mask format for SDPA
    # SDPA expects: True = attend, False = ignore (opposite of key_padding_mask)
    attn_mask = None
    if key_padding_mask is not None:
        # Expand to (batch, 1, 1, seq_len) for broadcasting
        attn_mask = ~key_padding_mask.unsqueeze(1).unsqueeze(2)

    # Use PyTorch's fused scaled_dot_product_attention (FlashAttention when available)
    out = torch.nn.functional.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=0.0,
        is_causal=False,
    )

    out = out.transpose(1, 2).reshape(batch, seq_len, -1)
    return self.proj(out)

TSViT(*, image_size: int, patch_size: int, in_channels: int, num_classes: int, max_seq_len: int, dim: int, temporal_depth: int, spatial_depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0, emb_dropout: float = 0.0, temporal_metadata_channels: int = 0)

Bases: Module

Temporal-Spatial Vision Transformer tailored to Sentinel-2 patches.

Initialize the TSViT module.

Parameters:

Name Type Description Default
image_size int

Height/width (in Sentinel pixels) of the cropped patch.

required
patch_size int

Edge length of each Vision Transformer patch.

required
in_channels int

Number of spectral channels produced by the dataset.

required
num_classes int

Number of segmentation categories.

required
max_seq_len int

Size of the one-hot position encoding. The last index (max_seq_len - 1) is reserved for padding tokens, so valid position indices must be in range [0, max_seq_len - 2]. Set to 13 for month-of-year encoding (months 0-11 + padding) or 367 for day-of-year encoding (days 0-365 + padding).

required
dim int

Embedding dimension of token representations.

required
temporal_depth int

Number of transformer blocks in the temporal encoder.

required
spatial_depth int

Number of transformer blocks in the spatial encoder.

required
num_heads int

Number of attention heads for both encoders.

required
mlp_dim int

Hidden size of the feed-forward sublayers.

required
dropout float

Dropout applied inside the transformer blocks.

0.0
emb_dropout float

Dropout applied after adding positional embeddings.

0.0
temporal_metadata_channels int

Optional number of metadata channels reserved at the end of the spectral dimension (e.g., timestamps).

0
Source code in src/models/architectures/tsvit.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def __init__(
    self,
    *,
    image_size: int,
    patch_size: int,
    in_channels: int,
    num_classes: int,
    max_seq_len: int,
    dim: int,
    temporal_depth: int,
    spatial_depth: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
    emb_dropout: float = 0.0,
    temporal_metadata_channels: int = 0,
) -> None:
    """Initialize the TSViT module.

    Args:
        image_size: Height/width (in Sentinel pixels) of the cropped patch.
        patch_size: Edge length of each Vision Transformer patch.
        in_channels: Number of spectral channels produced by the dataset.
        num_classes: Number of segmentation categories.
        max_seq_len: Size of the one-hot position encoding. The last index
            (max_seq_len - 1) is reserved for padding tokens, so valid position
            indices must be in range [0, max_seq_len - 2]. Set to 13 for
            month-of-year encoding (months 0-11 + padding) or 367 for
            day-of-year encoding (days 0-365 + padding).
        dim: Embedding dimension of token representations.
        temporal_depth: Number of transformer blocks in the temporal encoder.
        spatial_depth: Number of transformer blocks in the spatial encoder.
        num_heads: Number of attention heads for both encoders.
        mlp_dim: Hidden size of the feed-forward sublayers.
        dropout: Dropout applied inside the transformer blocks.
        emb_dropout: Dropout applied after adding positional embeddings.
        temporal_metadata_channels: Optional number of metadata channels reserved
            at the end of the spectral dimension (e.g., timestamps).

    """
    super().__init__()
    if image_size % patch_size != 0:
        msg = "image_size must be divisible by patch_size"
        raise ValueError(msg)
    if in_channels <= temporal_metadata_channels:
        msg = "in_channels must be larger than temporal_metadata_channels"
        raise ValueError(msg)

    self.image_size = image_size
    self.patch_size = patch_size
    self.num_classes = num_classes
    self.max_seq_len = max_seq_len
    self.dim = dim
    self.temporal_metadata_channels = temporal_metadata_channels

    self.num_patches_1d = image_size // patch_size
    self.num_patches = self.num_patches_1d**2
    patch_dim = (in_channels - temporal_metadata_channels) * (patch_size**2)

    self.to_patch_embedding = nn.Sequential(
        Rearrange(
            "b t c (h p1) (w p2) -> (b h w) t (p1 p2 c)",
            p1=patch_size,
            p2=patch_size,
        ),
        nn.Linear(patch_dim, dim),
    )
    # Use one-hot encoding + linear projection for temporal positions (as in paper)
    # max_seq_len should be 366 for day-of-year or 12 for month-of-year
    self.to_temporal_embedding_input = nn.Linear(max_seq_len, dim)
    self.temporal_dropout = nn.Dropout(emb_dropout)
    self.temporal_cls_tokens = nn.Parameter(torch.randn(1, num_classes, dim))
    self.temporal_encoder = TransformerEncoder(
        dim=dim,
        depth=temporal_depth,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
    )

    self.space_pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, dim))
    self.space_dropout = nn.Dropout(emb_dropout)
    self.space_encoder = TransformerEncoder(
        dim=dim,
        depth=spatial_depth,
        num_heads=num_heads,
        mlp_dim=mlp_dim,
        dropout=dropout,
    )
    self.head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, patch_size**2))

forward(x: torch.Tensor, *, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None) -> torch.Tensor

Run the temporal-spatial transformer over a Sentinel-2 sequence.

Source code in src/models/architectures/tsvit.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def forward(
    self,
    x: torch.Tensor,
    *,
    batch_positions: torch.Tensor | None = None,
    pad_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Run the temporal-spatial transformer over a Sentinel-2 sequence."""
    if x.ndim != TEMPORAL_INPUT_NDIM:
        msg = "TSViT expects input of shape (B, T, C, H, W)"
        raise ValueError(msg)

    batch_size, time, _, height, width = x.shape
    if height != self.image_size or width != self.image_size:
        msg = (
            "Input spatial resolution does not match configured image_size: "
            f"expected {self.image_size}, got {(height, width)}"
        )
        raise ValueError(msg)

    clean_x, metadata = self._split_metadata(x)
    temporal_pos = self._prepare_positions(batch_positions, batch_size, time, x.device)
    effective_mask = self._resolve_pad_mask(pad_mask, metadata, clean_x.device)

    temporal_tokens = self._encode_temporal(
        clean_x,
        temporal_pos,
        effective_mask,
        batch_size,
        time,
    )

    return self._decode_spatial(temporal_tokens, batch_size)

TransformerBlock(dim: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)

Bases: Module

Transformer block with pre-norm and residual connections.

Source code in src/models/architectures/tsvit.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.attn = MultiHeadSelfAttention(dim, num_heads, dropout=dropout)
    self.norm2 = nn.LayerNorm(dim)
    self.mlp = nn.Sequential(
        nn.Linear(dim, mlp_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(mlp_dim, dim),
        nn.Dropout(dropout),
    )

forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor

Apply transformer block with attention and feedforward.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, seq_len, dim)

required
key_padding_mask Tensor | None

Optional boolean mask for attention

None

Returns:

Type Description
Tensor

Output tensor of shape (batch, seq_len, dim)

Source code in src/models/architectures/tsvit.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def forward(
    self,
    x: torch.Tensor,
    key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Apply transformer block with attention and feedforward.

    Args:
        x: Input tensor of shape (batch, seq_len, dim)
        key_padding_mask: Optional boolean mask for attention

    Returns:
        Output tensor of shape (batch, seq_len, dim)

    """
    attn_out = self.attn(self.norm1(x), key_padding_mask=key_padding_mask)
    x = x + attn_out
    return x + self.mlp(self.norm2(x))

TransformerEncoder(dim: int, depth: int, num_heads: int, mlp_dim: int, dropout: float = 0.0)

Bases: Module

Stack of Transformer blocks with shared padding mask.

Source code in src/models/architectures/tsvit.py
128
129
130
131
132
133
134
135
136
137
138
139
140
def __init__(
    self,
    dim: int,
    depth: int,
    num_heads: int,
    mlp_dim: int,
    dropout: float = 0.0,
) -> None:
    super().__init__()
    self.layers = nn.ModuleList(
        [TransformerBlock(dim, num_heads, mlp_dim, dropout) for _ in range(depth)],
    )
    self.norm = nn.LayerNorm(dim)

forward(x: torch.Tensor, key_padding_mask: torch.Tensor | None = None) -> torch.Tensor

Apply stack of transformer blocks.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, seq_len, dim)

required
key_padding_mask Tensor | None

Optional boolean mask for attention

None

Returns:

Type Description
Tensor

Output tensor of shape (batch, seq_len, dim) after layer normalization

Source code in src/models/architectures/tsvit.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def forward(
    self,
    x: torch.Tensor,
    key_padding_mask: torch.Tensor | None = None,
) -> torch.Tensor:
    """Apply stack of transformer blocks.

    Args:
        x: Input tensor of shape (batch, seq_len, dim)
        key_padding_mask: Optional boolean mask for attention

    Returns:
        Output tensor of shape (batch, seq_len, dim) after layer normalization

    """
    for layer in self.layers:
        x = layer(x, key_padding_mask=key_padding_mask)
    return self.norm(x)