wavedl 1.6.3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +115 -9
- wavedl/models/_pretrained_utils.py +72 -0
- wavedl/models/_template.py +7 -6
- wavedl/models/cnn.py +20 -0
- wavedl/models/convnext.py +3 -70
- wavedl/models/convnext_v2.py +1 -18
- wavedl/models/mamba.py +126 -38
- wavedl/models/resnet3d.py +23 -5
- wavedl/models/unireplknet.py +1 -18
- wavedl/models/vit.py +18 -8
- wavedl/test.py +5 -23
- wavedl/train.py +492 -26
- wavedl/utils/__init__.py +49 -9
- wavedl/utils/config.py +6 -8
- wavedl/utils/cross_validation.py +17 -4
- wavedl/utils/data.py +140 -174
- wavedl/utils/metrics.py +26 -5
- wavedl/utils/schedulers.py +2 -2
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/METADATA +35 -14
- wavedl-1.7.0.dist-info/RECORD +46 -0
- wavedl-1.6.3.dist-info/RECORD +0 -46
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/LICENSE +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/WHEEL +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.3.dist-info → wavedl-1.7.0.dist-info}/top_level.txt +0 -0
wavedl/models/mamba.py
CHANGED
|
@@ -56,6 +56,14 @@ __all__ = [
|
|
|
56
56
|
# SELECTIVE SSM CORE (Pure PyTorch Implementation)
|
|
57
57
|
# =============================================================================
|
|
58
58
|
|
|
59
|
+
# Maximum sequence length for stable parallel scan without chunking
|
|
60
|
+
# Beyond this, the chunked implementation is used automatically
|
|
61
|
+
MAX_SAFE_SEQUENCE_LENGTH = 512
|
|
62
|
+
|
|
63
|
+
# Recommended maximum for this pure-PyTorch implementation
|
|
64
|
+
# For longer sequences, consider using the optimized mamba-ssm package
|
|
65
|
+
MAX_RECOMMENDED_SEQUENCE_LENGTH = 2048
|
|
66
|
+
|
|
59
67
|
|
|
60
68
|
class SelectiveSSM(nn.Module):
|
|
61
69
|
"""
|
|
@@ -64,8 +72,17 @@ class SelectiveSSM(nn.Module):
|
|
|
64
72
|
The key innovation is making the SSM parameters (B, C, Δ) input-dependent,
|
|
65
73
|
allowing the model to selectively focus on or ignore inputs.
|
|
66
74
|
|
|
67
|
-
This is a
|
|
68
|
-
|
|
75
|
+
This is a pure-PyTorch implementation with chunked parallel scan for
|
|
76
|
+
numerical stability. For sequences > 2048 or production use, consider
|
|
77
|
+
the optimized mamba-ssm package with CUDA kernels.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
d_model: Model dimension
|
|
81
|
+
d_state: SSM state dimension (default: 16)
|
|
82
|
+
d_conv: Local convolution width (default: 4)
|
|
83
|
+
expand: Expansion factor for inner dimension (default: 2)
|
|
84
|
+
chunk_size: Chunk size for parallel scan (default: 256).
|
|
85
|
+
Smaller = more stable but slower. Larger = faster but may overflow.
|
|
69
86
|
"""
|
|
70
87
|
|
|
71
88
|
def __init__(
|
|
@@ -74,12 +91,14 @@ class SelectiveSSM(nn.Module):
|
|
|
74
91
|
d_state: int = 16,
|
|
75
92
|
d_conv: int = 4,
|
|
76
93
|
expand: int = 2,
|
|
94
|
+
chunk_size: int = 256,
|
|
77
95
|
):
|
|
78
96
|
super().__init__()
|
|
79
97
|
|
|
80
98
|
self.d_model = d_model
|
|
81
99
|
self.d_state = d_state
|
|
82
100
|
self.d_inner = d_model * expand
|
|
101
|
+
self.chunk_size = chunk_size
|
|
83
102
|
|
|
84
103
|
# Input projection
|
|
85
104
|
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
|
@@ -116,6 +135,17 @@ class SelectiveSSM(nn.Module):
|
|
|
116
135
|
"""
|
|
117
136
|
_B, L, _D = x.shape
|
|
118
137
|
|
|
138
|
+
# Warn for very long sequences
|
|
139
|
+
if L > MAX_RECOMMENDED_SEQUENCE_LENGTH and self.training:
|
|
140
|
+
import warnings
|
|
141
|
+
|
|
142
|
+
warnings.warn(
|
|
143
|
+
f"Sequence length {L} > {MAX_RECOMMENDED_SEQUENCE_LENGTH}. "
|
|
144
|
+
"Consider using mamba-ssm package for better performance.",
|
|
145
|
+
UserWarning,
|
|
146
|
+
stacklevel=2,
|
|
147
|
+
)
|
|
148
|
+
|
|
119
149
|
# Input projection and split
|
|
120
150
|
xz = self.in_proj(x) # (B, L, 2*d_inner)
|
|
121
151
|
x, z = xz.chunk(2, dim=-1) # Each: (B, L, d_inner)
|
|
@@ -135,8 +165,11 @@ class SelectiveSSM(nn.Module):
|
|
|
135
165
|
# Discretize A
|
|
136
166
|
A = -torch.exp(self.A_log) # (d_state,)
|
|
137
167
|
|
|
138
|
-
#
|
|
139
|
-
|
|
168
|
+
# Use chunked scan for long sequences, direct scan for short
|
|
169
|
+
if L > MAX_SAFE_SEQUENCE_LENGTH:
|
|
170
|
+
y = self._chunked_selective_scan(x, delta, A, B_param, C_param, self.D)
|
|
171
|
+
else:
|
|
172
|
+
y = self._selective_scan_single(x, delta, A, B_param, C_param, self.D)
|
|
140
173
|
|
|
141
174
|
# Gating
|
|
142
175
|
y = y * F.silu(z)
|
|
@@ -144,7 +177,7 @@ class SelectiveSSM(nn.Module):
|
|
|
144
177
|
# Output projection
|
|
145
178
|
return self.out_proj(y)
|
|
146
179
|
|
|
147
|
-
def
|
|
180
|
+
def _selective_scan_single(
|
|
148
181
|
self,
|
|
149
182
|
x: torch.Tensor,
|
|
150
183
|
delta: torch.Tensor,
|
|
@@ -154,54 +187,109 @@ class SelectiveSSM(nn.Module):
|
|
|
154
187
|
D: torch.Tensor,
|
|
155
188
|
) -> torch.Tensor:
|
|
156
189
|
"""
|
|
157
|
-
|
|
190
|
+
Single-chunk parallel scan for short sequences (L <= MAX_SAFE_SEQUENCE_LENGTH).
|
|
158
191
|
|
|
159
|
-
|
|
160
|
-
all timesteps in parallel using cumulative products and sums.
|
|
161
|
-
~100x faster than the naive sequential implementation.
|
|
192
|
+
Uses log-space cumsum which is stable for short sequences.
|
|
162
193
|
"""
|
|
194
|
+
# Compute discretized A_bar: (B, L, d_inner, d_state)
|
|
195
|
+
A_bar = torch.exp(delta.unsqueeze(-1) * A)
|
|
163
196
|
|
|
164
|
-
#
|
|
165
|
-
A_bar = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
|
|
166
|
-
|
|
167
|
-
# Compute input contribution: delta * B * x for all timesteps
|
|
168
|
-
# B: (B, L, d_state), x: (B, L, d_inner), delta: (B, L, d_inner)
|
|
169
|
-
# Result: (B, L, d_inner, d_state)
|
|
197
|
+
# Input contribution: (B, L, d_inner, d_state)
|
|
170
198
|
BX = delta.unsqueeze(-1) * B.unsqueeze(2) * x.unsqueeze(-1)
|
|
171
199
|
|
|
172
|
-
#
|
|
173
|
-
# For SSM: h[t] = A_bar[t] * h[t-1] + BX[t]
|
|
174
|
-
# This is a linear recurrence that can be solved with associative scan
|
|
175
|
-
|
|
176
|
-
# Use chunked approach for memory efficiency with parallel scan
|
|
177
|
-
# Compute cumulative product of A_bar (in log space for stability)
|
|
200
|
+
# Log-space parallel scan
|
|
178
201
|
log_A_bar = torch.log(A_bar.clamp(min=1e-10))
|
|
179
|
-
log_A_cumsum = torch.cumsum(log_A_bar, dim=1)
|
|
180
|
-
A_cumsum = torch.exp(log_A_cumsum)
|
|
202
|
+
log_A_cumsum = torch.cumsum(log_A_bar, dim=1)
|
|
203
|
+
A_cumsum = torch.exp(log_A_cumsum.clamp(max=80)) # Prevent overflow
|
|
181
204
|
|
|
182
|
-
#
|
|
183
|
-
# = sum_{s=0}^{t} (A_cumsum[t] / A_cumsum[s]) * BX[s]
|
|
184
|
-
# = A_cumsum[t] * sum_{s=0}^{t} (BX[s] / A_cumsum[s])
|
|
185
|
-
|
|
186
|
-
# Compute BX / A_cumsum (use A_cumsum shifted by 1 for proper indexing)
|
|
187
|
-
# A_cumsum[s] represents prod_{k=0}^{s} A_bar[k], but we need prod_{k=0}^{s-1}
|
|
188
|
-
# So we shift: use A_cumsum from previous timestep
|
|
205
|
+
# Shifted cumsum for proper indexing
|
|
189
206
|
A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
|
|
190
207
|
|
|
191
|
-
# Weighted input
|
|
208
|
+
# Weighted input and cumsum
|
|
192
209
|
weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
|
|
193
|
-
|
|
194
|
-
# Cumulative sum of weighted inputs
|
|
195
210
|
weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
|
|
196
211
|
|
|
197
|
-
# Final state
|
|
198
|
-
# But A_cumsum includes A_bar[0], so adjust
|
|
212
|
+
# Final state
|
|
199
213
|
h = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
|
|
200
214
|
|
|
201
|
-
# Output
|
|
202
|
-
|
|
203
|
-
y
|
|
215
|
+
# Output
|
|
216
|
+
y = (C.unsqueeze(2) * h).sum(-1) + D * x
|
|
217
|
+
return y
|
|
218
|
+
|
|
219
|
+
def _chunked_selective_scan(
|
|
220
|
+
self,
|
|
221
|
+
x: torch.Tensor,
|
|
222
|
+
delta: torch.Tensor,
|
|
223
|
+
A: torch.Tensor,
|
|
224
|
+
B: torch.Tensor,
|
|
225
|
+
C: torch.Tensor,
|
|
226
|
+
D: torch.Tensor,
|
|
227
|
+
) -> torch.Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Chunked parallel scan for long sequences.
|
|
230
|
+
|
|
231
|
+
Processes in chunks of self.chunk_size, carrying state between chunks.
|
|
232
|
+
This prevents log-cumsum from growing unbounded while maintaining
|
|
233
|
+
reasonable parallelism within each chunk.
|
|
234
|
+
"""
|
|
235
|
+
batch_size, seq_len, d_inner = x.shape
|
|
236
|
+
d_state = self.d_state
|
|
237
|
+
chunk_size = self.chunk_size
|
|
238
|
+
|
|
239
|
+
# Initialize output and state
|
|
240
|
+
y_chunks = []
|
|
241
|
+
h_state = torch.zeros(
|
|
242
|
+
batch_size, d_inner, d_state, device=x.device, dtype=x.dtype
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Process in chunks
|
|
246
|
+
for start in range(0, seq_len, chunk_size):
|
|
247
|
+
end = min(start + chunk_size, seq_len)
|
|
248
|
+
|
|
249
|
+
# Extract chunk
|
|
250
|
+
x_chunk = x[:, start:end]
|
|
251
|
+
delta_chunk = delta[:, start:end]
|
|
252
|
+
B_chunk = B[:, start:end]
|
|
253
|
+
C_chunk = C[:, start:end]
|
|
254
|
+
|
|
255
|
+
# Compute A_bar for chunk: (B, chunk_len, d_inner, d_state)
|
|
256
|
+
A_bar = torch.exp(delta_chunk.unsqueeze(-1) * A)
|
|
257
|
+
|
|
258
|
+
# Input contribution
|
|
259
|
+
BX = (
|
|
260
|
+
delta_chunk.unsqueeze(-1) * B_chunk.unsqueeze(2) * x_chunk.unsqueeze(-1)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Within-chunk parallel scan (short enough to be stable)
|
|
264
|
+
log_A_bar = torch.log(A_bar.clamp(min=1e-10))
|
|
265
|
+
log_A_cumsum = torch.cumsum(log_A_bar, dim=1)
|
|
266
|
+
A_cumsum = torch.exp(log_A_cumsum.clamp(max=80))
|
|
267
|
+
|
|
268
|
+
A_cumsum_shifted = F.pad(A_cumsum[:, :-1], (0, 0, 0, 0, 1, 0), value=1.0)
|
|
269
|
+
weighted_BX = BX / A_cumsum_shifted.clamp(min=1e-10)
|
|
270
|
+
weighted_BX_cumsum = torch.cumsum(weighted_BX, dim=1)
|
|
271
|
+
|
|
272
|
+
# Chunk-internal state (without carry-over)
|
|
273
|
+
h_chunk_internal = A_cumsum * weighted_BX_cumsum / A_bar.clamp(min=1e-10)
|
|
274
|
+
|
|
275
|
+
# Add contribution from previous state
|
|
276
|
+
# h_state: (B, d_inner, d_state) -> (B, 1, d_inner, d_state)
|
|
277
|
+
# A_cumsum: (B, chunk_len, d_inner, d_state)
|
|
278
|
+
h_state_contribution = h_state.unsqueeze(1) * A_cumsum
|
|
279
|
+
|
|
280
|
+
# Total state for this chunk
|
|
281
|
+
h_chunk = h_chunk_internal + h_state_contribution
|
|
282
|
+
|
|
283
|
+
# Output for this chunk
|
|
284
|
+
y_chunk = (C_chunk.unsqueeze(2) * h_chunk).sum(-1) + D * x_chunk
|
|
285
|
+
y_chunks.append(y_chunk)
|
|
286
|
+
|
|
287
|
+
# Update carry-over state for next chunk
|
|
288
|
+
# Final state of this chunk: h_chunk[:, -1]
|
|
289
|
+
h_state = h_chunk[:, -1]
|
|
204
290
|
|
|
291
|
+
# Concatenate all chunks
|
|
292
|
+
y = torch.cat(y_chunks, dim=1)
|
|
205
293
|
return y
|
|
206
294
|
|
|
207
295
|
|
wavedl/models/resnet3d.py
CHANGED
|
@@ -136,6 +136,28 @@ class ResNet3DBase(BaseModel):
|
|
|
136
136
|
if freeze_backbone:
|
|
137
137
|
self._freeze_backbone()
|
|
138
138
|
|
|
139
|
+
# Adapt first conv for single-channel input (instead of expand in forward)
|
|
140
|
+
self._adapt_stem_for_single_channel()
|
|
141
|
+
|
|
142
|
+
def _adapt_stem_for_single_channel(self):
|
|
143
|
+
"""Modify stem conv to accept 1 channel, averaging pretrained RGB weights."""
|
|
144
|
+
old_conv = self.backbone.stem[0]
|
|
145
|
+
new_conv = nn.Conv3d(
|
|
146
|
+
1,
|
|
147
|
+
old_conv.out_channels,
|
|
148
|
+
kernel_size=old_conv.kernel_size,
|
|
149
|
+
stride=old_conv.stride,
|
|
150
|
+
padding=old_conv.padding,
|
|
151
|
+
bias=old_conv.bias is not None,
|
|
152
|
+
)
|
|
153
|
+
if self.pretrained:
|
|
154
|
+
with torch.no_grad():
|
|
155
|
+
# Average RGB weights for grayscale initialization
|
|
156
|
+
new_conv.weight.copy_(old_conv.weight.mean(dim=1, keepdim=True))
|
|
157
|
+
if old_conv.bias is not None:
|
|
158
|
+
new_conv.bias.copy_(old_conv.bias)
|
|
159
|
+
self.backbone.stem[0] = new_conv
|
|
160
|
+
|
|
139
161
|
def _freeze_backbone(self):
|
|
140
162
|
"""Freeze all backbone parameters except the fc head."""
|
|
141
163
|
for name, param in self.backbone.named_parameters():
|
|
@@ -147,15 +169,11 @@ class ResNet3DBase(BaseModel):
|
|
|
147
169
|
Forward pass.
|
|
148
170
|
|
|
149
171
|
Args:
|
|
150
|
-
x: Input tensor of shape (B,
|
|
172
|
+
x: Input tensor of shape (B, 1, D, H, W)
|
|
151
173
|
|
|
152
174
|
Returns:
|
|
153
175
|
Output tensor of shape (B, out_size)
|
|
154
176
|
"""
|
|
155
|
-
# Expand single channel to 3 channels for pretrained weights compatibility
|
|
156
|
-
if x.size(1) == 1:
|
|
157
|
-
x = x.expand(-1, 3, -1, -1, -1)
|
|
158
|
-
|
|
159
177
|
return self.backbone(x)
|
|
160
178
|
|
|
161
179
|
@classmethod
|
wavedl/models/unireplknet.py
CHANGED
|
@@ -37,6 +37,7 @@ import torch
|
|
|
37
37
|
import torch.nn as nn
|
|
38
38
|
|
|
39
39
|
from wavedl.models._pretrained_utils import (
|
|
40
|
+
DropPath,
|
|
40
41
|
LayerNormNd,
|
|
41
42
|
get_conv_layer,
|
|
42
43
|
get_grn_layer,
|
|
@@ -133,24 +134,6 @@ class SEBlock(nn.Module):
|
|
|
133
134
|
return x * scale
|
|
134
135
|
|
|
135
136
|
|
|
136
|
-
class DropPath(nn.Module):
|
|
137
|
-
"""Stochastic Depth (drop path) regularization."""
|
|
138
|
-
|
|
139
|
-
def __init__(self, drop_prob: float = 0.0):
|
|
140
|
-
super().__init__()
|
|
141
|
-
self.drop_prob = drop_prob
|
|
142
|
-
|
|
143
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
144
|
-
if self.drop_prob == 0.0 or not self.training:
|
|
145
|
-
return x
|
|
146
|
-
|
|
147
|
-
keep_prob = 1 - self.drop_prob
|
|
148
|
-
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
149
|
-
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
150
|
-
random_tensor.floor_()
|
|
151
|
-
return x.div(keep_prob) * random_tensor
|
|
152
|
-
|
|
153
|
-
|
|
154
137
|
# =============================================================================
|
|
155
138
|
# UNIREPLKNET BLOCK
|
|
156
139
|
# =============================================================================
|
wavedl/models/vit.py
CHANGED
|
@@ -150,17 +150,22 @@ class PatchEmbed(nn.Module):
|
|
|
150
150
|
|
|
151
151
|
|
|
152
152
|
class MultiHeadAttention(nn.Module):
|
|
153
|
-
"""
|
|
153
|
+
"""
|
|
154
|
+
Multi-head self-attention mechanism.
|
|
155
|
+
|
|
156
|
+
Uses F.scaled_dot_product_attention (PyTorch 2.0+) for efficient,
|
|
157
|
+
fused attention with automatic Flash Attention support when available.
|
|
158
|
+
"""
|
|
154
159
|
|
|
155
160
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
|
|
156
161
|
super().__init__()
|
|
157
162
|
self.num_heads = num_heads
|
|
158
163
|
self.head_dim = embed_dim // num_heads
|
|
159
|
-
self.
|
|
164
|
+
self.dropout_p = dropout
|
|
160
165
|
|
|
161
166
|
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
|
|
162
167
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
163
|
-
self.
|
|
168
|
+
self.proj_dropout = nn.Dropout(dropout)
|
|
164
169
|
|
|
165
170
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
166
171
|
B, N, C = x.shape
|
|
@@ -169,13 +174,18 @@ class MultiHeadAttention(nn.Module):
|
|
|
169
174
|
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
|
|
170
175
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
171
176
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
177
|
+
# Use fused SDPA (PyTorch 2.0+) for efficiency + Flash Attention
|
|
178
|
+
x = torch.nn.functional.scaled_dot_product_attention(
|
|
179
|
+
q,
|
|
180
|
+
k,
|
|
181
|
+
v,
|
|
182
|
+
dropout_p=self.dropout_p if self.training else 0.0,
|
|
183
|
+
is_causal=False,
|
|
184
|
+
)
|
|
175
185
|
|
|
176
|
-
x =
|
|
186
|
+
x = x.transpose(1, 2).reshape(B, N, C)
|
|
177
187
|
x = self.proj(x)
|
|
178
|
-
x = self.
|
|
188
|
+
x = self.proj_dropout(x)
|
|
179
189
|
return x
|
|
180
190
|
|
|
181
191
|
|
wavedl/test.py
CHANGED
|
@@ -33,35 +33,17 @@ Author: Ductho Le (ductho.le@outlook.com)
|
|
|
33
33
|
# Uses current working directory as fallback - works on HPC and local machines.
|
|
34
34
|
import os
|
|
35
35
|
|
|
36
|
+
# Import and call HPC cache setup before any library imports
|
|
37
|
+
from wavedl.utils import setup_hpc_cache_dirs
|
|
36
38
|
|
|
37
|
-
def _setup_cache_dir(env_var: str, subdir: str) -> None:
|
|
38
|
-
"""Set cache directory to CWD if home is not writable."""
|
|
39
|
-
if env_var in os.environ:
|
|
40
|
-
return # User already set, respect their choice
|
|
41
39
|
|
|
42
|
-
|
|
43
|
-
home = os.path.expanduser("~")
|
|
44
|
-
if os.access(home, os.W_OK):
|
|
45
|
-
return # Home is writable, let library use defaults
|
|
46
|
-
|
|
47
|
-
# Home not writable - use current working directory
|
|
48
|
-
cache_path = os.path.join(os.getcwd(), f".{subdir}")
|
|
49
|
-
os.makedirs(cache_path, exist_ok=True)
|
|
50
|
-
os.environ[env_var] = cache_path
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
# Configure cache directories (before any library imports)
|
|
54
|
-
_setup_cache_dir("TORCH_HOME", "torch_cache")
|
|
55
|
-
_setup_cache_dir("MPLCONFIGDIR", "matplotlib")
|
|
56
|
-
_setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
|
|
57
|
-
_setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
58
|
-
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
59
|
-
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
40
|
+
setup_hpc_cache_dirs()
|
|
60
41
|
|
|
61
42
|
import argparse # noqa: E402
|
|
62
43
|
import logging # noqa: E402
|
|
63
44
|
import pickle # noqa: E402
|
|
64
45
|
from pathlib import Path # noqa: E402
|
|
46
|
+
from typing import Any # noqa: E402
|
|
65
47
|
|
|
66
48
|
import matplotlib.pyplot as plt # noqa: E402
|
|
67
49
|
import numpy as np # noqa: E402
|
|
@@ -314,7 +296,7 @@ def load_checkpoint(
|
|
|
314
296
|
in_shape: tuple[int, ...],
|
|
315
297
|
out_size: int,
|
|
316
298
|
model_name: str | None = None,
|
|
317
|
-
) -> tuple[nn.Module,
|
|
299
|
+
) -> tuple[nn.Module, Any]:
|
|
318
300
|
"""
|
|
319
301
|
Load model and scaler from Accelerate checkpoint directory.
|
|
320
302
|
|