wavedl 1.6.2__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/models/mamba.py CHANGED
@@ -56,6 +56,14 @@ __all__ = [
56
56
  # SELECTIVE SSM CORE (Pure PyTorch Implementation)
57
57
  # =============================================================================
58
58
 
59
+ # Maximum sequence length for stable parallel scan without chunking
60
+ # Beyond this, the chunked implementation is used automatically
61
+ MAX_SAFE_SEQUENCE_LENGTH = 512
62
+
63
+ # Recommended maximum for this pure-PyTorch implementation
64
+ # For longer sequences, consider using the optimized mamba-ssm package
65
+ MAX_RECOMMENDED_SEQUENCE_LENGTH = 2048
66
+
59
67
 
60
68
  class SelectiveSSM(nn.Module):
61
69
  """
@@ -64,8 +72,17 @@ class SelectiveSSM(nn.Module):
64
72
  The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
65
73
  allowing the model to selectively focus on or ignore inputs.
66
74
 
67
- This is a simplified pure-PyTorch implementation. For production use,
68
- consider the optimized mamba-ssm package.
75
+ This is a pure-PyTorch implementation with chunked parallel scan for
76
+ numerical stability. For sequences > 2048 or production use, consider
77
+ the optimized mamba-ssm package with CUDA kernels.
78
+
79
+ Args:
80
+ d_model: Model dimension
81
+ d_state: SSM state dimension (default: 16)
82
+ d_conv: Local convolution width (default: 4)
83
+ expand: Expansion factor for inner dimension (default: 2)
84
+ chunk_size: Chunk size for parallel scan (default: 256).
85
+ Smaller = more stable but slower. Larger = faster but may overflow.
69
86
  """
70
87
 
71
88
  def __init__(
@@ -74,12 +91,14 @@ class SelectiveSSM(nn.Module):
74
91
  d_state: int = 16,
75
92
  d_conv: int = 4,
76
93
  expand: int = 2,
94
+ chunk_size: int = 256,
77
95
  ):
78
96
  super().__init__()
79
97
 
80
98
  self.d_model = d_model
81
99
  self.d_state = d_state
82
100
  self.d_inner = d_model * expand
101
+ self.chunk_size = chunk_size
83
102
 
84
103
  # Input projection
85
104
  self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
@@ -116,6 +135,17 @@ class SelectiveSSM(nn.Module):
116
135
  """
117
136
  _B, L, _D = x.shape
118
137
 
138
+ # Warn for very long sequences
139
+ if L > MAX_RECOMMENDED_SEQUENCE_LENGTH and self.training:
140
+ import warnings
141
+
142
+ warnings.warn(
143
+ f"Sequence length {L} > {MAX_RECOMMENDED_SEQUENCE_LENGTH}. "
144
+ "Consider using mamba-ssm package for better performance.",
145
+ UserWarning,
146
+ stacklevel=2,
147
+ )
148
+
119
149
  # Input projection and split
120
150
  xz = self.in_proj(x) # (B, L, 2*d_inner)
121
151
  x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
@@ -135,8 +165,11 @@ class SelectiveSSM(nn.Module):
135
165
  # Discretize A
136
166
  A = -torch.exp(self.A_log) # (d_state,)
137
167
 
138
- # Selective scan (simplified, not optimized)
139
- y = self._selective_scan(x, delta, A, B_param, C_param, self.D)
168
+ # Use chunked scan for long sequences, direct scan for short
169
+ if L > MAX_SAFE_SEQUENCE_LENGTH:
170
+ y = self._chunked_selective_scan(x, delta, A, B_param, C_param, self.D)
171
+ else:
172
+ y = self._selective_scan_single(x, delta, A, B_param, C_param, self.D)
140
173
 
141
174
  # Gating
142
175
  y = y * F.silu(z)
@@ -144,7 +177,7 @@ class SelectiveSSM(nn.Module):
144
177
  # Output projection
145
178
  return self.out_proj(y)
146
179
 
147
- def _selective_scan(
180
+ def _selective_scan_single(
148
181
  self,
149
182
  x: torch.Tensor,
150
183
  delta: torch.Tensor,
@@ -154,54 +187,109 @@ class SelectiveSSM(nn.Module):
154
187
  D: torch.Tensor,
155
188
  ) -> torch.Tensor:
156
189
  """
157
- Vectorized selective scan using parallel associative scan.
190
+ Single-chunk parallel scan for short sequences (L <= MAX_SAFE_SEQUENCE_LENGTH).
158
191
 
159
- This implementation avoids the sequential for-loop by computing
160
- all timesteps in parallel using cumulative products and sums.
161
- ~100x faster than the naive sequential implementation.
192
+ Uses log-space cumsum which is stable for short sequences.
162
193
  """
194
+ # Compute discretized A_bar: (B, L, d_inner, d_state)
195
+ A_bar = torch.exp(delta.unsqueeze(-1) * A)
163
196
 
164
- # Compute discretized A_bar for all timesteps: (B, L, d_inner, d_state)
165
- A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
166
-
167
- # Compute input contribution: delta * B * x for all timesteps
168
- # B: (B, L, d_state), x: (B, L, d_inner), delta: (B, L, d_inner)
169
- # Result: (B, L, d_inner, d_state)
197
+ # Input contribution: (B, L, d_inner, d_state)
170
198
  BX = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
171
199
 
172
- # Parallel scan using log-space cumulative products for numerical stability
173
- # For SSM: h[t] = A_bar[t] * h[t-1] + BX[t]
174
- # This is a linear recurrence that can be solved with associative scan
175
-
176
- # Use chunked approach for memory efficiency with parallel scan
177
- # Compute cumulative product of A_bar (in log space for stability)
200
+ # Log-space parallel scan
178
201
  log_A_bar = torch.log(A_bar.clamp(min=1e-10))
179
- log_A_cumsum = torch.cumsum(log_A_bar, dim=1) # (B, L, d_inner, d_state)
180
- A_cumsum = torch.exp(log_A_cumsum)
202
+ log_A_cumsum = torch.cumsum(log_A_bar, dim=1)
203
+ A_cumsum = torch.exp(log_A_cumsum.clamp(max=80)) # Prevent overflow
181
204
 
182
- # For each timestep t, we need: sum_{s=0}^{t} (prod_{k=s+1}^{t} A_bar[k]) * BX[s]
183
- # = sum_{s=0}^{t} (A_cumsum[t] / A_cumsum[s]) * BX[s]
184
- # = A_cumsum[t] * sum_{s=0}^{t} (BX[s] / A_cumsum[s])
185
-
186
- # Compute BX / A_cumsum (use A_cumsum shifted by 1 for proper indexing)
187
- # A_cumsum[s] represents prod_{k=0}^{s} A_bar[k], but we need prod_{k=0}^{s-1}
188
- # So we shift: use A_cumsum from previous timestep
205
+ # Shifted cumsum for proper indexing
189
206
  A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
190
207
 
191
- # Weighted input: BX[s] / A_cumsum[s-1] = BX[s] * exp(-log_A_cumsum[s-1])
208
+ # Weighted input and cumsum
192
209
  weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
193
-
194
- # Cumulative sum of weighted inputs
195
210
  weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
196
211
 
197
- # Final state at each timestep: h[t] = A_cumsum[t] * weighted_BX_cumsum[t]
198
- # But A_cumsum includes A_bar[0], so adjust
212
+ # Final state
199
213
  h = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
200
214
 
201
- # Output: y = C * h + D * x
202
- # h: (B, L, d_inner, d_state), C: (B, L, d_state)
203
- y = (C.unsqueeze(2) * h).sum(-1) + D * x # (B, L, d_inner)
215
+ # Output
216
+ y = (C.unsqueeze(2) * h).sum(-1) + D * x
217
+ return y
218
+
219
+ def _chunked_selective_scan(
220
+ self,
221
+ x: torch.Tensor,
222
+ delta: torch.Tensor,
223
+ A: torch.Tensor,
224
+ B: torch.Tensor,
225
+ C: torch.Tensor,
226
+ D: torch.Tensor,
227
+ ) -> torch.Tensor:
228
+ """
229
+ Chunked parallel scan for long sequences.
230
+
231
+ Processes in chunks of self.chunk_size, carrying state between chunks.
232
+ This prevents log-cumsum from growing unbounded while maintaining
233
+ reasonable parallelism within each chunk.
234
+ """
235
+ batch_size, seq_len, d_inner = x.shape
236
+ d_state = self.d_state
237
+ chunk_size = self.chunk_size
238
+
239
+ # Initialize output and state
240
+ y_chunks = []
241
+ h_state = torch.zeros(
242
+ batch_size, d_inner, d_state, device=x.device, dtype=x.dtype
243
+ )
244
+
245
+ # Process in chunks
246
+ for start in range(0, seq_len, chunk_size):
247
+ end = min(start + chunk_size, seq_len)
248
+
249
+ # Extract chunk
250
+ x_chunk = x[:, start:end]
251
+ delta_chunk = delta[:, start:end]
252
+ B_chunk = B[:, start:end]
253
+ C_chunk = C[:, start:end]
254
+
255
+ # Compute A_bar for chunk: (B, chunk_len, d_inner, d_state)
256
+ A_bar = torch.exp(delta_chunk.unsqueeze(-1) * A)
257
+
258
+ # Input contribution
259
+ BX = (
260
+ delta_chunk.unsqueeze(-1) * B_chunk.unsqueeze(2) * x_chunk.unsqueeze(-1)
261
+ )
262
+
263
+ # Within-chunk parallel scan (short enough to be stable)
264
+ log_A_bar = torch.log(A_bar.clamp(min=1e-10))
265
+ log_A_cumsum = torch.cumsum(log_A_bar, dim=1)
266
+ A_cumsum = torch.exp(log_A_cumsum.clamp(max=80))
267
+
268
+ A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
269
+ weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
270
+ weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
271
+
272
+ # Chunk-internal state (without carry-over)
273
+ h_chunk_internal = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
274
+
275
+ # Add contribution from previous state
276
+ # h_state: (B, d_inner, d_state) -> (B, 1, d_inner, d_state)
277
+ # A_cumsum: (B, chunk_len, d_inner, d_state)
278
+ h_state_contribution = h_state.unsqueeze(1) * A_cumsum
279
+
280
+ # Total state for this chunk
281
+ h_chunk = h_chunk_internal + h_state_contribution
282
+
283
+ # Output for this chunk
284
+ y_chunk = (C_chunk.unsqueeze(2) * h_chunk).sum(-1) + D * x_chunk
285
+ y_chunks.append(y_chunk)
286
+
287
+ # Update carry-over state for next chunk
288
+ # Final state of this chunk: h_chunk[:, -1]
289
+ h_state = h_chunk[:, -1]
204
290
 
291
+ # Concatenate all chunks
292
+ y = torch.cat(y_chunks, dim=1)
205
293
  return y
206
294
 
207
295
 
wavedl/models/resnet3d.py CHANGED
@@ -136,6 +136,28 @@ class ResNet3DBase(BaseModel):
136
136
  if freeze_backbone:
137
137
  self._freeze_backbone()
138
138
 
139
+ # Adapt first conv for single-channel input (instead of expand in forward)
140
+ self._adapt_stem_for_single_channel()
141
+
142
+ def _adapt_stem_for_single_channel(self):
143
+ """Modify stem conv to accept 1 channel, averaging pretrained RGB weights."""
144
+ old_conv = self.backbone.stem[0]
145
+ new_conv = nn.Conv3d(
146
+ 1,
147
+ old_conv.out_channels,
148
+ kernel_size=old_conv.kernel_size,
149
+ stride=old_conv.stride,
150
+ padding=old_conv.padding,
151
+ bias=old_conv.bias is not None,
152
+ )
153
+ if self.pretrained:
154
+ with torch.no_grad():
155
+ # Average RGB weights for grayscale initialization
156
+ new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
157
+ if old_conv.bias is not None:
158
+ new_conv.bias.copy_(old_conv.bias)
159
+ self.backbone.stem[0] = new_conv
160
+
139
161
  def _freeze_backbone(self):
140
162
  """Freeze all backbone parameters except the fc head."""
141
163
  for name, param in self.backbone.named_parameters():
@@ -147,15 +169,11 @@ class ResNet3DBase(BaseModel):
147
169
  Forward pass.
148
170
 
149
171
  Args:
150
- x: Input tensor of shape (B, C, D, H, W) where C is 1 or 3
172
+ x: Input tensor of shape (B, 1, D, H, W)
151
173
 
152
174
  Returns:
153
175
  Output tensor of shape (B, out_size)
154
176
  """
155
- # Expand single channel to 3 channels for pretrained weights compatibility
156
- if x.size(1) == 1:
157
- x = x.expand(-1, 3, -1, -1, -1)
158
-
159
177
  return self.backbone(x)
160
178
 
161
179
  @classmethod
@@ -37,6 +37,7 @@ import torch
37
37
  import torch.nn as nn
38
38
 
39
39
  from wavedl.models._pretrained_utils import (
40
+ DropPath,
40
41
  LayerNormNd,
41
42
  get_conv_layer,
42
43
  get_grn_layer,
@@ -133,24 +134,6 @@ class SEBlock(nn.Module):
133
134
  return x * scale
134
135
 
135
136
 
136
- class DropPath(nn.Module):
137
- """Stochastic Depth (drop path) regularization."""
138
-
139
- def __init__(self, drop_prob: float = 0.0):
140
- super().__init__()
141
- self.drop_prob = drop_prob
142
-
143
- def forward(self, x: torch.Tensor) -> torch.Tensor:
144
- if self.drop_prob == 0.0 or not self.training:
145
- return x
146
-
147
- keep_prob = 1 - self.drop_prob
148
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
149
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
150
- random_tensor.floor_()
151
- return x.div(keep_prob) * random_tensor
152
-
153
-
154
137
  # =============================================================================
155
138
  # UNIREPLKNET BLOCK
156
139
  # =============================================================================
wavedl/models/vit.py CHANGED
@@ -150,17 +150,22 @@ class PatchEmbed(nn.Module):
150
150
 
151
151
 
152
152
  class MultiHeadAttention(nn.Module):
153
- """Multi-head self-attention mechanism."""
153
+ """
154
+ Multi-head self-attention mechanism.
155
+
156
+ Uses F.scaled_dot_product_attention (PyTorch 2.0+) for efficient,
157
+ fused attention with automatic Flash Attention support when available.
158
+ """
154
159
 
155
160
  def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
156
161
  super().__init__()
157
162
  self.num_heads = num_heads
158
163
  self.head_dim = embed_dim // num_heads
159
- self.scale = self.head_dim**-0.5
164
+ self.dropout_p = dropout
160
165
 
161
166
  self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
162
167
  self.proj = nn.Linear(embed_dim, embed_dim)
163
- self.dropout = nn.Dropout(dropout)
168
+ self.proj_dropout = nn.Dropout(dropout)
164
169
 
165
170
  def forward(self, x: torch.Tensor) -> torch.Tensor:
166
171
  B, N, C = x.shape
@@ -169,13 +174,18 @@ class MultiHeadAttention(nn.Module):
169
174
  qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
170
175
  q, k, v = qkv[0], qkv[1], qkv[2]
171
176
 
172
- attn = (q @ k.transpose(-2, -1)) * self.scale
173
- attn = attn.softmax(dim=-1)
174
- attn = self.dropout(attn)
177
+ # Use fused SDPA (PyTorch 2.0+) for efficiency + Flash Attention
178
+ x = torch.nn.functional.scaled_dot_product_attention(
179
+ q,
180
+ k,
181
+ v,
182
+ dropout_p=self.dropout_p if self.training else 0.0,
183
+ is_causal=False,
184
+ )
175
185
 
176
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
186
+ x = x.transpose(1, 2).reshape(B, N, C)
177
187
  x = self.proj(x)
178
- x = self.dropout(x)
188
+ x = self.proj_dropout(x)
179
189
  return x
180
190
 
181
191
 
wavedl/test.py CHANGED
@@ -33,35 +33,17 @@ Author: Ductho Le (ductho.le@outlook.com)
33
33
  # Uses current working directory as fallback - works on HPC and local machines.
34
34
  import os
35
35
 
36
+ # Import and call HPC cache setup before any library imports
37
+ from wavedl.utils import setup_hpc_cache_dirs
36
38
 
37
- def _setup_cache_dir(env_var: str, subdir: str) -> None:
38
- """Set cache directory to CWD if home is not writable."""
39
- if env_var in os.environ:
40
- return # User already set, respect their choice
41
39
 
42
- # Check if home is writable
43
- home = os.path.expanduser("~")
44
- if os.access(home, os.W_OK):
45
- return # Home is writable, let library use defaults
46
-
47
- # Home not writable - use current working directory
48
- cache_path = os.path.join(os.getcwd(), f".{subdir}")
49
- os.makedirs(cache_path, exist_ok=True)
50
- os.environ[env_var] = cache_path
51
-
52
-
53
- # Configure cache directories (before any library imports)
54
- _setup_cache_dir("TORCH_HOME", "torch_cache")
55
- _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
56
- _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
57
- _setup_cache_dir("XDG_DATA_HOME", "local/share")
58
- _setup_cache_dir("XDG_STATE_HOME", "local/state")
59
- _setup_cache_dir("XDG_CACHE_HOME", "cache")
40
+ setup_hpc_cache_dirs()
60
41
 
61
42
  import argparse # noqa: E402
62
43
  import logging # noqa: E402
63
44
  import pickle # noqa: E402
64
45
  from pathlib import Path # noqa: E402
46
+ from typing import Any # noqa: E402
65
47
 
66
48
  import matplotlib.pyplot as plt # noqa: E402
67
49
  import numpy as np # noqa: E402
@@ -314,7 +296,7 @@ def load_checkpoint(
314
296
  in_shape: tuple[int, ...],
315
297
  out_size: int,
316
298
  model_name: str | None = None,
317
- ) -> tuple[nn.Module, any]:
299
+ ) -> tuple[nn.Module, Any]:
318
300
  """
319
301
  Load model and scaler from Accelerate checkpoint directory.
320
302
 
@@ -398,6 +380,14 @@ def load_checkpoint(
398
380
 
399
381
  if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
400
382
  state_dict = load_safetensors(str(weight_path))
383
+ elif weight_path.suffix == ".safetensors":
384
+ # Safetensors file exists but library not installed
385
+ raise ImportError(
386
+ f"Checkpoint uses safetensors format ({weight_path.name}) but "
387
+ f"'safetensors' package is not installed. Install it with:\n"
388
+ f" pip install safetensors\n"
389
+ f"Or convert the checkpoint to PyTorch format (model.bin)."
390
+ )
401
391
  else:
402
392
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
403
393