wavedl 1.5.7__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. wavedl/__init__.py +1 -1
  2. wavedl/hpo.py +451 -451
  3. wavedl/models/__init__.py +80 -4
  4. wavedl/models/_pretrained_utils.py +366 -0
  5. wavedl/models/base.py +48 -0
  6. wavedl/models/caformer.py +270 -0
  7. wavedl/models/cnn.py +2 -27
  8. wavedl/models/convnext.py +113 -51
  9. wavedl/models/convnext_v2.py +488 -0
  10. wavedl/models/densenet.py +10 -23
  11. wavedl/models/efficientnet.py +6 -6
  12. wavedl/models/efficientnetv2.py +315 -315
  13. wavedl/models/efficientvit.py +398 -0
  14. wavedl/models/fastvit.py +252 -0
  15. wavedl/models/mamba.py +555 -0
  16. wavedl/models/maxvit.py +254 -0
  17. wavedl/models/mobilenetv3.py +295 -295
  18. wavedl/models/regnet.py +406 -406
  19. wavedl/models/resnet.py +19 -61
  20. wavedl/models/resnet3d.py +258 -258
  21. wavedl/models/swin.py +443 -443
  22. wavedl/models/tcn.py +393 -409
  23. wavedl/models/unet.py +2 -6
  24. wavedl/models/unireplknet.py +491 -0
  25. wavedl/models/vit.py +9 -9
  26. wavedl/train.py +1430 -1425
  27. wavedl/utils/config.py +367 -367
  28. wavedl/utils/cross_validation.py +530 -530
  29. wavedl/utils/data.py +39 -6
  30. wavedl/utils/losses.py +216 -216
  31. wavedl/utils/optimizers.py +216 -216
  32. wavedl/utils/schedulers.py +251 -251
  33. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/METADATA +150 -82
  34. wavedl-1.6.1.dist-info/RECORD +46 -0
  35. wavedl-1.5.7.dist-info/RECORD +0 -38
  36. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/LICENSE +0 -0
  37. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/WHEEL +0 -0
  38. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/entry_points.txt +0 -0
  39. {wavedl-1.5.7.dist-info → wavedl-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ """
2
+ MaxViT: Multi-Axis Vision Transformer
3
+ ======================================
4
+
5
+ MaxViT combines local and global attention with O(n) complexity using
6
+ multi-axis attention: block attention (local) + grid attention (global sparse).
7
+
8
+ **Key Features**:
9
+ - Multi-axis attention for both local and global context
10
+ - Hybrid design with MBConv + attention
11
+ - Linear O(n) complexity
12
+ - Hierarchical multi-scale features
13
+
14
+ **Variants**:
15
+ - maxvit_tiny: 31M params
16
+ - maxvit_small: 69M params
17
+ - maxvit_base: 120M params
18
+
19
+ **Requirements**:
20
+ - timm (for pretrained models and architecture)
21
+ - torchvision (fallback, limited support)
22
+
23
+ Reference:
24
+ Tu, Z., et al. (2022). MaxViT: Multi-Axis Vision Transformer.
25
+ ECCV 2022. https://arxiv.org/abs/2204.01697
26
+
27
+ Author: Ductho Le (ductho.le@outlook.com)
28
+ """
29
+
30
+ import torch
31
+ import torch.nn.functional as F
32
+
33
+ from wavedl.models._pretrained_utils import build_regression_head
34
+ from wavedl.models.base import BaseModel
35
+ from wavedl.models.registry import register_model
36
+
37
+
38
+ __all__ = [
39
+ "MaxViTBase",
40
+ "MaxViTBaseLarge",
41
+ "MaxViTSmall",
42
+ "MaxViTTiny",
43
+ ]
44
+
45
+
46
+ # =============================================================================
47
+ # MAXVIT BASE CLASS
48
+ # =============================================================================
49
+
50
+
51
+ class MaxViTBase(BaseModel):
52
+ """
53
+ MaxViT base class wrapping timm implementation.
54
+
55
+ Multi-axis attention with local block and global grid attention.
56
+ 2D only due to attention structure.
57
+
58
+ Note:
59
+ MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
60
+ This implementation automatically resizes inputs to the nearest compatible size.
61
+ """
62
+
63
+ # MaxViT stem downsamples by 4x, then requires divisibility by 7 (window size)
64
+ # So original input must be divisible by 4 * 7 = 28
65
+ _DIVISOR = 28
66
+
67
+ def __init__(
68
+ self,
69
+ in_shape: tuple[int, int],
70
+ out_size: int,
71
+ model_name: str = "maxvit_tiny_tf_224",
72
+ pretrained: bool = True,
73
+ freeze_backbone: bool = False,
74
+ dropout_rate: float = 0.3,
75
+ **kwargs,
76
+ ):
77
+ super().__init__(in_shape, out_size)
78
+
79
+ if len(in_shape) != 2:
80
+ raise ValueError(f"MaxViT requires 2D input (H, W), got {len(in_shape)}D")
81
+
82
+ self.pretrained = pretrained
83
+ self.freeze_backbone = freeze_backbone
84
+ self.model_name = model_name
85
+
86
+ # Compute compatible input size for MaxViT attention windows
87
+ self._target_size = self._compute_compatible_size(in_shape)
88
+
89
+ # Try to load from timm
90
+ try:
91
+ import timm
92
+
93
+ self.backbone = timm.create_model(
94
+ model_name,
95
+ pretrained=pretrained,
96
+ num_classes=0, # Remove classifier
97
+ )
98
+
99
+ # Get feature dimension using compatible size
100
+ with torch.no_grad():
101
+ dummy = torch.zeros(1, 3, *self._target_size)
102
+ features = self.backbone(dummy)
103
+ in_features = features.shape[-1]
104
+
105
+ except ImportError:
106
+ raise ImportError(
107
+ "timm is required for MaxViT. Install with: pip install timm"
108
+ )
109
+ except Exception as e:
110
+ raise RuntimeError(f"Failed to load MaxViT model '{model_name}': {e}")
111
+
112
+ # Adapt input channels (3 -> 1)
113
+ self._adapt_input_channels()
114
+
115
+ # Regression head
116
+ self.head = build_regression_head(in_features, out_size, dropout_rate)
117
+
118
+ if freeze_backbone:
119
+ self._freeze_backbone()
120
+
121
+ def _adapt_input_channels(self):
122
+ """Adapt first conv layer for single-channel input."""
123
+ from wavedl.models._pretrained_utils import find_and_adapt_input_convs
124
+
125
+ adapted_count = find_and_adapt_input_convs(
126
+ self.backbone, pretrained=self.pretrained, adapt_all=False
127
+ )
128
+
129
+ if adapted_count == 0:
130
+ import warnings
131
+
132
+ warnings.warn(
133
+ "Could not adapt MaxViT input channels. Model may fail.", stacklevel=2
134
+ )
135
+
136
+ def _freeze_backbone(self):
137
+ """Freeze backbone parameters."""
138
+ for param in self.backbone.parameters():
139
+ param.requires_grad = False
140
+
141
+ def _compute_compatible_size(self, in_shape: tuple[int, int]) -> tuple[int, int]:
142
+ """
143
+ Compute the nearest input size compatible with MaxViT attention windows.
144
+
145
+ MaxViT requires input dimensions divisible by 28 (4x stem downsample × 7 window).
146
+ This rounds up to the nearest compatible size.
147
+
148
+ Args:
149
+ in_shape: Original (H, W) input shape
150
+
151
+ Returns:
152
+ Compatible (H, W) shape divisible by 28
153
+ """
154
+ import math
155
+
156
+ h, w = in_shape
157
+ target_h = math.ceil(h / self._DIVISOR) * self._DIVISOR
158
+ target_w = math.ceil(w / self._DIVISOR) * self._DIVISOR
159
+ return (target_h, target_w)
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ # Resize input to compatible size if needed
163
+ _, _, h, w = x.shape
164
+ if (h, w) != self._target_size:
165
+ x = F.interpolate(
166
+ x,
167
+ size=self._target_size,
168
+ mode="bilinear",
169
+ align_corners=False,
170
+ )
171
+ features = self.backbone(x)
172
+ return self.head(features)
173
+
174
+
175
+ # =============================================================================
176
+ # REGISTERED VARIANTS
177
+ # =============================================================================
178
+
179
+
180
+ @register_model("maxvit_tiny")
181
+ class MaxViTTiny(MaxViTBase):
182
+ """
183
+ MaxViT Tiny: ~30.1M backbone parameters.
184
+
185
+ Multi-axis attention with local+global context.
186
+ 2D only.
187
+
188
+ Example:
189
+ >>> model = MaxViTTiny(in_shape=(224, 224), out_size=3)
190
+ >>> x = torch.randn(4, 1, 224, 224)
191
+ >>> out = model(x) # (4, 3)
192
+ """
193
+
194
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
195
+ super().__init__(
196
+ in_shape=in_shape,
197
+ out_size=out_size,
198
+ model_name="maxvit_tiny_tf_224",
199
+ **kwargs,
200
+ )
201
+
202
+ def __repr__(self) -> str:
203
+ return (
204
+ f"MaxViT_Tiny(in_shape={self.in_shape}, out_size={self.out_size}, "
205
+ f"pretrained={self.pretrained})"
206
+ )
207
+
208
+
209
+ @register_model("maxvit_small")
210
+ class MaxViTSmall(MaxViTBase):
211
+ """
212
+ MaxViT Small: ~67.6M backbone parameters.
213
+
214
+ Multi-axis attention with local+global context.
215
+ 2D only.
216
+ """
217
+
218
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
219
+ super().__init__(
220
+ in_shape=in_shape,
221
+ out_size=out_size,
222
+ model_name="maxvit_small_tf_224",
223
+ **kwargs,
224
+ )
225
+
226
+ def __repr__(self) -> str:
227
+ return (
228
+ f"MaxViT_Small(in_shape={self.in_shape}, out_size={self.out_size}, "
229
+ f"pretrained={self.pretrained})"
230
+ )
231
+
232
+
233
+ @register_model("maxvit_base")
234
+ class MaxViTBaseLarge(MaxViTBase):
235
+ """
236
+ MaxViT Base: ~118.1M backbone parameters.
237
+
238
+ Multi-axis attention with local+global context.
239
+ 2D only.
240
+ """
241
+
242
+ def __init__(self, in_shape: tuple[int, int], out_size: int, **kwargs):
243
+ super().__init__(
244
+ in_shape=in_shape,
245
+ out_size=out_size,
246
+ model_name="maxvit_base_tf_224",
247
+ **kwargs,
248
+ )
249
+
250
+ def __repr__(self) -> str:
251
+ return (
252
+ f"MaxViT_Base(in_shape={self.in_shape}, out_size={self.out_size}, "
253
+ f"pretrained={self.pretrained})"
254
+ )