wavedl 1.5.5__py3-none-any.whl → 1.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/vit.py +85 -25
- wavedl/train.py +46 -14
- wavedl/utils/data.py +58 -36
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/METADATA +23 -22
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/RECORD +10 -10
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/LICENSE +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/WHEEL +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.5.dist-info → wavedl-1.5.6.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/models/vit.py
CHANGED
|
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
|
|
|
42
42
|
Supports 1D and 2D inputs:
|
|
43
43
|
- 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
|
|
44
44
|
- 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
48
|
+
patch_size: Size of each patch
|
|
49
|
+
embed_dim: Embedding dimension
|
|
50
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead of
|
|
51
|
+
dropping edge pixels. Important for NDE/QUS applications where edge
|
|
52
|
+
effects matter. Default: False (original behavior with warning).
|
|
45
53
|
"""
|
|
46
54
|
|
|
47
|
-
def __init__(
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
in_shape: SpatialShape,
|
|
58
|
+
patch_size: int,
|
|
59
|
+
embed_dim: int,
|
|
60
|
+
pad_if_needed: bool = False,
|
|
61
|
+
):
|
|
48
62
|
super().__init__()
|
|
49
63
|
|
|
50
64
|
self.dim = len(in_shape)
|
|
51
65
|
self.patch_size = patch_size
|
|
52
66
|
self.embed_dim = embed_dim
|
|
67
|
+
self.pad_if_needed = pad_if_needed
|
|
68
|
+
self._padding = None # Will be set if padding is needed
|
|
53
69
|
|
|
54
70
|
if self.dim == 1:
|
|
55
71
|
# 1D: segment patches
|
|
56
72
|
L = in_shape[0]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
73
|
+
remainder = L % patch_size
|
|
74
|
+
if remainder != 0:
|
|
75
|
+
if pad_if_needed:
|
|
76
|
+
# Pad to next multiple of patch_size
|
|
77
|
+
pad_amount = patch_size - remainder
|
|
78
|
+
self._padding = (0, pad_amount) # (left, right)
|
|
79
|
+
L_padded = L + pad_amount
|
|
80
|
+
self.num_patches = L_padded // patch_size
|
|
81
|
+
else:
|
|
82
|
+
import warnings
|
|
83
|
+
|
|
84
|
+
warnings.warn(
|
|
85
|
+
f"Input length {L} not divisible by patch_size {patch_size}. "
|
|
86
|
+
f"Last {remainder} elements will be dropped. "
|
|
87
|
+
f"Consider using pad_if_needed=True or padding input to "
|
|
88
|
+
f"{((L // patch_size) + 1) * patch_size}.",
|
|
89
|
+
UserWarning,
|
|
90
|
+
stacklevel=2,
|
|
91
|
+
)
|
|
92
|
+
self.num_patches = L // patch_size
|
|
93
|
+
else:
|
|
94
|
+
self.num_patches = L // patch_size
|
|
68
95
|
self.proj = nn.Conv1d(
|
|
69
96
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
70
97
|
)
|
|
71
98
|
elif self.dim == 2:
|
|
72
99
|
# 2D: grid patches
|
|
73
100
|
H, W = in_shape
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
101
|
+
h_rem, w_rem = H % patch_size, W % patch_size
|
|
102
|
+
if h_rem != 0 or w_rem != 0:
|
|
103
|
+
if pad_if_needed:
|
|
104
|
+
# Pad to next multiple of patch_size
|
|
105
|
+
h_pad = (patch_size - h_rem) % patch_size
|
|
106
|
+
w_pad = (patch_size - w_rem) % patch_size
|
|
107
|
+
# Padding format: (left, right, top, bottom)
|
|
108
|
+
self._padding = (0, w_pad, 0, h_pad)
|
|
109
|
+
H_padded, W_padded = H + h_pad, W + w_pad
|
|
110
|
+
self.num_patches = (H_padded // patch_size) * (
|
|
111
|
+
W_padded // patch_size
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
import warnings
|
|
115
|
+
|
|
116
|
+
warnings.warn(
|
|
117
|
+
f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
|
|
118
|
+
f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
|
|
119
|
+
f"Consider using pad_if_needed=True or padding to "
|
|
120
|
+
f"({((H // patch_size) + 1) * patch_size}, "
|
|
121
|
+
f"{((W // patch_size) + 1) * patch_size}).",
|
|
122
|
+
UserWarning,
|
|
123
|
+
stacklevel=2,
|
|
124
|
+
)
|
|
125
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
126
|
+
else:
|
|
127
|
+
self.num_patches = (H // patch_size) * (W // patch_size)
|
|
86
128
|
self.proj = nn.Conv2d(
|
|
87
129
|
1, embed_dim, kernel_size=patch_size, stride=patch_size
|
|
88
130
|
)
|
|
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
|
|
|
97
139
|
Returns:
|
|
98
140
|
Patch embeddings (B, num_patches, embed_dim)
|
|
99
141
|
"""
|
|
142
|
+
# Apply padding if configured
|
|
143
|
+
if self._padding is not None:
|
|
144
|
+
x = nn.functional.pad(x, self._padding, mode="constant", value=0)
|
|
145
|
+
|
|
100
146
|
x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
|
|
101
147
|
x = x.flatten(2) # (B, embed_dim, num_patches)
|
|
102
148
|
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
|
|
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
|
|
|
185
231
|
3. Transformer encoder blocks
|
|
186
232
|
4. Extract CLS token
|
|
187
233
|
5. Regression head
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
|
|
237
|
+
out_size: Number of regression targets
|
|
238
|
+
patch_size: Size of each patch (default: 16)
|
|
239
|
+
embed_dim: Embedding dimension (default: 768)
|
|
240
|
+
depth: Number of transformer blocks (default: 12)
|
|
241
|
+
num_heads: Number of attention heads (default: 12)
|
|
242
|
+
mlp_ratio: MLP hidden dim multiplier (default: 4.0)
|
|
243
|
+
dropout_rate: Dropout rate (default: 0.1)
|
|
244
|
+
pad_if_needed: If True, pad input to nearest patch-aligned size instead
|
|
245
|
+
of dropping edge pixels. Important for NDE/QUS applications.
|
|
188
246
|
"""
|
|
189
247
|
|
|
190
248
|
def __init__(
|
|
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
|
|
|
197
255
|
num_heads: int = 12,
|
|
198
256
|
mlp_ratio: float = 4.0,
|
|
199
257
|
dropout_rate: float = 0.1,
|
|
258
|
+
pad_if_needed: bool = False,
|
|
200
259
|
**kwargs,
|
|
201
260
|
):
|
|
202
261
|
super().__init__(in_shape, out_size)
|
|
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
|
|
|
207
266
|
self.num_heads = num_heads
|
|
208
267
|
self.dropout_rate = dropout_rate
|
|
209
268
|
self.dim = len(in_shape)
|
|
269
|
+
self.pad_if_needed = pad_if_needed
|
|
210
270
|
|
|
211
271
|
# Patch embedding
|
|
212
|
-
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
|
|
272
|
+
self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
|
|
213
273
|
num_patches = self.patch_embed.num_patches
|
|
214
274
|
|
|
215
275
|
# Learnable CLS token and position embeddings
|
wavedl/train.py
CHANGED
|
@@ -162,12 +162,22 @@ except ImportError:
|
|
|
162
162
|
os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
|
|
163
163
|
os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
|
|
164
164
|
|
|
165
|
-
# Suppress
|
|
166
|
-
|
|
165
|
+
# Suppress warnings from known-noisy libraries, but preserve legitimate warnings
|
|
166
|
+
# from torch/numpy about NaN, dtype, and numerical issues.
|
|
167
167
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
168
168
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
169
|
+
# Pydantic v1/v2 compatibility warnings
|
|
169
170
|
warnings.filterwarnings("ignore", module="pydantic")
|
|
170
171
|
warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
|
|
172
|
+
# Transformer library warnings (loading configs, etc.)
|
|
173
|
+
warnings.filterwarnings("ignore", module="transformers")
|
|
174
|
+
# Accelerate verbose messages
|
|
175
|
+
warnings.filterwarnings("ignore", module="accelerate")
|
|
176
|
+
# torch.compile backend selection warnings
|
|
177
|
+
warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
|
|
178
|
+
warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
|
|
179
|
+
# Note: UserWarning from torch/numpy core is NOT suppressed to preserve
|
|
180
|
+
# legitimate warnings about NaN values, dtype mismatches, etc.
|
|
171
181
|
|
|
172
182
|
# ==============================================================================
|
|
173
183
|
# GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
|
|
@@ -228,6 +238,18 @@ def parse_args() -> argparse.Namespace:
|
|
|
228
238
|
default=[],
|
|
229
239
|
help="Python modules to import before training (for custom models)",
|
|
230
240
|
)
|
|
241
|
+
parser.add_argument(
|
|
242
|
+
"--pretrained",
|
|
243
|
+
action="store_true",
|
|
244
|
+
default=True,
|
|
245
|
+
help="Use pretrained weights (default: True)",
|
|
246
|
+
)
|
|
247
|
+
parser.add_argument(
|
|
248
|
+
"--no_pretrained",
|
|
249
|
+
dest="pretrained",
|
|
250
|
+
action="store_false",
|
|
251
|
+
help="Train from scratch without pretrained weights",
|
|
252
|
+
)
|
|
231
253
|
|
|
232
254
|
# Configuration File
|
|
233
255
|
parser.add_argument(
|
|
@@ -543,15 +565,11 @@ def main():
|
|
|
543
565
|
data_format = DataSource.detect_format(args.data_path)
|
|
544
566
|
source = get_data_source(data_format)
|
|
545
567
|
|
|
546
|
-
# Use memory-mapped loading when available
|
|
568
|
+
# Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
|
|
547
569
|
_cv_handle = None
|
|
548
570
|
if hasattr(source, "load_mmap"):
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
_cv_handle = result
|
|
552
|
-
X, y = result.inputs, result.outputs
|
|
553
|
-
else:
|
|
554
|
-
X, y = result # NPZ returns tuple directly
|
|
571
|
+
_cv_handle = source.load_mmap(args.data_path)
|
|
572
|
+
X, y = _cv_handle.inputs, _cv_handle.outputs
|
|
555
573
|
else:
|
|
556
574
|
X, y = source.load(args.data_path)
|
|
557
575
|
|
|
@@ -684,7 +702,9 @@ def main():
|
|
|
684
702
|
)
|
|
685
703
|
|
|
686
704
|
# Build model using registry
|
|
687
|
-
model = build_model(
|
|
705
|
+
model = build_model(
|
|
706
|
+
args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
|
|
707
|
+
)
|
|
688
708
|
|
|
689
709
|
if accelerator.is_main_process:
|
|
690
710
|
param_info = model.parameter_summary()
|
|
@@ -861,10 +881,22 @@ def main():
|
|
|
861
881
|
milestones=milestones,
|
|
862
882
|
warmup_epochs=args.warmup_epochs,
|
|
863
883
|
)
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
884
|
+
|
|
885
|
+
# For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
|
|
886
|
+
# because accelerator wraps scheduler.step() to sync across processes,
|
|
887
|
+
# which defeats our rank-0-only stepping for correct patience counting.
|
|
888
|
+
# Other schedulers are safe to prepare (no internal state affected by multi-call).
|
|
889
|
+
if args.scheduler == "plateau":
|
|
890
|
+
model, optimizer, train_dl, val_dl = accelerator.prepare(
|
|
891
|
+
model, optimizer, train_dl, val_dl
|
|
892
|
+
)
|
|
893
|
+
# Scheduler stays unwrapped - we handle sync manually in training loop
|
|
894
|
+
# But register it for checkpointing so state is saved/loaded on resume
|
|
895
|
+
accelerator.register_for_checkpointing(scheduler)
|
|
896
|
+
else:
|
|
897
|
+
model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
|
|
898
|
+
model, optimizer, train_dl, val_dl, scheduler
|
|
899
|
+
)
|
|
868
900
|
|
|
869
901
|
# ==========================================================================
|
|
870
902
|
# AUTO-RESUME / RESUME FROM CHECKPOINT
|
wavedl/utils/data.py
CHANGED
|
@@ -13,6 +13,7 @@ Version: 1.0.0
|
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
15
|
import gc
|
|
16
|
+
import hashlib
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
import pickle
|
|
@@ -49,6 +50,29 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
|
|
|
49
50
|
OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
|
|
50
51
|
|
|
51
52
|
|
|
53
|
+
def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Compute SHA256 hash of a file for cache validation.
|
|
56
|
+
|
|
57
|
+
Uses chunked reading to handle large files efficiently without loading
|
|
58
|
+
the entire file into memory. This is more reliable than mtime for detecting
|
|
59
|
+
actual content changes, especially with cloud sync services (Dropbox, etc.)
|
|
60
|
+
that may touch files without modifying content.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
path: Path to file to hash
|
|
64
|
+
chunk_size: Read buffer size (default 8MB for fast I/O)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Hex string of SHA256 hash
|
|
68
|
+
"""
|
|
69
|
+
hasher = hashlib.sha256()
|
|
70
|
+
with open(path, "rb") as f:
|
|
71
|
+
while chunk := f.read(chunk_size):
|
|
72
|
+
hasher.update(chunk)
|
|
73
|
+
return hasher.hexdigest()
|
|
74
|
+
|
|
75
|
+
|
|
52
76
|
class LazyDataHandle:
|
|
53
77
|
"""
|
|
54
78
|
Context manager wrapper for memory-mapped data handles.
|
|
@@ -272,13 +296,21 @@ class NPZSource(DataSource):
|
|
|
272
296
|
|
|
273
297
|
return inp, outp
|
|
274
298
|
|
|
275
|
-
def load_mmap(self, path: str) ->
|
|
299
|
+
def load_mmap(self, path: str) -> LazyDataHandle:
|
|
276
300
|
"""
|
|
277
301
|
Load data using memory-mapped mode for zero-copy access.
|
|
278
302
|
|
|
279
303
|
This allows processing large datasets without loading them entirely
|
|
280
304
|
into RAM. Critical for HPC environments with memory constraints.
|
|
281
305
|
|
|
306
|
+
Returns a LazyDataHandle for consistent API across all data sources.
|
|
307
|
+
The NpzFile is kept open for lazy access.
|
|
308
|
+
|
|
309
|
+
Usage:
|
|
310
|
+
with source.load_mmap(path) as (inputs, outputs):
|
|
311
|
+
# Use inputs and outputs
|
|
312
|
+
pass # File automatically closed
|
|
313
|
+
|
|
282
314
|
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
283
315
|
"""
|
|
284
316
|
# First pass to find keys without loading data
|
|
@@ -295,11 +327,13 @@ class NPZSource(DataSource):
|
|
|
295
327
|
f"Found: {keys}"
|
|
296
328
|
)
|
|
297
329
|
|
|
330
|
+
# Keep NpzFile open for lazy access (like HDF5/MATSource)
|
|
298
331
|
data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
|
|
299
332
|
inp = data[input_key]
|
|
300
333
|
outp = data[output_key]
|
|
301
334
|
|
|
302
|
-
|
|
335
|
+
# Return LazyDataHandle for consistent API with HDF5Source/MATSource
|
|
336
|
+
return LazyDataHandle(inp, outp, file_handle=data)
|
|
303
337
|
|
|
304
338
|
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
305
339
|
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
@@ -1148,32 +1182,21 @@ def prepare_data(
|
|
|
1148
1182
|
and os.path.exists(META_FILE)
|
|
1149
1183
|
)
|
|
1150
1184
|
|
|
1151
|
-
# Validate cache
|
|
1185
|
+
# Validate cache using content hash (portable across folders/machines)
|
|
1186
|
+
# File size is a fast pre-check, content hash is definitive validation
|
|
1152
1187
|
if cache_exists:
|
|
1153
1188
|
try:
|
|
1154
1189
|
with open(META_FILE, "rb") as f:
|
|
1155
1190
|
meta = pickle.load(f)
|
|
1156
|
-
cached_data_path = meta.get("data_path", None)
|
|
1157
1191
|
cached_file_size = meta.get("file_size", None)
|
|
1158
|
-
|
|
1192
|
+
cached_content_hash = meta.get("content_hash", None)
|
|
1159
1193
|
|
|
1160
1194
|
# Get current file stats
|
|
1161
1195
|
current_stats = os.stat(args.data_path)
|
|
1162
1196
|
current_size = current_stats.st_size
|
|
1163
|
-
current_mtime = current_stats.st_mtime
|
|
1164
1197
|
|
|
1165
|
-
# Check if
|
|
1166
|
-
if
|
|
1167
|
-
if accelerator.is_main_process:
|
|
1168
|
-
logger.warning(
|
|
1169
|
-
f"⚠️ Cache was created from different data file!\n"
|
|
1170
|
-
f" Cached: {cached_data_path}\n"
|
|
1171
|
-
f" Current: {os.path.abspath(args.data_path)}\n"
|
|
1172
|
-
f" Invalidating cache and regenerating..."
|
|
1173
|
-
)
|
|
1174
|
-
cache_exists = False
|
|
1175
|
-
# Check if file was modified (size or mtime changed)
|
|
1176
|
-
elif cached_file_size is not None and cached_file_size != current_size:
|
|
1198
|
+
# Check if file size changed (fast check before expensive hash)
|
|
1199
|
+
if cached_file_size is not None and cached_file_size != current_size:
|
|
1177
1200
|
if accelerator.is_main_process:
|
|
1178
1201
|
logger.warning(
|
|
1179
1202
|
f"⚠️ Data file size changed!\n"
|
|
@@ -1182,13 +1205,16 @@ def prepare_data(
|
|
|
1182
1205
|
f" Invalidating cache and regenerating..."
|
|
1183
1206
|
)
|
|
1184
1207
|
cache_exists = False
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1208
|
+
# Content hash check (robust against cloud sync mtime changes)
|
|
1209
|
+
elif cached_content_hash is not None:
|
|
1210
|
+
current_hash = _compute_file_hash(args.data_path)
|
|
1211
|
+
if cached_content_hash != current_hash:
|
|
1212
|
+
if accelerator.is_main_process:
|
|
1213
|
+
logger.warning(
|
|
1214
|
+
"⚠️ Data file content changed!\n"
|
|
1215
|
+
" Cache is stale, regenerating..."
|
|
1216
|
+
)
|
|
1217
|
+
cache_exists = False
|
|
1192
1218
|
except Exception:
|
|
1193
1219
|
cache_exists = False
|
|
1194
1220
|
|
|
@@ -1234,16 +1260,11 @@ def prepare_data(
|
|
|
1234
1260
|
|
|
1235
1261
|
# Load raw data using memory-mapped mode for all formats
|
|
1236
1262
|
# This avoids loading the entire dataset into RAM at once
|
|
1263
|
+
# All load_mmap() methods now return LazyDataHandle consistently
|
|
1264
|
+
_lazy_handle = None
|
|
1237
1265
|
try:
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
inp, outp = source.load_mmap(args.data_path)
|
|
1241
|
-
elif data_format == "hdf5":
|
|
1242
|
-
source = HDF5Source()
|
|
1243
|
-
_lazy_handle = source.load_mmap(args.data_path)
|
|
1244
|
-
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
1245
|
-
elif data_format == "mat":
|
|
1246
|
-
source = MATSource()
|
|
1266
|
+
source = get_data_source(data_format)
|
|
1267
|
+
if hasattr(source, "load_mmap"):
|
|
1247
1268
|
_lazy_handle = source.load_mmap(args.data_path)
|
|
1248
1269
|
inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
|
|
1249
1270
|
else:
|
|
@@ -1307,8 +1328,9 @@ def prepare_data(
|
|
|
1307
1328
|
f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
|
|
1308
1329
|
)
|
|
1309
1330
|
|
|
1310
|
-
# Save metadata (including data path, size,
|
|
1331
|
+
# Save metadata (including data path, size, content hash for cache validation)
|
|
1311
1332
|
file_stats = os.stat(args.data_path)
|
|
1333
|
+
content_hash = _compute_file_hash(args.data_path)
|
|
1312
1334
|
with open(META_FILE, "wb") as f:
|
|
1313
1335
|
pickle.dump(
|
|
1314
1336
|
{
|
|
@@ -1316,7 +1338,7 @@ def prepare_data(
|
|
|
1316
1338
|
"out_dim": out_dim,
|
|
1317
1339
|
"data_path": os.path.abspath(args.data_path),
|
|
1318
1340
|
"file_size": file_stats.st_size,
|
|
1319
|
-
"
|
|
1341
|
+
"content_hash": content_hash,
|
|
1320
1342
|
},
|
|
1321
1343
|
f,
|
|
1322
1344
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.6
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -470,6 +470,7 @@ WaveDL/
|
|
|
470
470
|
⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
|
|
471
471
|
- **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
|
|
472
472
|
- **Size**: ~20–350 MB per model depending on architecture
|
|
473
|
+
- **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
|
|
473
474
|
|
|
474
475
|
**💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
|
|
475
476
|
|
|
@@ -1030,7 +1031,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
|
|
|
1030
1031
|
|
|
1031
1032
|
## 📦 Examples [](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
|
|
1032
1033
|
|
|
1033
|
-
The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained
|
|
1034
|
+
The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
|
|
1034
1035
|
|
|
1035
1036
|
| Parameter | Unit | Description |
|
|
1036
1037
|
|-----------|------|-------------|
|
|
@@ -1045,22 +1046,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
|
|
|
1045
1046
|
|
|
1046
1047
|
```bash
|
|
1047
1048
|
# Run inference on the example data
|
|
1048
|
-
python -m wavedl.test --checkpoint ./examples/
|
|
1049
|
-
--data_path ./examples/
|
|
1050
|
-
--plot --save_predictions --output_dir ./examples/
|
|
1049
|
+
python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1050
|
+
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1051
|
+
--plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
|
|
1051
1052
|
|
|
1052
1053
|
# Export to ONNX (already included as model.onnx)
|
|
1053
|
-
python -m wavedl.test --checkpoint ./examples/
|
|
1054
|
-
--data_path ./examples/
|
|
1055
|
-
--export onnx --export_path ./examples/
|
|
1054
|
+
python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
|
|
1055
|
+
--data_path ./examples/elasticity_prediction/Test_data_100.mat \
|
|
1056
|
+
--export onnx --export_path ./examples/elasticity_prediction/model.onnx
|
|
1056
1057
|
```
|
|
1057
1058
|
|
|
1058
1059
|
**What's Included:**
|
|
1059
1060
|
|
|
1060
1061
|
| File | Description |
|
|
1061
1062
|
|------|-------------|
|
|
1062
|
-
| `best_checkpoint/` | Pre-trained
|
|
1063
|
-
| `
|
|
1063
|
+
| `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
|
|
1064
|
+
| `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
|
|
1064
1065
|
| `model.onnx` | ONNX export with embedded de-normalization |
|
|
1065
1066
|
| `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
|
|
1066
1067
|
| `training_curves.png` | Training/validation loss and learning rate plot |
|
|
@@ -1070,59 +1071,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
|
|
|
1070
1071
|
**Training Progress:**
|
|
1071
1072
|
|
|
1072
1073
|
<p align="center">
|
|
1073
|
-
<img src="examples/
|
|
1074
|
-
<em>Training and validation loss
|
|
1074
|
+
<img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
|
|
1075
|
+
<em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
|
|
1075
1076
|
</p>
|
|
1076
1077
|
|
|
1077
1078
|
**Inference Results:**
|
|
1078
1079
|
|
|
1079
1080
|
<p align="center">
|
|
1080
|
-
<img src="examples/
|
|
1081
|
+
<img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
|
|
1081
1082
|
<em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
|
|
1082
1083
|
</p>
|
|
1083
1084
|
|
|
1084
1085
|
<p align="center">
|
|
1085
|
-
<img src="examples/
|
|
1086
|
+
<img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
|
|
1086
1087
|
<em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
|
|
1087
1088
|
</p>
|
|
1088
1089
|
|
|
1089
1090
|
<p align="center">
|
|
1090
|
-
<img src="examples/
|
|
1091
|
+
<img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
|
|
1091
1092
|
<em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
|
|
1092
1093
|
</p>
|
|
1093
1094
|
|
|
1094
1095
|
<p align="center">
|
|
1095
|
-
<img src="examples/
|
|
1096
|
+
<img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
|
|
1096
1097
|
<em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
|
|
1097
1098
|
</p>
|
|
1098
1099
|
|
|
1099
1100
|
<p align="center">
|
|
1100
|
-
<img src="examples/
|
|
1101
|
+
<img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
|
|
1101
1102
|
<em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
|
|
1102
1103
|
</p>
|
|
1103
1104
|
|
|
1104
1105
|
<p align="center">
|
|
1105
|
-
<img src="examples/
|
|
1106
|
+
<img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
|
|
1106
1107
|
<em>Figure 6: Error correlation matrix between parameters</em>
|
|
1107
1108
|
</p>
|
|
1108
1109
|
|
|
1109
1110
|
<p align="center">
|
|
1110
|
-
<img src="examples/
|
|
1111
|
+
<img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
|
|
1111
1112
|
<em>Figure 7: Relative error (%) vs true value for each parameter</em>
|
|
1112
1113
|
</p>
|
|
1113
1114
|
|
|
1114
1115
|
<p align="center">
|
|
1115
|
-
<img src="examples/
|
|
1116
|
+
<img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
|
|
1116
1117
|
<em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
|
|
1117
1118
|
</p>
|
|
1118
1119
|
|
|
1119
1120
|
<p align="center">
|
|
1120
|
-
<img src="examples/
|
|
1121
|
+
<img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
|
|
1121
1122
|
<em>Figure 9: True vs predicted values by sample index</em>
|
|
1122
1123
|
</p>
|
|
1123
1124
|
|
|
1124
1125
|
<p align="center">
|
|
1125
|
-
<img src="examples/
|
|
1126
|
+
<img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
|
|
1126
1127
|
<em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
|
|
1127
1128
|
</p>
|
|
1128
1129
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=qesevvzcBx9pJrvfW07e7PB9_sjb1eOL1BrWpUF-wZM,1177
|
|
2
2
|
wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
|
|
3
3
|
wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=JlSXWyTdU4S_PTgvANqXN4ceCS9KONOybbRksDPPcuo,57570
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -19,20 +19,20 @@ wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,874
|
|
|
19
19
|
wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
|
-
wavedl/models/vit.py,sha256=
|
|
22
|
+
wavedl/models/vit.py,sha256=D4jlYAlvegb3O19jCPpUHYmt5q0SZ7EGVBIWiYbq0GA,14816
|
|
23
23
|
wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
24
24
|
wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
26
|
wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
27
|
+
wavedl/utils/data.py,sha256=l8aqC7mtnUyXPOS0cCbgE-jS8TKndXnHU2WiP1VU1Zk,58361
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
30
|
wavedl/utils/metrics.py,sha256=El2NYsulH5jxBhC1gCAMcS8C-yxEjuSC930LhsKYQrY,40059
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.6.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.6.dist-info/METADATA,sha256=lLNnw1m1vOKEvFKr-9-v3xb71RHqDUUJxqu4VzNR8eI,45715
|
|
35
|
+
wavedl-1.5.6.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.6.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.6.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|