Skip to content

U-TAE++

Modernized U-TAE with ConvNeXt blocks and attention mechanisms for temporal segmentation.

models.architectures.utae_pp

U-TAE++ Implementation - Modernized U-TAE with ConvNeXt blocks and Flash Attention.

Based on U-TAE by Vivien Sainte Fare Garnot (github/VSainteuf) Improvements: - ConvNeXt-style blocks with 7x7 depthwise conv - Flash Attention (PyTorch 2.0+) - Stochastic Depth (DropPath) - CBAM attention in decoder - Deep supervision - Layer Scale

CBAM(channels: int, reduction: int = 16, kernel_size: int = 7)

Bases: Module

Convolutional Block Attention Module.

Source code in src/models/architectures/utae_pp.py
473
474
475
476
def __init__(self, channels: int, reduction: int = 16, kernel_size: int = 7):
    super().__init__()
    self.ca = ChannelAttention(channels, reduction)
    self.sa = SpatialAttention(kernel_size)

ChannelAttention(channels: int, reduction: int = 16)

Bases: Module

Channel attention from CBAM.

Source code in src/models/architectures/utae_pp.py
440
441
442
443
444
445
446
447
448
def __init__(self, channels: int, reduction: int = 16):
    super().__init__()
    self.avg_pool = nn.AdaptiveAvgPool2d(1)
    self.max_pool = nn.AdaptiveMaxPool2d(1)
    self.fc = nn.Sequential(
        nn.Conv2d(channels, channels // reduction, 1, bias=False),
        nn.GELU(),
        nn.Conv2d(channels // reduction, channels, 1, bias=False),
    )

ConvBlock(nkernels: list[int], pad_value: float | None = None, norm: str = 'batch', last_relu: bool = True, padding_mode: str = 'reflect')

Bases: TemporallySharedBlock

Convolutional block with temporal sharing.

Source code in src/models/architectures/utae_pp.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def __init__(
    self,
    nkernels: list[int],
    pad_value: float | None = None,
    norm: str = "batch",
    last_relu: bool = True,
    padding_mode: str = "reflect",
):
    super().__init__(pad_value=pad_value)
    self.conv = ConvLayer(
        nkernels=nkernels,
        norm=norm,
        last_relu=last_relu,
        padding_mode=padding_mode,
    )

ConvLayer(nkernels: list[int], norm: Literal['batch', 'group', 'instance', 'layer'] = 'batch', k: int = 3, s: int = 1, p: int = 1, n_groups: int = 4, last_relu: bool = True, padding_mode: str = 'reflect')

Bases: Module

Basic convolution layer with norm and activation.

Source code in src/models/architectures/utae_pp.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def __init__(
    self,
    nkernels: list[int],
    norm: Literal["batch", "group", "instance", "layer"] = "batch",
    k: int = 3,
    s: int = 1,
    p: int = 1,
    n_groups: int = 4,
    last_relu: bool = True,
    padding_mode: str = "reflect",
):
    super().__init__()

    if norm == "batch":
        nl = nn.BatchNorm2d
    elif norm == "instance":
        nl = nn.InstanceNorm2d
    elif norm == "group":
        nl = lambda c: nn.GroupNorm(num_channels=c, num_groups=n_groups)
    elif norm == "layer":
        nl = lambda c: nn.GroupNorm(num_channels=c, num_groups=1)
    else:
        nl = None

    layers = []
    for i in range(len(nkernels) - 1):
        layers.append(
            nn.Conv2d(
                nkernels[i],
                nkernels[i + 1],
                kernel_size=k,
                padding=p,
                stride=s,
                padding_mode=padding_mode,
            ),
        )
        if nl is not None:
            layers.append(nl(nkernels[i + 1]))
        if last_relu or i < len(nkernels) - 2:
            layers.append(nn.GELU())

    self.conv = nn.Sequential(*layers)

ConvNeXtBlock(dim: int, expansion: int = 4, drop_path: float = 0.0, layer_scale_init: float = 1e-06, kernel_size: int = 7)

Bases: Module

ConvNeXt block with depthwise conv, inverted bottleneck, and layer scale.

Source code in src/models/architectures/utae_pp.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def __init__(
    self,
    dim: int,
    expansion: int = 4,
    drop_path: float = 0.0,
    layer_scale_init: float = 1e-6,
    kernel_size: int = 7,
):
    super().__init__()
    self.dwconv = nn.Conv2d(
        dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim,
    )
    self.norm = nn.GroupNorm(1, dim)  # LayerNorm equivalent without permute
    self.pwconv1 = nn.Conv2d(dim, expansion * dim, kernel_size=1)
    self.act = nn.GELU()
    self.pwconv2 = nn.Conv2d(expansion * dim, dim, kernel_size=1)
    self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    # Layer scale
    self.gamma = (
        nn.Parameter(layer_scale_init * torch.ones(dim, 1, 1)) if layer_scale_init > 0 else None
    )

CoordinateAttention(channels: int, reduction: int = 16)

Bases: Module

Coordinate Attention module (Hou et al., CVPR 2021).

Unlike CBAM which loses spatial information via global pooling, Coordinate Attention encodes channel relationships while preserving precise positional information via 1D horizontal and vertical pooling.

Reference: https://arxiv.org/abs/2103.02907

Source code in src/models/architectures/utae_pp.py
494
495
496
497
498
499
500
501
502
503
504
505
506
def __init__(self, channels: int, reduction: int = 16):
    super().__init__()
    reduced_channels = max(8, channels // reduction)

    self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
    self.pool_w = nn.AdaptiveAvgPool2d((1, None))

    self.conv1 = nn.Conv2d(channels, reduced_channels, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(reduced_channels)
    self.act = nn.GELU()

    self.conv_h = nn.Conv2d(reduced_channels, channels, kernel_size=1, bias=False)
    self.conv_w = nn.Conv2d(reduced_channels, channels, kernel_size=1, bias=False)

DownConvBlock(d_in: int, d_out: int, k: int, s: int, p: int, pad_value: float | None = None, norm: str = 'batch', padding_mode: str = 'reflect', drop_path: float = 0.0, use_convnext: bool = True)

Bases: TemporallySharedBlock

Downsampling block with ConvNeXt-style processing.

Source code in src/models/architectures/utae_pp.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def __init__(
    self,
    d_in: int,
    d_out: int,
    k: int,
    s: int,
    p: int,
    pad_value: float | None = None,
    norm: str = "batch",
    padding_mode: str = "reflect",
    drop_path: float = 0.0,
    use_convnext: bool = True,
):
    super().__init__(pad_value=pad_value)

    # Strided conv for downsampling
    self.down = nn.Sequential(
        nn.Conv2d(d_in, d_in, kernel_size=k, stride=s, padding=p, padding_mode=padding_mode),
        nn.GroupNorm(1, d_in) if use_convnext else nn.BatchNorm2d(d_in),
    )

    # Channel projection
    self.proj = nn.Conv2d(d_in, d_out, kernel_size=1)

    if use_convnext:
        self.conv1 = ConvNeXtBlock(d_out, drop_path=drop_path)
        self.conv2 = ConvNeXtBlock(d_out, drop_path=drop_path)
    else:
        self.conv1 = ConvLayer(nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode)
        self.conv2 = ConvLayer(nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode)

LTAE2d(in_channels: int = 128, n_head: int = 16, d_k: int = 4, mlp: list[int] | None = None, dropout: float = 0.2, d_model: int = 256, T: int = 1000, return_att: bool = False, positional_encoding: bool = True)

Bases: Module

Lightweight Temporal Attention Encoder for image time series.

Source code in src/models/architectures/utae_pp.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def __init__(
    self,
    in_channels: int = 128,
    n_head: int = 16,
    d_k: int = 4,
    mlp: list[int] | None = None,
    dropout: float = 0.2,
    d_model: int = 256,
    T: int = 1000,
    return_att: bool = False,
    positional_encoding: bool = True,
):
    super().__init__()
    if mlp is None:
        mlp = [256, 128]

    self.in_channels = in_channels
    self.mlp = copy.deepcopy(mlp)
    self.return_att = return_att
    self.n_head = n_head

    if d_model is not None:
        self.d_model = d_model
        self.inconv = nn.Conv1d(in_channels, d_model, 1)
    else:
        self.d_model = in_channels
        self.inconv = None

    assert self.mlp[0] == self.d_model

    if positional_encoding:
        self.positional_encoder = PositionalEncoder(self.d_model // n_head, T=T, repeat=n_head)
    else:
        self.positional_encoder = None

    self.attention_heads = MultiHeadAttention(n_head=n_head, d_k=d_k, d_in=self.d_model)
    self.in_norm = nn.GroupNorm(num_groups=n_head, num_channels=in_channels)
    self.out_norm = nn.GroupNorm(num_groups=n_head, num_channels=mlp[-1])

    layers = []
    for i in range(len(self.mlp) - 1):
        layers.extend(
            [
                nn.Linear(self.mlp[i], self.mlp[i + 1]),
                nn.LayerNorm(self.mlp[i + 1]),  # LayerNorm instead of BatchNorm1d
                nn.GELU(),
            ],
        )
    self.mlp = nn.Sequential(*layers)
    self.dropout = nn.Dropout(dropout)

MultiHeadAttention(n_head: int, d_k: int, d_in: int)

Bases: Module

Multi-Head Attention with learnable query.

Source code in src/models/architectures/utae_pp.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(self, n_head: int, d_k: int, d_in: int):
    super().__init__()
    self.n_head = n_head
    self.d_k = d_k
    self.d_in = d_in

    # Learnable query (shared across positions)
    self.Q = nn.Parameter(torch.zeros((n_head, d_k)))
    nn.init.normal_(self.Q, mean=0, std=math.sqrt(2.0 / d_k))

    self.fc1_k = nn.Linear(d_in, n_head * d_k)
    nn.init.normal_(self.fc1_k.weight, mean=0, std=math.sqrt(2.0 / d_k))

    self.attention = ScaledDotProductAttention(temperature=math.sqrt(d_k))

PositionalEncoder(d: int, T: int = 1000, repeat: int | None = None, offset: int = 0)

Bases: Module

Sinusoidal positional encoding for temporal sequences.

Source code in src/models/architectures/utae_pp.py
28
29
30
31
32
33
34
def __init__(self, d: int, T: int = 1000, repeat: int | None = None, offset: int = 0):
    super().__init__()
    self.d = d
    self.T = T
    self.repeat = repeat
    self.denom = torch.pow(T, 2 * (torch.arange(offset, offset + d).float() // 2) / d)
    self.updated_location = False

ScaledDotProductAttention(temperature: float, attn_dropout: float = 0.1)

Bases: Module

Scaled Dot-Product Attention with Flash Attention support.

Source code in src/models/architectures/utae_pp.py
53
54
55
56
def __init__(self, temperature: float, attn_dropout: float = 0.1):
    super().__init__()
    self.temperature = temperature
    self.dropout_p = attn_dropout

forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pad_mask: torch.Tensor | None = None, return_comp: bool = False) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Args: q: Query tensor (N, d_k) k: Key tensor (N, T, d_k) v: Value tensor (N, T, d_v) pad_mask: Padding mask (N, T) return_comp: Whether to return attention compatibility scores

Source code in src/models/architectures/utae_pp.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    pad_mask: torch.Tensor | None = None,
    return_comp: bool = False,
) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Args:
    q: Query tensor (N, d_k)
    k: Key tensor (N, T, d_k)
    v: Value tensor (N, T, d_v)
    pad_mask: Padding mask (N, T)
    return_comp: Whether to return attention compatibility scores

    """
    q = q.unsqueeze(1)  # (N, 1, d_k)

    if return_comp:
        attn = torch.matmul(q, k.transpose(1, 2)) / self.temperature
        if pad_mask is not None:
            attn = attn.masked_fill(pad_mask.unsqueeze(1), -1e9)
        comp = attn
        attn = F.softmax(attn, dim=-1)
        attn = F.dropout(attn, p=self.dropout_p, training=self.training)
        output = torch.matmul(attn, v)
        return output, attn, comp
    # Flash Attention path (PyTorch 2.0+)
    attn_mask = None
    if pad_mask is not None:
        attn_mask = pad_mask.unsqueeze(1).float() * -1e9

    output = F.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=attn_mask,
        dropout_p=self.dropout_p if self.training else 0.0,
        scale=1.0 / self.temperature,
    )
    # Return dummy attention weights for compatibility
    attn = torch.zeros(q.size(0), 1, k.size(1), device=q.device, dtype=q.dtype)
    return output, attn

SpatialAttention(kernel_size: int = 7)

Bases: Module

Spatial attention from CBAM.

Source code in src/models/architectures/utae_pp.py
459
460
461
def __init__(self, kernel_size: int = 7):
    super().__init__()
    self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)

TemporalAggregator(mode: Literal['att_group', 'att_mean', 'mean'] = 'mean')

Bases: Module

Aggregates temporal features using attention masks.

Source code in src/models/architectures/utae_pp.py
608
609
610
def __init__(self, mode: Literal["att_group", "att_mean", "mean"] = "mean"):
    super().__init__()
    self.mode = mode

TemporallySharedBlock(pad_value: float | None = None)

Bases: Module

Base class for blocks that are shared across temporal dimension.

Source code in src/models/architectures/utae_pp.py
337
338
339
340
def __init__(self, pad_value: float | None = None):
    super().__init__()
    self.pad_value = pad_value
    self.out_shape = None

UTAE(input_dim: int, encoder_widths: list[int] | None = None, decoder_widths: list[int] | None = None, out_conv: list[int] | None = None, str_conv_k: int = 4, str_conv_s: int = 2, str_conv_p: int = 1, agg_mode: str = 'att_group', encoder_norm: str = 'group', n_head: int = 16, d_model: int = 256, d_k: int = 4, encoder: bool = False, return_maps: bool = False, pad_value: float = 0, padding_mode: str = 'reflect', use_convnext: bool = True, attention_type: str = 'coord', drop_path_rate: float = 0.1, deep_supervision: bool = False)

Bases: Module

U-TAE++ - Modernized U-TAE with ConvNeXt blocks and Flash Attention.

Parameters:

Name Type Description Default
input_dim int

Number of input channels

required
encoder_widths list[int] | None

Channel widths for each encoder stage

None
decoder_widths list[int] | None

Channel widths for each decoder stage

None
out_conv list[int] | None

Output convolution channels [hidden, n_classes]

None
str_conv_k int

Kernel size for strided convolutions

4
str_conv_s int

Stride for strided convolutions

2
str_conv_p int

Padding for strided convolutions

1
agg_mode str

Temporal aggregation mode ('att_group', 'att_mean', 'mean')

'att_group'
encoder_norm str

Normalization type ('group', 'batch', 'instance')

'group'
n_head int

Number of attention heads in L-TAE

16
d_model int

Model dimension for L-TAE

256
d_k int

Key/query dimension for attention

4
encoder bool

If True, return feature maps instead of predictions

False
return_maps bool

If True, also return intermediate feature maps

False
pad_value float

Padding value for temporal sequences

0
padding_mode str

Padding mode for convolutions

'reflect'
use_convnext bool

Use ConvNeXt-style blocks

True
attention_type str

Decoder attention type ('coord', 'cbam', or 'none')

'coord'
drop_path_rate float

Stochastic depth rate

0.1
deep_supervision bool

Enable auxiliary outputs for deep supervision

False
Source code in src/models/architectures/utae_pp.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
def __init__(
    self,
    input_dim: int,
    encoder_widths: list[int] | None = None,
    decoder_widths: list[int] | None = None,
    out_conv: list[int] | None = None,
    str_conv_k: int = 4,
    str_conv_s: int = 2,
    str_conv_p: int = 1,
    agg_mode: str = "att_group",
    encoder_norm: str = "group",
    n_head: int = 16,
    d_model: int = 256,
    d_k: int = 4,
    encoder: bool = False,
    return_maps: bool = False,
    pad_value: float = 0,
    padding_mode: str = "reflect",
    use_convnext: bool = True,
    attention_type: str = "coord",
    drop_path_rate: float = 0.1,
    deep_supervision: bool = False,
):
    super().__init__()

    if encoder_widths is None:
        encoder_widths = [64, 64, 64, 128]
    if decoder_widths is None:
        decoder_widths = [32, 32, 64, 128]
    if out_conv is None:
        out_conv = [32, 20]

    self.n_stages = len(encoder_widths)
    self.return_maps = return_maps
    self.encoder_widths = encoder_widths
    self.decoder_widths = decoder_widths
    self.enc_dim = decoder_widths[0] if decoder_widths is not None else encoder_widths[0]
    self.stack_dim = sum(decoder_widths) if decoder_widths is not None else sum(encoder_widths)
    self.pad_value = pad_value
    self.encoder = encoder
    self.use_convnext = use_convnext
    self.deep_supervision = deep_supervision

    if encoder:
        self.return_maps = True

    if decoder_widths is not None:
        assert len(encoder_widths) == len(decoder_widths)
        assert encoder_widths[-1] == decoder_widths[-1]
    else:
        decoder_widths = encoder_widths

    # Stochastic depth
    total_blocks = 2 * (self.n_stages - 1)  # 2 ConvNeXt blocks per stage
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_blocks)]

    # Input conv
    self.in_conv = ConvBlock(
        nkernels=[input_dim, encoder_widths[0], encoder_widths[0]],
        pad_value=pad_value,
        norm=encoder_norm,
        padding_mode=padding_mode,
    )

    # Encoder
    self.down_blocks = nn.ModuleList(
        [
            DownConvBlock(
                d_in=encoder_widths[i],
                d_out=encoder_widths[i + 1],
                k=str_conv_k,
                s=str_conv_s,
                p=str_conv_p,
                pad_value=pad_value,
                norm=encoder_norm,
                padding_mode=padding_mode,
                drop_path=dpr[i * 2],
                use_convnext=use_convnext,
            )
            for i in range(self.n_stages - 1)
        ],
    )

    # Decoder
    self.up_blocks = nn.ModuleList(
        [
            UpConvBlock(
                d_in=decoder_widths[i],
                d_out=decoder_widths[i - 1],
                d_skip=encoder_widths[i - 1],
                k=str_conv_k,
                s=str_conv_s,
                p=str_conv_p,
                norm="batch",
                padding_mode=padding_mode,
                attention_type=attention_type,
                use_convnext=use_convnext,
                drop_path=dpr[-(i * 2 + 1)] if i < len(dpr) // 2 else 0.0,
            )
            for i in range(self.n_stages - 1, 0, -1)
        ],
    )

    # Temporal encoder
    self.temporal_encoder = LTAE2d(
        in_channels=encoder_widths[-1],
        d_model=d_model,
        n_head=n_head,
        mlp=[d_model, encoder_widths[-1]],
        return_att=True,
        d_k=d_k,
    )
    self.temporal_aggregator = TemporalAggregator(mode=agg_mode)

    # Output
    self.out_conv = ConvBlock(
        nkernels=[decoder_widths[0]] + out_conv,
        padding_mode=padding_mode,
    )

    # Deep supervision auxiliary heads
    if deep_supervision:
        self.aux_heads = nn.ModuleList(
            [
                nn.Conv2d(decoder_widths[i - 1], out_conv[-1], kernel_size=1)
                for i in range(self.n_stages - 1, 1, -1)
            ],
        )
    else:
        self.aux_heads = None

forward(input: torch.Tensor, batch_positions: torch.Tensor | None = None, pad_mask: torch.Tensor | None = None, return_att: bool = False) -> torch.Tensor | tuple

Parameters:

Name Type Description Default
input Tensor

Input tensor (B, T, C, H, W)

required
batch_positions Tensor | None

Temporal positions (B, T)

None
pad_mask Tensor | None

Boolean padding mask (B, T) where True indicates padded timesteps. If not provided, computed from input using self.pad_value.

None
return_att bool

Return attention maps

False

Returns:

Type Description
Tensor | tuple

Segmentation output and optionally attention/auxiliary outputs

Source code in src/models/architectures/utae_pp.py
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
def forward(
    self,
    input: torch.Tensor,
    batch_positions: torch.Tensor | None = None,
    pad_mask: torch.Tensor | None = None,
    return_att: bool = False,
) -> torch.Tensor | tuple:
    """Args:
        input: Input tensor (B, T, C, H, W)
        batch_positions: Temporal positions (B, T)
        pad_mask: Boolean padding mask (B, T) where True indicates padded timesteps.
            If not provided, computed from input using self.pad_value.
        return_att: Return attention maps

    Returns:
        Segmentation output and optionally attention/auxiliary outputs

    """
    # Use external pad_mask if provided, otherwise compute from input
    if pad_mask is None:
        pad_mask = (input == self.pad_value).all(dim=-1).all(dim=-1).all(dim=-1)

    # Input convolution
    out = self.in_conv.smart_forward(input)
    feature_maps = [out]

    # Spatial encoder
    for i in range(self.n_stages - 1):
        out = self.down_blocks[i].smart_forward(feature_maps[-1])
        feature_maps.append(out)

    # Temporal encoder
    out, att = self.temporal_encoder(
        feature_maps[-1], batch_positions=batch_positions, pad_mask=pad_mask,
    )

    # Spatial decoder
    if self.return_maps:
        maps = [out]

    aux_outputs = []
    for i in range(self.n_stages - 1):
        skip = self.temporal_aggregator(
            feature_maps[-(i + 2)], pad_mask=pad_mask, attn_mask=att,
        )
        out = self.up_blocks[i](out, skip)

        if self.return_maps:
            maps.append(out)

        # Deep supervision
        if self.training and self.aux_heads is not None and i < len(self.aux_heads):
            aux_out = self.aux_heads[i](out)
            aux_out = F.interpolate(
                aux_out, size=input.shape[-2:], mode="bilinear", align_corners=False,
            )
            aux_outputs.append(aux_out)

    # Final output
    if self.encoder:
        return out, maps

    out = self.out_conv(out)

    # Return format
    if self.training and self.deep_supervision and aux_outputs:
        if return_att:
            return out, att, aux_outputs
        return out, aux_outputs

    if return_att:
        return out, att
    if self.return_maps:
        return out, maps
    return out

UpConvBlock(d_in: int, d_out: int, k: int, s: int, p: int, norm: str = 'batch', d_skip: int | None = None, padding_mode: str = 'reflect', attention_type: str = 'coord', use_convnext: bool = True, drop_path: float = 0.0)

Bases: Module

Upsampling block with configurable attention (CBAM, Coordinate, or none).

Source code in src/models/architectures/utae_pp.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def __init__(
    self,
    d_in: int,
    d_out: int,
    k: int,
    s: int,
    p: int,
    norm: str = "batch",
    d_skip: int | None = None,
    padding_mode: str = "reflect",
    attention_type: str = "coord",
    use_convnext: bool = True,
    drop_path: float = 0.0,
):
    super().__init__()
    d = d_out if d_skip is None else d_skip

    self.skip_attn = build_attention(attention_type, d)

    self.up = nn.Sequential(
        nn.ConvTranspose2d(d_in, d_out, kernel_size=k, stride=s, padding=p),
        nn.GroupNorm(1, d_out) if use_convnext else nn.BatchNorm2d(d_out),
        nn.GELU(),
    )

    if use_convnext:
        self.conv1 = ConvNeXtBlock(d_out + d, drop_path=drop_path)
        self.proj = nn.Conv2d(d_out + d, d_out, kernel_size=1)
        self.conv2 = ConvNeXtBlock(d_out, drop_path=drop_path)
    else:
        self.conv1 = ConvLayer(
            nkernels=[d_out + d, d_out], norm=norm, padding_mode=padding_mode,
        )
        self.proj = nn.Identity()
        self.conv2 = ConvLayer(nkernels=[d_out, d_out], norm=norm, padding_mode=padding_mode)

build_attention(attention_type: str, channels: int, reduction: int = 16) -> nn.Module

Factory function to build attention modules.

Parameters:

Name Type Description Default
attention_type str

Type of attention ('cbam', 'coord', or 'none')

required
channels int

Number of input channels

required
reduction int

Channel reduction ratio

16

Returns:

Type Description
Module

Attention module or identity-like fallback

Source code in src/models/architectures/utae_pp.py
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def build_attention(attention_type: str, channels: int, reduction: int = 16) -> nn.Module:
    """Factory function to build attention modules.

    Args:
        attention_type: Type of attention ('cbam', 'coord', or 'none')
        channels: Number of input channels
        reduction: Channel reduction ratio

    Returns:
        Attention module or identity-like fallback
    """
    if attention_type == "cbam":
        return CBAM(channels, reduction)
    elif attention_type == "coord":
        return CoordinateAttention(channels, reduction)
    else:  # 'none' or any other value
        return nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=1),
            nn.BatchNorm2d(channels),
            nn.GELU(),
        )