wavedl 1.5.4__py3-none-any.whl → 1.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +2 -1
- wavedl/models/vit.py +85 -25
- wavedl/train.py +46 -14
- wavedl/utils/data.py +135 -49
- wavedl/utils/metrics.py +22 -1
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/METADATA +24 -23
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/RECORD +12 -12
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/LICENSE +0 -0
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/WHEEL +0 -0
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.4.dist-info → wavedl-1.5.6.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -175,13 +175,14 @@ def create_objective(args):
|
|
|
175
175
|
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
176
176
|
|
|
177
177
|
# Run training
|
|
178
|
+
# Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
|
|
179
|
+
# because site-packages may be read-only and train.py creates cache directories
|
|
178
180
|
try:
|
|
179
181
|
result = subprocess.run(
|
|
180
182
|
cmd,
|
|
181
183
|
capture_output=True,
|
|
182
184
|
text=True,
|
|
183
185
|
timeout=args.timeout,
|
|
184
|
-
cwd=Path(__file__).parent,
|
|
185
186
|
env=env,
|
|
186
187
|
)
|
|
187
188
|
|
wavedl/models/vit.py
CHANGED
|
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
|
|
|
42
42
|
Supports 1D and 2D inputs:
|
|
43
43
|
- 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
|
|
44
44
|
- 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
48
|
+
patch_size: Size of each patch
|
|
49
|
+
embed_dim: Embedding dimension
|
|
50
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead of
|
|
51
|
+
dropping edge pixels. Important for NDE/QUS applications where edge
|
|
52
|
+
effects matter. Default: False (original behavior with warning).
|
|
45
53
|
"""
|
|
46
54
|
|
|
47
|
-
def __init__(
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
in_shape: SpatialShape,
|
|
58
|
+
patch_size: int,
|
|
59
|
+
embed_dim: int,
|
|
60
|
+
pad_if_needed: bool = False,
|
|
61
|
+
):
|
|
48
62
|
super().__init__()
|
|
49
63
|
|
|
50
64
|
self.dim = len(in_shape)
|
|
51
65
|
self.patch_size = patch_size
|
|
52
66
|
self.embed_dim = embed_dim
|
|
67
|
+
self.pad_if_needed = pad_if_needed
|
|
68
|
+
self._padding = None # Will be set if padding is needed
|
|
53
69
|
|
|
54
70
|
if self.dim == 1:
|
|
55
71
|
# 1D: segment patches
|
|
56
72
|
L = in_shape[0]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
remainder = L % patch_size
|
|
74
|
+
if remainder != 0:
|
|
75
|
+
if pad_if_needed:
|
|
76
|
+
# Pad to next multiple of patch_size
|
|
77
|
+
pad_amount = patch_size - remainder
|
|
78
|
+
self._padding = (0, pad_amount) # (left, right)
|
|
79
|
+
L_padded = L + pad_amount
|
|
80
|
+
self.num_patches = L_padded // patch_size
|
|
81
|
+
else:
|
|
82
|
+
import warnings
|
|
83
|
+
|
|
84
|
+
warnings.warn(
|
|
85
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
86
|
+
f"Last {remainder} elements will be dropped. "
|
|
87
|
+
f"Consider using pad_if_needed=True or padding input to "
|
|
88
|
+
f"{((L // patch_size) + 1) * patch_size}.",
|
|
89
|
+
UserWarning,
|
|
90
|
+
stacklevel=2,
|
|
91
|
+
)
|
|
92
|
+
self.num_patches = L // patch_size
|
|
93
|
+
else:
|
|
94
|
+
self.num_patches = L // patch_size
|
|
68
95
|
self.proj = nn.Conv1d(
|
|
69
96
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
70
97
|
)
|
|
71
98
|
elif self.dim == 2:
|
|
72
99
|
# 2D: grid patches
|
|
73
100
|
H, W = in_shape
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
101
|
+
h_rem, w_rem = H % patch_size, W % patch_size
|
|
102
|
+
if h_rem != 0 or w_rem != 0:
|
|
103
|
+
if pad_if_needed:
|
|
104
|
+
# Pad to next multiple of patch_size
|
|
105
|
+
h_pad = (patch_size - h_rem) % patch_size
|
|
106
|
+
w_pad = (patch_size - w_rem) % patch_size
|
|
107
|
+
# Padding format: (left, right, top, bottom)
|
|
108
|
+
self._padding = (0, w_pad, 0, h_pad)
|
|
109
|
+
H_padded, W_padded = H + h_pad, W + w_pad
|
|
110
|
+
self.num_patches = (H_padded // patch_size) * (
|
|
111
|
+
W_padded // patch_size
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
import warnings
|
|
115
|
+
|
|
116
|
+
warnings.warn(
|
|
117
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
118
|
+
f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
|
|
119
|
+
f"Consider using pad_if_needed=True or padding to "
|
|
120
|
+
f"({((H // patch_size) + 1) * patch_size}, "
|
|
121
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
122
|
+
UserWarning,
|
|
123
|
+
stacklevel=2,
|
|
124
|
+
)
|
|
125
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
126
|
+
else:
|
|
127
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
86
128
|
self.proj = nn.Conv2d(
|
|
87
129
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
88
130
|
)
|
|
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
|
|
|
97
139
|
Returns:
|
|
98
140
|
Patch embeddings (B, num_patches, embed_dim)
|
|
99
141
|
"""
|
|
142
|
+
# Apply padding if configured
|
|
143
|
+
if self._padding is not None:
|
|
144
|
+
x = nn.functional.pad(x, self._padding, mode="constant", value=0)
|
|
145
|
+
|
|
100
146
|
x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
|
|
101
147
|
x = x.flatten(2) # (B, embed_dim, num_patches)
|
|
102
148
|
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
|
|
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
|
|
|
185
231
|
3. Transformer encoder blocks
|
|
186
232
|
4. Extract CLS token
|
|
187
233
|
5. Regression head
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
237
|
+
out_size: Number of regression targets
|
|
238
|
+
patch_size: Size of each patch (default: 16)
|
|
239
|
+
embed_dim: Embedding dimension (default: 768)
|
|
240
|
+
depth: Number of transformer blocks (default: 12)
|
|
241
|
+
num_heads: Number of attention heads (default: 12)
|
|
242
|
+
mlp_ratio: MLP hidden dim multiplier (default: 4.0)
|
|
243
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
244
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead
|
|
245
|
+
of dropping edge pixels. Important for NDE/QUS applications.
|
|
188
246
|
"""
|
|
189
247
|
|
|
190
248
|
def __init__(
|
|
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
|
|
|
197
255
|
num_heads: int = 12,
|
|
198
256
|
mlp_ratio: float = 4.0,
|
|
199
257
|
dropout_rate: float = 0.1,
|
|
258
|
+
pad_if_needed: bool = False,
|
|
200
259
|
**kwargs,
|
|
201
260
|
):
|
|
202
261
|
super().__init__(in_shape, out_size)
|
|
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
|
|
|
207
266
|
self.num_heads = num_heads
|
|
208
267
|
self.dropout_rate = dropout_rate
|
|
209
268
|
self.dim = len(in_shape)
|
|
269
|
+
self.pad_if_needed = pad_if_needed
|
|
210
270
|
|
|
211
271
|
# Patch embedding
|
|
212
|
-
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
|
|
272
|
+
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
|
|
213
273
|
num_patches = self.patch_embed.num_patches
|
|
214
274
|
|
|
215
275
|
# Learnable CLS token and position embeddings
|
wavedl/train.py
CHANGED
|
@@ -162,12 +162,22 @@ except ImportError:
|
|
|
162
162
|
os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
|
|
163
163
|
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
164
164
|
|
|
165
|
-
# Suppress
|
|
166
|
-
|
|
165
|
+
# Suppress warnings from known-noisy libraries, but preserve legitimate warnings
|
|
166
|
+
# from torch/numpy about NaN, dtype, and numerical issues.
|
|
167
167
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
168
168
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
169
|
+
# Pydantic v1/v2 compatibility warnings
|
|
169
170
|
warnings.filterwarnings("ignore", module="pydantic")
|
|
170
171
|
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
172
|
+
# Transformer library warnings (loading configs, etc.)
|
|
173
|
+
warnings.filterwarnings("ignore", module="transformers")
|
|
174
|
+
# Accelerate verbose messages
|
|
175
|
+
warnings.filterwarnings("ignore", module="accelerate")
|
|
176
|
+
# torch.compile backend selection warnings
|
|
177
|
+
warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
|
|
178
|
+
warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
|
|
179
|
+
# Note: UserWarning from torch/numpy core is NOT suppressed to preserve
|
|
180
|
+
# legitimate warnings about NaN values, dtype mismatches, etc.
|
|
171
181
|
|
|
172
182
|
# ==============================================================================
|
|
173
183
|
# GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
|
|
@@ -228,6 +238,18 @@ def parse_args() -> argparse.Namespace:
|
|
|
228
238
|
default=[],
|
|
229
239
|
help="Python modules to import before training (for custom models)",
|
|
230
240
|
)
|
|
241
|
+
parser.add_argument(
|
|
242
|
+
"--pretrained",
|
|
243
|
+
action="store_true",
|
|
244
|
+
default=True,
|
|
245
|
+
help="Use pretrained weights (default: True)",
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--no_pretrained",
|
|
249
|
+
dest="pretrained",
|
|
250
|
+
action="store_false",
|
|
251
|
+
help="Train from scratch without pretrained weights",
|
|
252
|
+
)
|
|
231
253
|
|
|
232
254
|
# Configuration File
|
|
233
255
|
parser.add_argument(
|
|
@@ -543,15 +565,11 @@ def main():
|
|
|
543
565
|
data_format = DataSource.detect_format(args.data_path)
|
|
544
566
|
source = get_data_source(data_format)
|
|
545
567
|
|
|
546
|
-
# Use memory-mapped loading when available
|
|
568
|
+
# Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
|
|
547
569
|
_cv_handle = None
|
|
548
570
|
if hasattr(source, "load_mmap"):
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
_cv_handle = result
|
|
552
|
-
X, y = result.inputs, result.outputs
|
|
553
|
-
else:
|
|
554
|
-
X, y = result # NPZ returns tuple directly
|
|
571
|
+
_cv_handle = source.load_mmap(args.data_path)
|
|
572
|
+
X, y = _cv_handle.inputs, _cv_handle.outputs
|
|
555
573
|
else:
|
|
556
574
|
X, y = source.load(args.data_path)
|
|
557
575
|
|
|
@@ -684,7 +702,9 @@ def main():
|
|
|
684
702
|
)
|
|
685
703
|
|
|
686
704
|
# Build model using registry
|
|
687
|
-
model = build_model(
|
|
705
|
+
model = build_model(
|
|
706
|
+
args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
|
|
707
|
+
)
|
|
688
708
|
|
|
689
709
|
if accelerator.is_main_process:
|
|
690
710
|
param_info = model.parameter_summary()
|
|
@@ -861,10 +881,22 @@ def main():
|
|
|
861
881
|
milestones=milestones,
|
|
862
882
|
warmup_epochs=args.warmup_epochs,
|
|
863
883
|
)
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
884
|
+
|
|
885
|
+
# For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
|
|
886
|
+
# because accelerator wraps scheduler.step() to sync across processes,
|
|
887
|
+
# which defeats our rank-0-only stepping for correct patience counting.
|
|
888
|
+
# Other schedulers are safe to prepare (no internal state affected by multi-call).
|
|
889
|
+
if args.scheduler == "plateau":
|
|
890
|
+
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
891
|
+
model, optimizer, train_dl, val_dl
|
|
892
|
+
)
|
|
893
|
+
# Scheduler stays unwrapped - we handle sync manually in training loop
|
|
894
|
+
# But register it for checkpointing so state is saved/loaded on resume
|
|
895
|
+
accelerator.register_for_checkpointing(scheduler)
|
|
896
|
+
else:
|
|
897
|
+
model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
|
|
898
|
+
model, optimizer, train_dl, val_dl, scheduler
|
|
899
|
+
)
|
|
868
900
|
|
|
869
901
|
# ==========================================================================
|
|
870
902
|
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
wavedl/utils/data.py
CHANGED
|
@@ -13,6 +13,7 @@ Version: 1.0.0
|
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
15
|
import gc
|
|
16
|
+
import hashlib
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
import pickle
|
|
@@ -49,6 +50,29 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
|
|
|
49
50
|
OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
|
|
50
51
|
|
|
51
52
|
|
|
53
|
+
def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Compute SHA256 hash of a file for cache validation.
|
|
56
|
+
|
|
57
|
+
Uses chunked reading to handle large files efficiently without loading
|
|
58
|
+
the entire file into memory. This is more reliable than mtime for detecting
|
|
59
|
+
actual content changes, especially with cloud sync services (Dropbox, etc.)
|
|
60
|
+
that may touch files without modifying content.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path: Path to file to hash
|
|
64
|
+
chunk_size: Read buffer size (default 8MB for fast I/O)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Hex string of SHA256 hash
|
|
68
|
+
"""
|
|
69
|
+
hasher = hashlib.sha256()
|
|
70
|
+
with open(path, "rb") as f:
|
|
71
|
+
while chunk := f.read(chunk_size):
|
|
72
|
+
hasher.update(chunk)
|
|
73
|
+
return hasher.hexdigest()
|
|
74
|
+
|
|
75
|
+
|
|
52
76
|
class LazyDataHandle:
|
|
53
77
|
"""
|
|
54
78
|
Context manager wrapper for memory-mapped data handles.
|
|
@@ -207,6 +231,10 @@ class NPZSource(DataSource):
|
|
|
207
231
|
|
|
208
232
|
The error for object arrays happens at ACCESS time, not load time.
|
|
209
233
|
So we need to probe the keys to detect if pickle is required.
|
|
234
|
+
|
|
235
|
+
WARNING: When mmap_mode is not None, the returned NpzFile must be kept
|
|
236
|
+
open for arrays to remain valid. Caller is responsible for closing.
|
|
237
|
+
For non-mmap loading, use _load_and_copy() instead to avoid leaks.
|
|
210
238
|
"""
|
|
211
239
|
data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
|
|
212
240
|
try:
|
|
@@ -222,6 +250,26 @@ class NPZSource(DataSource):
|
|
|
222
250
|
return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
|
|
223
251
|
raise
|
|
224
252
|
|
|
253
|
+
@staticmethod
|
|
254
|
+
def _load_and_copy(path: str, keys: list[str]) -> dict[str, np.ndarray]:
|
|
255
|
+
"""Load NPZ and copy arrays, ensuring file is properly closed.
|
|
256
|
+
|
|
257
|
+
This prevents file descriptor leaks by copying arrays before closing.
|
|
258
|
+
Use this for eager loading; use _safe_load for memory-mapped access.
|
|
259
|
+
"""
|
|
260
|
+
data = NPZSource._safe_load(path, keys, mmap_mode=None)
|
|
261
|
+
try:
|
|
262
|
+
result = {}
|
|
263
|
+
for key in keys:
|
|
264
|
+
if key in data:
|
|
265
|
+
arr = data[key]
|
|
266
|
+
# Copy ensures we don't hold reference to mmap
|
|
267
|
+
result[key] = arr.copy() if hasattr(arr, "copy") else arr
|
|
268
|
+
return result
|
|
269
|
+
finally:
|
|
270
|
+
if hasattr(data, "close"):
|
|
271
|
+
data.close()
|
|
272
|
+
|
|
225
273
|
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
226
274
|
"""Load NPZ file (pickle enabled only for sparse matrices)."""
|
|
227
275
|
# First pass to find keys without loading data
|
|
@@ -238,7 +286,7 @@ class NPZSource(DataSource):
|
|
|
238
286
|
f"Found: {keys}"
|
|
239
287
|
)
|
|
240
288
|
|
|
241
|
-
data = self.
|
|
289
|
+
data = self._load_and_copy(path, [input_key, output_key])
|
|
242
290
|
inp = data[input_key]
|
|
243
291
|
outp = data[output_key]
|
|
244
292
|
|
|
@@ -248,13 +296,21 @@ class NPZSource(DataSource):
|
|
|
248
296
|
|
|
249
297
|
return inp, outp
|
|
250
298
|
|
|
251
|
-
def load_mmap(self, path: str) ->
|
|
299
|
+
def load_mmap(self, path: str) -> LazyDataHandle:
|
|
252
300
|
"""
|
|
253
301
|
Load data using memory-mapped mode for zero-copy access.
|
|
254
302
|
|
|
255
303
|
This allows processing large datasets without loading them entirely
|
|
256
304
|
into RAM. Critical for HPC environments with memory constraints.
|
|
257
305
|
|
|
306
|
+
Returns a LazyDataHandle for consistent API across all data sources.
|
|
307
|
+
The NpzFile is kept open for lazy access.
|
|
308
|
+
|
|
309
|
+
Usage:
|
|
310
|
+
with source.load_mmap(path) as (inputs, outputs):
|
|
311
|
+
# Use inputs and outputs
|
|
312
|
+
pass # File automatically closed
|
|
313
|
+
|
|
258
314
|
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
259
315
|
"""
|
|
260
316
|
# First pass to find keys without loading data
|
|
@@ -271,11 +327,13 @@ class NPZSource(DataSource):
|
|
|
271
327
|
f"Found: {keys}"
|
|
272
328
|
)
|
|
273
329
|
|
|
330
|
+
# Keep NpzFile open for lazy access (like HDF5/MATSource)
|
|
274
331
|
data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
|
|
275
332
|
inp = data[input_key]
|
|
276
333
|
outp = data[output_key]
|
|
277
334
|
|
|
278
|
-
|
|
335
|
+
# Return LazyDataHandle for consistent API with HDF5Source/MATSource
|
|
336
|
+
return LazyDataHandle(inp, outp, file_handle=data)
|
|
279
337
|
|
|
280
338
|
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
281
339
|
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
@@ -290,7 +348,7 @@ class NPZSource(DataSource):
|
|
|
290
348
|
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
291
349
|
)
|
|
292
350
|
|
|
293
|
-
data = self.
|
|
351
|
+
data = self._load_and_copy(path, [output_key])
|
|
294
352
|
return data[output_key]
|
|
295
353
|
|
|
296
354
|
|
|
@@ -527,9 +585,17 @@ class MATSource(DataSource):
|
|
|
527
585
|
inp = self._load_dataset(f, input_key)
|
|
528
586
|
outp = self._load_dataset(f, output_key)
|
|
529
587
|
|
|
530
|
-
# Handle
|
|
531
|
-
|
|
532
|
-
|
|
588
|
+
# Handle transposed outputs from MATLAB.
|
|
589
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
590
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
591
|
+
num_samples = inp.shape[0] # inp is already transposed
|
|
592
|
+
if outp.ndim == 2:
|
|
593
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
594
|
+
# 1D vector: (1, N) → (N, 1)
|
|
595
|
+
outp = outp.T
|
|
596
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
597
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
598
|
+
outp = outp.T
|
|
533
599
|
|
|
534
600
|
except OSError as e:
|
|
535
601
|
raise ValueError(
|
|
@@ -614,7 +680,10 @@ class MATSource(DataSource):
|
|
|
614
680
|
# Load with sparse matrix support
|
|
615
681
|
outp = self._load_dataset(f, output_key)
|
|
616
682
|
|
|
617
|
-
# Handle 1D outputs
|
|
683
|
+
# Handle 1D outputs that become (1, N) after transpose.
|
|
684
|
+
# Note: This method has no input to compare against, so we can't
|
|
685
|
+
# distinguish single-sample outputs. This is acceptable for training
|
|
686
|
+
# data where single-sample is unlikely. For inference, use load_test_data.
|
|
618
687
|
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
619
688
|
outp = outp.T
|
|
620
689
|
|
|
@@ -775,7 +844,7 @@ def load_test_data(
|
|
|
775
844
|
raise KeyError(
|
|
776
845
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
777
846
|
)
|
|
778
|
-
data = NPZSource.
|
|
847
|
+
data = NPZSource._load_and_copy(
|
|
779
848
|
path, [inp_key] + ([out_key] if out_key else [])
|
|
780
849
|
)
|
|
781
850
|
inp = data[inp_key]
|
|
@@ -824,8 +893,17 @@ def load_test_data(
|
|
|
824
893
|
inp = mat_source._load_dataset(f, inp_key)
|
|
825
894
|
if out_key:
|
|
826
895
|
outp = mat_source._load_dataset(f, out_key)
|
|
827
|
-
|
|
828
|
-
|
|
896
|
+
# Handle transposed outputs from MATLAB
|
|
897
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
898
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
899
|
+
num_samples = inp.shape[0]
|
|
900
|
+
if outp.ndim == 2:
|
|
901
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
902
|
+
# 1D vector: (1, N) → (N, 1)
|
|
903
|
+
outp = outp.T
|
|
904
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
905
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
906
|
+
outp = outp.T
|
|
829
907
|
else:
|
|
830
908
|
outp = None
|
|
831
909
|
else:
|
|
@@ -844,7 +922,7 @@ def load_test_data(
|
|
|
844
922
|
)
|
|
845
923
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
846
924
|
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
847
|
-
data = NPZSource.
|
|
925
|
+
data = NPZSource._load_and_copy(path, keys_to_probe)
|
|
848
926
|
inp = data[inp_key]
|
|
849
927
|
if inp.dtype == object:
|
|
850
928
|
inp = np.array(
|
|
@@ -894,9 +972,17 @@ def load_test_data(
|
|
|
894
972
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
895
973
|
if out_key:
|
|
896
974
|
outp = mat_source._load_dataset(f, out_key)
|
|
897
|
-
# Handle
|
|
898
|
-
|
|
899
|
-
|
|
975
|
+
# Handle transposed outputs from MATLAB
|
|
976
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
977
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
978
|
+
num_samples = inp.shape[0]
|
|
979
|
+
if outp.ndim == 2:
|
|
980
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
981
|
+
# 1D vector: (1, N) → (N, 1)
|
|
982
|
+
outp = outp.T
|
|
983
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
984
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
985
|
+
outp = outp.T
|
|
900
986
|
else:
|
|
901
987
|
outp = None
|
|
902
988
|
else:
|
|
@@ -1096,32 +1182,21 @@ def prepare_data(
|
|
|
1096
1182
|
and os.path.exists(META_FILE)
|
|
1097
1183
|
)
|
|
1098
1184
|
|
|
1099
|
-
# Validate cache
|
|
1185
|
+
# Validate cache using content hash (portable across folders/machines)
|
|
1186
|
+
# File size is a fast pre-check, content hash is definitive validation
|
|
1100
1187
|
if cache_exists:
|
|
1101
1188
|
try:
|
|
1102
1189
|
with open(META_FILE, "rb") as f:
|
|
1103
1190
|
meta = pickle.load(f)
|
|
1104
|
-
cached_data_path = meta.get("data_path", None)
|
|
1105
1191
|
cached_file_size = meta.get("file_size", None)
|
|
1106
|
-
|
|
1192
|
+
cached_content_hash = meta.get("content_hash", None)
|
|
1107
1193
|
|
|
1108
1194
|
# Get current file stats
|
|
1109
1195
|
current_stats = os.stat(args.data_path)
|
|
1110
1196
|
current_size = current_stats.st_size
|
|
1111
|
-
current_mtime = current_stats.st_mtime
|
|
1112
1197
|
|
|
1113
|
-
# Check if
|
|
1114
|
-
if
|
|
1115
|
-
if accelerator.is_main_process:
|
|
1116
|
-
logger.warning(
|
|
1117
|
-
f"⚠️ Cache was created from different data file!\n"
|
|
1118
|
-
f" Cached: {cached_data_path}\n"
|
|
1119
|
-
f" Current: {os.path.abspath(args.data_path)}\n"
|
|
1120
|
-
f" Invalidating cache and regenerating..."
|
|
1121
|
-
)
|
|
1122
|
-
cache_exists = False
|
|
1123
|
-
# Check if file was modified (size or mtime changed)
|
|
1124
|
-
elif cached_file_size is not None and cached_file_size != current_size:
|
|
1198
|
+
# Check if file size changed (fast check before expensive hash)
|
|
1199
|
+
if cached_file_size is not None and cached_file_size != current_size:
|
|
1125
1200
|
if accelerator.is_main_process:
|
|
1126
1201
|
logger.warning(
|
|
1127
1202
|
f"⚠️ Data file size changed!\n"
|
|
@@ -1130,13 +1205,16 @@ def prepare_data(
|
|
|
1130
1205
|
f" Invalidating cache and regenerating..."
|
|
1131
1206
|
)
|
|
1132
1207
|
cache_exists = False
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1208
|
+
# Content hash check (robust against cloud sync mtime changes)
|
|
1209
|
+
elif cached_content_hash is not None:
|
|
1210
|
+
current_hash = _compute_file_hash(args.data_path)
|
|
1211
|
+
if cached_content_hash != current_hash:
|
|
1212
|
+
if accelerator.is_main_process:
|
|
1213
|
+
logger.warning(
|
|
1214
|
+
"⚠️ Data file content changed!\n"
|
|
1215
|
+
" Cache is stale, regenerating..."
|
|
1216
|
+
)
|
|
1217
|
+
cache_exists = False
|
|
1140
1218
|
except Exception:
|
|
1141
1219
|
cache_exists = False
|
|
1142
1220
|
|
|
@@ -1153,6 +1231,18 @@ def prepare_data(
|
|
|
1153
1231
|
logger.warning(
|
|
1154
1232
|
f" Failed to remove stale cache {stale_file}: {e}"
|
|
1155
1233
|
)
|
|
1234
|
+
|
|
1235
|
+
# Fail explicitly if stale cache files couldn't be removed
|
|
1236
|
+
# This prevents silent reuse of outdated data
|
|
1237
|
+
remaining_stale = [
|
|
1238
|
+
f for f in [CACHE_FILE, SCALER_FILE] if os.path.exists(f)
|
|
1239
|
+
]
|
|
1240
|
+
if remaining_stale:
|
|
1241
|
+
raise RuntimeError(
|
|
1242
|
+
f"Cannot regenerate cache: stale files could not be removed. "
|
|
1243
|
+
f"Please manually delete: {remaining_stale}"
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1156
1246
|
# RANK 0: Create cache (can take a long time for large datasets)
|
|
1157
1247
|
# Other ranks will wait at the barrier below
|
|
1158
1248
|
|
|
@@ -1170,16 +1260,11 @@ def prepare_data(
|
|
|
1170
1260
|
|
|
1171
1261
|
# Load raw data using memory-mapped mode for all formats
|
|
1172
1262
|
# This avoids loading the entire dataset into RAM at once
|
|
1263
|
+
# All load_mmap() methods now return LazyDataHandle consistently
|
|
1264
|
+
_lazy_handle = None
|
|
1173
1265
|
try:
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
inp, outp = source.load_mmap(args.data_path)
|
|
1177
|
-
elif data_format == "hdf5":
|
|
1178
|
-
source = HDF5Source()
|
|
1179
|
-
_lazy_handle = source.load_mmap(args.data_path)
|
|
1180
|
-
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
1181
|
-
elif data_format == "mat":
|
|
1182
|
-
source = MATSource()
|
|
1266
|
+
source = get_data_source(data_format)
|
|
1267
|
+
if hasattr(source, "load_mmap"):
|
|
1183
1268
|
_lazy_handle = source.load_mmap(args.data_path)
|
|
1184
1269
|
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
1185
1270
|
else:
|
|
@@ -1243,8 +1328,9 @@ def prepare_data(
|
|
|
1243
1328
|
f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
|
|
1244
1329
|
)
|
|
1245
1330
|
|
|
1246
|
-
# Save metadata (including data path, size,
|
|
1331
|
+
# Save metadata (including data path, size, content hash for cache validation)
|
|
1247
1332
|
file_stats = os.stat(args.data_path)
|
|
1333
|
+
content_hash = _compute_file_hash(args.data_path)
|
|
1248
1334
|
with open(META_FILE, "wb") as f:
|
|
1249
1335
|
pickle.dump(
|
|
1250
1336
|
{
|
|
@@ -1252,7 +1338,7 @@ def prepare_data(
|
|
|
1252
1338
|
"out_dim": out_dim,
|
|
1253
1339
|
"data_path": os.path.abspath(args.data_path),
|
|
1254
1340
|
"file_size": file_stats.st_size,
|
|
1255
|
-
"
|
|
1341
|
+
"content_hash": content_hash,
|
|
1256
1342
|
},
|
|
1257
1343
|
f,
|
|
1258
1344
|
)
|
wavedl/utils/metrics.py
CHANGED
|
@@ -815,7 +815,28 @@ def plot_qq(
|
|
|
815
815
|
|
|
816
816
|
# Standardize errors for QQ plot
|
|
817
817
|
err = errors[:, i]
|
|
818
|
-
|
|
818
|
+
std_err = np.std(err)
|
|
819
|
+
|
|
820
|
+
# Guard against zero variance (constant errors)
|
|
821
|
+
if std_err < 1e-10:
|
|
822
|
+
title = (
|
|
823
|
+
param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
824
|
+
)
|
|
825
|
+
ax.text(
|
|
826
|
+
0.5,
|
|
827
|
+
0.5,
|
|
828
|
+
"Zero variance\n(constant errors)",
|
|
829
|
+
ha="center",
|
|
830
|
+
va="center",
|
|
831
|
+
fontsize=10,
|
|
832
|
+
transform=ax.transAxes,
|
|
833
|
+
)
|
|
834
|
+
ax.set_title(f"{title}\n(zero variance)")
|
|
835
|
+
ax.set_xlabel("Theoretical Quantiles")
|
|
836
|
+
ax.set_ylabel("Sample Quantiles")
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
standardized = (err - np.mean(err)) / std_err
|
|
819
840
|
|
|
820
841
|
# Calculate theoretical quantiles and sample quantiles
|
|
821
842
|
(osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.6
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -388,7 +388,7 @@ WaveDL/
|
|
|
388
388
|
├── configs/ # YAML config templates
|
|
389
389
|
├── examples/ # Ready-to-run examples
|
|
390
390
|
├── notebooks/ # Jupyter notebooks
|
|
391
|
-
├── unit_tests/ # Pytest test suite (
|
|
391
|
+
├── unit_tests/ # Pytest test suite (731 tests)
|
|
392
392
|
│
|
|
393
393
|
├── pyproject.toml # Package config, dependencies
|
|
394
394
|
├── CHANGELOG.md # Version history
|
|
@@ -470,6 +470,7 @@ WaveDL/
|
|
|
470
470
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
471
471
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
472
472
|
- **Size**: ~20–350 MB per model depending on architecture
|
|
473
|
+
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
473
474
|
|
|
474
475
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
475
476
|
|
|
@@ -1030,7 +1031,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
|
|
|
1030
1031
|
|
|
1031
1032
|
## 📦 Examples [](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
1032
1033
|
|
|
1033
|
-
The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained
|
|
1034
|
+
The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
|
|
1034
1035
|
|
|
1035
1036
|
| Parameter | Unit | Description |
|
|
1036
1037
|
|-----------|------|-------------|
|
|
@@ -1045,22 +1046,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1045
1046
|
|
|
1046
1047
|
```bash
|
|
1047
1048
|
# Run inference on the example data
|
|
1048
|
-
python -m wavedl.test --checkpoint ./examples/
|
|
1049
|
-
--data_path ./examples/
|
|
1050
|
-
--plot --save_predictions --output_dir ./examples/
|
|
1049
|
+
python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1050
|
+
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1051
|
+
--plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
|
|
1051
1052
|
|
|
1052
1053
|
# Export to ONNX (already included as model.onnx)
|
|
1053
|
-
python -m wavedl.test --checkpoint ./examples/
|
|
1054
|
-
--data_path ./examples/
|
|
1055
|
-
--export onnx --export_path ./examples/
|
|
1054
|
+
python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1055
|
+
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1056
|
+
--export onnx --export_path ./examples/elasticity_prediction/model.onnx
|
|
1056
1057
|
```
|
|
1057
1058
|
|
|
1058
1059
|
**What's Included:**
|
|
1059
1060
|
|
|
1060
1061
|
| File | Description |
|
|
1061
1062
|
|------|-------------|
|
|
1062
|
-
| `best_checkpoint/` | Pre-trained
|
|
1063
|
-
| `
|
|
1063
|
+
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1064
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
1064
1065
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1065
1066
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1066
1067
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -1070,59 +1071,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
1070
1071
|
**Training Progress:**
|
|
1071
1072
|
|
|
1072
1073
|
<p align="center">
|
|
1073
|
-
<img src="examples/
|
|
1074
|
-
<em>Training and validation loss
|
|
1074
|
+
<img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
|
|
1075
|
+
<em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
|
|
1075
1076
|
</p>
|
|
1076
1077
|
|
|
1077
1078
|
**Inference Results:**
|
|
1078
1079
|
|
|
1079
1080
|
<p align="center">
|
|
1080
|
-
<img src="examples/
|
|
1081
|
+
<img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
|
|
1081
1082
|
<em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
|
|
1082
1083
|
</p>
|
|
1083
1084
|
|
|
1084
1085
|
<p align="center">
|
|
1085
|
-
<img src="examples/
|
|
1086
|
+
<img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
|
|
1086
1087
|
<em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
|
|
1087
1088
|
</p>
|
|
1088
1089
|
|
|
1089
1090
|
<p align="center">
|
|
1090
|
-
<img src="examples/
|
|
1091
|
+
<img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
|
|
1091
1092
|
<em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
|
|
1092
1093
|
</p>
|
|
1093
1094
|
|
|
1094
1095
|
<p align="center">
|
|
1095
|
-
<img src="examples/
|
|
1096
|
+
<img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
|
|
1096
1097
|
<em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
|
|
1097
1098
|
</p>
|
|
1098
1099
|
|
|
1099
1100
|
<p align="center">
|
|
1100
|
-
<img src="examples/
|
|
1101
|
+
<img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
|
|
1101
1102
|
<em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
|
|
1102
1103
|
</p>
|
|
1103
1104
|
|
|
1104
1105
|
<p align="center">
|
|
1105
|
-
<img src="examples/
|
|
1106
|
+
<img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
|
|
1106
1107
|
<em>Figure 6: Error correlation matrix between parameters</em>
|
|
1107
1108
|
</p>
|
|
1108
1109
|
|
|
1109
1110
|
<p align="center">
|
|
1110
|
-
<img src="examples/
|
|
1111
|
+
<img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
|
|
1111
1112
|
<em>Figure 7: Relative error (%) vs true value for each parameter</em>
|
|
1112
1113
|
</p>
|
|
1113
1114
|
|
|
1114
1115
|
<p align="center">
|
|
1115
|
-
<img src="examples/
|
|
1116
|
+
<img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
|
|
1116
1117
|
<em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
|
|
1117
1118
|
</p>
|
|
1118
1119
|
|
|
1119
1120
|
<p align="center">
|
|
1120
|
-
<img src="examples/
|
|
1121
|
+
<img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
|
|
1121
1122
|
<em>Figure 9: True vs predicted values by sample index</em>
|
|
1122
1123
|
</p>
|
|
1123
1124
|
|
|
1124
1125
|
<p align="center">
|
|
1125
|
-
<img src="examples/
|
|
1126
|
+
<img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
|
|
1126
1127
|
<em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
|
|
1127
1128
|
</p>
|
|
1128
1129
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=qesevvzcBx9pJrvfW07e7PB9_sjb1eOL1BrWpUF-wZM,1177
|
|
2
2
|
wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
|
|
3
|
-
wavedl/hpo.py,sha256=
|
|
3
|
+
wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=JlSXWyTdU4S_PTgvANqXN4ceCS9KONOybbRksDPPcuo,57570
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -19,20 +19,20 @@ wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,874
|
|
|
19
19
|
wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
|
-
wavedl/models/vit.py,sha256=
|
|
22
|
+
wavedl/models/vit.py,sha256=D4jlYAlvegb3O19jCPpUHYmt5q0SZ7EGVBIWiYbq0GA,14816
|
|
23
23
|
wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
24
24
|
wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
26
|
wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
27
|
+
wavedl/utils/data.py,sha256=l8aqC7mtnUyXPOS0cCbgE-jS8TKndXnHU2WiP1VU1Zk,58361
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
|
-
wavedl/utils/metrics.py,sha256=
|
|
30
|
+
wavedl/utils/metrics.py,sha256=El2NYsulH5jxBhC1gCAMcS8C-yxEjuSC930LhsKYQrY,40059
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.6.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.6.dist-info/METADATA,sha256=lLNnw1m1vOKEvFKr-9-v3xb71RHqDUUJxqu4VzNR8eI,45715
|
|
35
|
+
wavedl-1.5.6.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.6.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.6.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|