wavedl 1.5.5__py3-none-any.whl → 1.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.5"
21
+ __version__ = "1.5.6"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/models/vit.py CHANGED
@@ -42,47 +42,89 @@ class PatchEmbed(nn.Module):
42
42
  Supports 1D and 2D inputs:
43
43
  - 1D: Input (B, 1, L) → (B, num_patches, embed_dim)
44
44
  - 2D: Input (B, 1, H, W) → (B, num_patches, embed_dim)
45
+
46
+ Args:
47
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
48
+ patch_size: Size of each patch
49
+ embed_dim: Embedding dimension
50
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead of
51
+ dropping edge pixels. Important for NDE/QUS applications where edge
52
+ effects matter. Default: False (original behavior with warning).
45
53
  """
46
54
 
47
- def __init__(self, in_shape: SpatialShape, patch_size: int, embed_dim: int):
55
+ def __init__(
56
+ self,
57
+ in_shape: SpatialShape,
58
+ patch_size: int,
59
+ embed_dim: int,
60
+ pad_if_needed: bool = False,
61
+ ):
48
62
  super().__init__()
49
63
 
50
64
  self.dim = len(in_shape)
51
65
  self.patch_size = patch_size
52
66
  self.embed_dim = embed_dim
67
+ self.pad_if_needed = pad_if_needed
68
+ self._padding = None # Will be set if padding is needed
53
69
 
54
70
  if self.dim == 1:
55
71
  # 1D: segment patches
56
72
  L = in_shape[0]
57
- if L % patch_size != 0:
58
- import warnings
59
-
60
- warnings.warn(
61
- f"Input length {L} not divisible by patch_size {patch_size}. "
62
- f"Last {L % patch_size} elements will be dropped. "
63
- f"Consider padding input to {((L // patch_size) + 1) * patch_size}.",
64
- UserWarning,
65
- stacklevel=2,
66
- )
67
- self.num_patches = L // patch_size
73
+ remainder = L % patch_size
74
+ if remainder != 0:
75
+ if pad_if_needed:
76
+ # Pad to next multiple of patch_size
77
+ pad_amount = patch_size - remainder
78
+ self._padding = (0, pad_amount) # (left, right)
79
+ L_padded = L + pad_amount
80
+ self.num_patches = L_padded // patch_size
81
+ else:
82
+ import warnings
83
+
84
+ warnings.warn(
85
+ f"Input length {L} not divisible by patch_size {patch_size}. "
86
+ f"Last {remainder} elements will be dropped. "
87
+ f"Consider using pad_if_needed=True or padding input to "
88
+ f"{((L // patch_size) + 1) * patch_size}.",
89
+ UserWarning,
90
+ stacklevel=2,
91
+ )
92
+ self.num_patches = L // patch_size
93
+ else:
94
+ self.num_patches = L // patch_size
68
95
  self.proj = nn.Conv1d(
69
96
  1, embed_dim, kernel_size=patch_size, stride=patch_size
70
97
  )
71
98
  elif self.dim == 2:
72
99
  # 2D: grid patches
73
100
  H, W = in_shape
74
- if H % patch_size != 0 or W % patch_size != 0:
75
- import warnings
76
-
77
- warnings.warn(
78
- f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
79
- f"Border pixels will be dropped (H: {H % patch_size}, W: {W % patch_size}). "
80
- f"Consider padding to ({((H // patch_size) + 1) * patch_size}, "
81
- f"{((W // patch_size) + 1) * patch_size}).",
82
- UserWarning,
83
- stacklevel=2,
84
- )
85
- self.num_patches = (H // patch_size) * (W // patch_size)
101
+ h_rem, w_rem = H % patch_size, W % patch_size
102
+ if h_rem != 0 or w_rem != 0:
103
+ if pad_if_needed:
104
+ # Pad to next multiple of patch_size
105
+ h_pad = (patch_size - h_rem) % patch_size
106
+ w_pad = (patch_size - w_rem) % patch_size
107
+ # Padding format: (left, right, top, bottom)
108
+ self._padding = (0, w_pad, 0, h_pad)
109
+ H_padded, W_padded = H + h_pad, W + w_pad
110
+ self.num_patches = (H_padded // patch_size) * (
111
+ W_padded // patch_size
112
+ )
113
+ else:
114
+ import warnings
115
+
116
+ warnings.warn(
117
+ f"Input shape ({H}, {W}) not divisible by patch_size {patch_size}. "
118
+ f"Border pixels will be dropped (H: {h_rem}, W: {w_rem}). "
119
+ f"Consider using pad_if_needed=True or padding to "
120
+ f"({((H // patch_size) + 1) * patch_size}, "
121
+ f"{((W // patch_size) + 1) * patch_size}).",
122
+ UserWarning,
123
+ stacklevel=2,
124
+ )
125
+ self.num_patches = (H // patch_size) * (W // patch_size)
126
+ else:
127
+ self.num_patches = (H // patch_size) * (W // patch_size)
86
128
  self.proj = nn.Conv2d(
87
129
  1, embed_dim, kernel_size=patch_size, stride=patch_size
88
130
  )
@@ -97,6 +139,10 @@ class PatchEmbed(nn.Module):
97
139
  Returns:
98
140
  Patch embeddings (B, num_patches, embed_dim)
99
141
  """
142
+ # Apply padding if configured
143
+ if self._padding is not None:
144
+ x = nn.functional.pad(x, self._padding, mode="constant", value=0)
145
+
100
146
  x = self.proj(x) # (B, embed_dim, ..reduced_spatial..)
101
147
  x = x.flatten(2) # (B, embed_dim, num_patches)
102
148
  x = x.transpose(1, 2) # (B, num_patches, embed_dim)
@@ -185,6 +231,18 @@ class ViTBase(BaseModel):
185
231
  3. Transformer encoder blocks
186
232
  4. Extract CLS token
187
233
  5. Regression head
234
+
235
+ Args:
236
+ in_shape: Spatial shape (L,) for 1D or (H, W) for 2D
237
+ out_size: Number of regression targets
238
+ patch_size: Size of each patch (default: 16)
239
+ embed_dim: Embedding dimension (default: 768)
240
+ depth: Number of transformer blocks (default: 12)
241
+ num_heads: Number of attention heads (default: 12)
242
+ mlp_ratio: MLP hidden dim multiplier (default: 4.0)
243
+ dropout_rate: Dropout rate (default: 0.1)
244
+ pad_if_needed: If True, pad input to nearest patch-aligned size instead
245
+ of dropping edge pixels. Important for NDE/QUS applications.
188
246
  """
189
247
 
190
248
  def __init__(
@@ -197,6 +255,7 @@ class ViTBase(BaseModel):
197
255
  num_heads: int = 12,
198
256
  mlp_ratio: float = 4.0,
199
257
  dropout_rate: float = 0.1,
258
+ pad_if_needed: bool = False,
200
259
  **kwargs,
201
260
  ):
202
261
  super().__init__(in_shape, out_size)
@@ -207,9 +266,10 @@ class ViTBase(BaseModel):
207
266
  self.num_heads = num_heads
208
267
  self.dropout_rate = dropout_rate
209
268
  self.dim = len(in_shape)
269
+ self.pad_if_needed = pad_if_needed
210
270
 
211
271
  # Patch embedding
212
- self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim)
272
+ self.patch_embed = PatchEmbed(in_shape, patch_size, embed_dim, pad_if_needed)
213
273
  num_patches = self.patch_embed.num_patches
214
274
 
215
275
  # Learnable CLS token and position embeddings
wavedl/train.py CHANGED
@@ -162,12 +162,22 @@ except ImportError:
162
162
  os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
163
163
  os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
164
164
 
165
- # Suppress non-critical warnings for cleaner training logs
166
- warnings.filterwarnings("ignore", category=UserWarning)
165
+ # Suppress warnings from known-noisy libraries, but preserve legitimate warnings
166
+ # from torch/numpy about NaN, dtype, and numerical issues.
167
167
  warnings.filterwarnings("ignore", category=FutureWarning)
168
168
  warnings.filterwarnings("ignore", category=DeprecationWarning)
169
+ # Pydantic v1/v2 compatibility warnings
169
170
  warnings.filterwarnings("ignore", module="pydantic")
170
171
  warnings.filterwarnings("ignore", message=".*UnsupportedFieldAttributeWarning.*")
172
+ # Transformer library warnings (loading configs, etc.)
173
+ warnings.filterwarnings("ignore", module="transformers")
174
+ # Accelerate verbose messages
175
+ warnings.filterwarnings("ignore", module="accelerate")
176
+ # torch.compile backend selection warnings
177
+ warnings.filterwarnings("ignore", message=".*TorchDynamo.*")
178
+ warnings.filterwarnings("ignore", message=".*Dynamo is not supported.*")
179
+ # Note: UserWarning from torch/numpy core is NOT suppressed to preserve
180
+ # legitimate warnings about NaN values, dtype mismatches, etc.
171
181
 
172
182
  # ==============================================================================
173
183
  # GPU PERFORMANCE OPTIMIZATIONS (Ampere/Hopper: A100, H100)
@@ -228,6 +238,18 @@ def parse_args() -> argparse.Namespace:
228
238
  default=[],
229
239
  help="Python modules to import before training (for custom models)",
230
240
  )
241
+ parser.add_argument(
242
+ "--pretrained",
243
+ action="store_true",
244
+ default=True,
245
+ help="Use pretrained weights (default: True)",
246
+ )
247
+ parser.add_argument(
248
+ "--no_pretrained",
249
+ dest="pretrained",
250
+ action="store_false",
251
+ help="Train from scratch without pretrained weights",
252
+ )
231
253
 
232
254
  # Configuration File
233
255
  parser.add_argument(
@@ -543,15 +565,11 @@ def main():
543
565
  data_format = DataSource.detect_format(args.data_path)
544
566
  source = get_data_source(data_format)
545
567
 
546
- # Use memory-mapped loading when available
568
+ # Use memory-mapped loading when available (now returns LazyDataHandle for all formats)
547
569
  _cv_handle = None
548
570
  if hasattr(source, "load_mmap"):
549
- result = source.load_mmap(args.data_path)
550
- if hasattr(result, "inputs"):
551
- _cv_handle = result
552
- X, y = result.inputs, result.outputs
553
- else:
554
- X, y = result # NPZ returns tuple directly
571
+ _cv_handle = source.load_mmap(args.data_path)
572
+ X, y = _cv_handle.inputs, _cv_handle.outputs
555
573
  else:
556
574
  X, y = source.load(args.data_path)
557
575
 
@@ -684,7 +702,9 @@ def main():
684
702
  )
685
703
 
686
704
  # Build model using registry
687
- model = build_model(args.model, in_shape=in_shape, out_size=out_dim)
705
+ model = build_model(
706
+ args.model, in_shape=in_shape, out_size=out_dim, pretrained=args.pretrained
707
+ )
688
708
 
689
709
  if accelerator.is_main_process:
690
710
  param_info = model.parameter_summary()
@@ -861,10 +881,22 @@ def main():
861
881
  milestones=milestones,
862
882
  warmup_epochs=args.warmup_epochs,
863
883
  )
864
- # Prepare everything together
865
- model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
866
- model, optimizer, train_dl, val_dl, scheduler
867
- )
884
+
885
+ # For ReduceLROnPlateau: DON'T include scheduler in accelerator.prepare()
886
+ # because accelerator wraps scheduler.step() to sync across processes,
887
+ # which defeats our rank-0-only stepping for correct patience counting.
888
+ # Other schedulers are safe to prepare (no internal state affected by multi-call).
889
+ if args.scheduler == "plateau":
890
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
891
+ model, optimizer, train_dl, val_dl
892
+ )
893
+ # Scheduler stays unwrapped - we handle sync manually in training loop
894
+ # But register it for checkpointing so state is saved/loaded on resume
895
+ accelerator.register_for_checkpointing(scheduler)
896
+ else:
897
+ model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(
898
+ model, optimizer, train_dl, val_dl, scheduler
899
+ )
868
900
 
869
901
  # ==========================================================================
870
902
  # AUTO-RESUME / RESUME FROM CHECKPOINT
wavedl/utils/data.py CHANGED
@@ -13,6 +13,7 @@ Version: 1.0.0
13
13
  """
14
14
 
15
15
  import gc
16
+ import hashlib
16
17
  import logging
17
18
  import os
18
19
  import pickle
@@ -49,6 +50,29 @@ INPUT_KEYS = ["input_train", "input_test", "X", "data", "inputs", "features", "x
49
50
  OUTPUT_KEYS = ["output_train", "output_test", "Y", "labels", "outputs", "targets", "y"]
50
51
 
51
52
 
53
+ def _compute_file_hash(path: str, chunk_size: int = 8 * 1024 * 1024) -> str:
54
+ """
55
+ Compute SHA256 hash of a file for cache validation.
56
+
57
+ Uses chunked reading to handle large files efficiently without loading
58
+ the entire file into memory. This is more reliable than mtime for detecting
59
+ actual content changes, especially with cloud sync services (Dropbox, etc.)
60
+ that may touch files without modifying content.
61
+
62
+ Args:
63
+ path: Path to file to hash
64
+ chunk_size: Read buffer size (default 8MB for fast I/O)
65
+
66
+ Returns:
67
+ Hex string of SHA256 hash
68
+ """
69
+ hasher = hashlib.sha256()
70
+ with open(path, "rb") as f:
71
+ while chunk := f.read(chunk_size):
72
+ hasher.update(chunk)
73
+ return hasher.hexdigest()
74
+
75
+
52
76
  class LazyDataHandle:
53
77
  """
54
78
  Context manager wrapper for memory-mapped data handles.
@@ -272,13 +296,21 @@ class NPZSource(DataSource):
272
296
 
273
297
  return inp, outp
274
298
 
275
- def load_mmap(self, path: str) -> tuple[np.ndarray, np.ndarray]:
299
+ def load_mmap(self, path: str) -> LazyDataHandle:
276
300
  """
277
301
  Load data using memory-mapped mode for zero-copy access.
278
302
 
279
303
  This allows processing large datasets without loading them entirely
280
304
  into RAM. Critical for HPC environments with memory constraints.
281
305
 
306
+ Returns a LazyDataHandle for consistent API across all data sources.
307
+ The NpzFile is kept open for lazy access.
308
+
309
+ Usage:
310
+ with source.load_mmap(path) as (inputs, outputs):
311
+ # Use inputs and outputs
312
+ pass # File automatically closed
313
+
282
314
  Note: Returns memory-mapped arrays - do NOT modify them.
283
315
  """
284
316
  # First pass to find keys without loading data
@@ -295,11 +327,13 @@ class NPZSource(DataSource):
295
327
  f"Found: {keys}"
296
328
  )
297
329
 
330
+ # Keep NpzFile open for lazy access (like HDF5/MATSource)
298
331
  data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
299
332
  inp = data[input_key]
300
333
  outp = data[output_key]
301
334
 
302
- return inp, outp
335
+ # Return LazyDataHandle for consistent API with HDF5Source/MATSource
336
+ return LazyDataHandle(inp, outp, file_handle=data)
303
337
 
304
338
  def load_outputs_only(self, path: str) -> np.ndarray:
305
339
  """Load only targets from NPZ (avoids loading large input arrays)."""
@@ -1148,32 +1182,21 @@ def prepare_data(
1148
1182
  and os.path.exists(META_FILE)
1149
1183
  )
1150
1184
 
1151
- # Validate cache matches current data_path (prevents stale cache corruption)
1185
+ # Validate cache using content hash (portable across folders/machines)
1186
+ # File size is a fast pre-check, content hash is definitive validation
1152
1187
  if cache_exists:
1153
1188
  try:
1154
1189
  with open(META_FILE, "rb") as f:
1155
1190
  meta = pickle.load(f)
1156
- cached_data_path = meta.get("data_path", None)
1157
1191
  cached_file_size = meta.get("file_size", None)
1158
- cached_file_mtime = meta.get("file_mtime", None)
1192
+ cached_content_hash = meta.get("content_hash", None)
1159
1193
 
1160
1194
  # Get current file stats
1161
1195
  current_stats = os.stat(args.data_path)
1162
1196
  current_size = current_stats.st_size
1163
- current_mtime = current_stats.st_mtime
1164
1197
 
1165
- # Check if data path changed
1166
- if cached_data_path != os.path.abspath(args.data_path):
1167
- if accelerator.is_main_process:
1168
- logger.warning(
1169
- f"⚠️ Cache was created from different data file!\n"
1170
- f" Cached: {cached_data_path}\n"
1171
- f" Current: {os.path.abspath(args.data_path)}\n"
1172
- f" Invalidating cache and regenerating..."
1173
- )
1174
- cache_exists = False
1175
- # Check if file was modified (size or mtime changed)
1176
- elif cached_file_size is not None and cached_file_size != current_size:
1198
+ # Check if file size changed (fast check before expensive hash)
1199
+ if cached_file_size is not None and cached_file_size != current_size:
1177
1200
  if accelerator.is_main_process:
1178
1201
  logger.warning(
1179
1202
  f"⚠️ Data file size changed!\n"
@@ -1182,13 +1205,16 @@ def prepare_data(
1182
1205
  f" Invalidating cache and regenerating..."
1183
1206
  )
1184
1207
  cache_exists = False
1185
- elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
1186
- if accelerator.is_main_process:
1187
- logger.warning(
1188
- "⚠️ Data file was modified!\n"
1189
- " Cache may be stale, regenerating..."
1190
- )
1191
- cache_exists = False
1208
+ # Content hash check (robust against cloud sync mtime changes)
1209
+ elif cached_content_hash is not None:
1210
+ current_hash = _compute_file_hash(args.data_path)
1211
+ if cached_content_hash != current_hash:
1212
+ if accelerator.is_main_process:
1213
+ logger.warning(
1214
+ "⚠️ Data file content changed!\n"
1215
+ " Cache is stale, regenerating..."
1216
+ )
1217
+ cache_exists = False
1192
1218
  except Exception:
1193
1219
  cache_exists = False
1194
1220
 
@@ -1234,16 +1260,11 @@ def prepare_data(
1234
1260
 
1235
1261
  # Load raw data using memory-mapped mode for all formats
1236
1262
  # This avoids loading the entire dataset into RAM at once
1263
+ # All load_mmap() methods now return LazyDataHandle consistently
1264
+ _lazy_handle = None
1237
1265
  try:
1238
- if data_format == "npz":
1239
- source = NPZSource()
1240
- inp, outp = source.load_mmap(args.data_path)
1241
- elif data_format == "hdf5":
1242
- source = HDF5Source()
1243
- _lazy_handle = source.load_mmap(args.data_path)
1244
- inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
1245
- elif data_format == "mat":
1246
- source = MATSource()
1266
+ source = get_data_source(data_format)
1267
+ if hasattr(source, "load_mmap"):
1247
1268
  _lazy_handle = source.load_mmap(args.data_path)
1248
1269
  inp, outp = _lazy_handle.inputs, _lazy_handle.outputs
1249
1270
  else:
@@ -1307,8 +1328,9 @@ def prepare_data(
1307
1328
  f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
1308
1329
  )
1309
1330
 
1310
- # Save metadata (including data path, size, mtime for cache validation)
1331
+ # Save metadata (including data path, size, content hash for cache validation)
1311
1332
  file_stats = os.stat(args.data_path)
1333
+ content_hash = _compute_file_hash(args.data_path)
1312
1334
  with open(META_FILE, "wb") as f:
1313
1335
  pickle.dump(
1314
1336
  {
@@ -1316,7 +1338,7 @@ def prepare_data(
1316
1338
  "out_dim": out_dim,
1317
1339
  "data_path": os.path.abspath(args.data_path),
1318
1340
  "file_size": file_stats.st_size,
1319
- "file_mtime": file_stats.st_mtime,
1341
+ "content_hash": content_hash,
1320
1342
  },
1321
1343
  f,
1322
1344
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.5
3
+ Version: 1.5.6
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -470,6 +470,7 @@ WaveDL/
470
470
  ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
471
471
  - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
472
472
  - **Size**: ~20–350 MB per model depending on architecture
473
+ - **Train from scratch**: Use `--no_pretrained` to disable pretrained weights
473
474
 
474
475
  **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
475
476
 
@@ -1030,7 +1031,7 @@ print(f"✓ Output: {data['output_train'].shape} {data['output_train'].dtype}")
1030
1031
 
1031
1032
  ## 📦 Examples [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
1032
1033
 
1033
- The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained CNN predicts three physical parameters from Lamb wave dispersion curves:
1034
+ The `examples/` folder contains a **complete, ready-to-run example** for **material characterization of isotropic plates**. The pre-trained MobileNetV3 predicts three physical parameters from Lamb wave dispersion curves:
1034
1035
 
1035
1036
  | Parameter | Unit | Description |
1036
1037
  |-----------|------|-------------|
@@ -1045,22 +1046,22 @@ The `examples/` folder contains a **complete, ready-to-run example** for **mater
1045
1046
 
1046
1047
  ```bash
1047
1048
  # Run inference on the example data
1048
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1049
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1050
- --plot --save_predictions --output_dir ./examples/elastic_cnn_example/test_results
1049
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1050
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1051
+ --plot --save_predictions --output_dir ./examples/elasticity_prediction/test_results
1051
1052
 
1052
1053
  # Export to ONNX (already included as model.onnx)
1053
- python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoint \
1054
- --data_path ./examples/elastic_cnn_example/Test_data_500.mat \
1055
- --export onnx --export_path ./examples/elastic_cnn_example/model.onnx
1054
+ python -m wavedl.test --checkpoint ./examples/elasticity_prediction/best_checkpoint \
1055
+ --data_path ./examples/elasticity_prediction/Test_data_100.mat \
1056
+ --export onnx --export_path ./examples/elasticity_prediction/model.onnx
1056
1057
  ```
1057
1058
 
1058
1059
  **What's Included:**
1059
1060
 
1060
1061
  | File | Description |
1061
1062
  |------|-------------|
1062
- | `best_checkpoint/` | Pre-trained CNN checkpoint |
1063
- | `Test_data_500.mat` | 500 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1063
+ | `best_checkpoint/` | Pre-trained MobileNetV3 checkpoint |
1064
+ | `Test_data_100.mat` | 100 sample test set (500×500 dispersion curves → *h*, √(*E*/ρ), *ν*) |
1064
1065
  | `model.onnx` | ONNX export with embedded de-normalization |
1065
1066
  | `training_history.csv` | Epoch-by-epoch training metrics (loss, R², LR, etc.) |
1066
1067
  | `training_curves.png` | Training/validation loss and learning rate plot |
@@ -1070,59 +1071,59 @@ python -m wavedl.test --checkpoint ./examples/elastic_cnn_example/best_checkpoin
1070
1071
  **Training Progress:**
1071
1072
 
1072
1073
  <p align="center">
1073
- <img src="examples/elastic_cnn_example/training_curves.png" alt="Training curves" width="600"><br>
1074
- <em>Training and validation loss over 227 epochs with <code>onecycle</code> learning rate schedule</em>
1074
+ <img src="examples/elasticity_prediction/training_curves.png" alt="Training curves" width="600"><br>
1075
+ <em>Training and validation loss with <code>plateau</code> learning rate schedule</em>
1075
1076
  </p>
1076
1077
 
1077
1078
  **Inference Results:**
1078
1079
 
1079
1080
  <p align="center">
1080
- <img src="examples/elastic_cnn_example/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
+ <img src="examples/elasticity_prediction/test_results/scatter_all.png" alt="Scatter plot" width="700"><br>
1081
1082
  <em>Figure 1: Predictions vs ground truth for all three elastic parameters</em>
1082
1083
  </p>
1083
1084
 
1084
1085
  <p align="center">
1085
- <img src="examples/elastic_cnn_example/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
+ <img src="examples/elasticity_prediction/test_results/error_histogram.png" alt="Error histogram" width="700"><br>
1086
1087
  <em>Figure 2: Distribution of prediction errors showing near-zero mean bias</em>
1087
1088
  </p>
1088
1089
 
1089
1090
  <p align="center">
1090
- <img src="examples/elastic_cnn_example/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
+ <img src="examples/elasticity_prediction/test_results/residuals.png" alt="Residual plot" width="700"><br>
1091
1092
  <em>Figure 3: Residuals vs predicted values (no heteroscedasticity detected)</em>
1092
1093
  </p>
1093
1094
 
1094
1095
  <p align="center">
1095
- <img src="examples/elastic_cnn_example/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
+ <img src="examples/elasticity_prediction/test_results/bland_altman.png" alt="Bland-Altman plot" width="700"><br>
1096
1097
  <em>Figure 4: Bland-Altman analysis with ±1.96 SD limits of agreement</em>
1097
1098
  </p>
1098
1099
 
1099
1100
  <p align="center">
1100
- <img src="examples/elastic_cnn_example/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
+ <img src="examples/elasticity_prediction/test_results/qq_plot.png" alt="Q-Q plot" width="700"><br>
1101
1102
  <em>Figure 5: Q-Q plots confirming normally distributed prediction errors</em>
1102
1103
  </p>
1103
1104
 
1104
1105
  <p align="center">
1105
- <img src="examples/elastic_cnn_example/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
+ <img src="examples/elasticity_prediction/test_results/error_correlation.png" alt="Error correlation" width="300"><br>
1106
1107
  <em>Figure 6: Error correlation matrix between parameters</em>
1107
1108
  </p>
1108
1109
 
1109
1110
  <p align="center">
1110
- <img src="examples/elastic_cnn_example/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
+ <img src="examples/elasticity_prediction/test_results/relative_error.png" alt="Relative error" width="700"><br>
1111
1112
  <em>Figure 7: Relative error (%) vs true value for each parameter</em>
1112
1113
  </p>
1113
1114
 
1114
1115
  <p align="center">
1115
- <img src="examples/elastic_cnn_example/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
+ <img src="examples/elasticity_prediction/test_results/error_cdf.png" alt="Error CDF" width="500"><br>
1116
1117
  <em>Figure 8: Cumulative error distribution — 95% of predictions within indicated bounds</em>
1117
1118
  </p>
1118
1119
 
1119
1120
  <p align="center">
1120
- <img src="examples/elastic_cnn_example/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
+ <img src="examples/elasticity_prediction/test_results/prediction_vs_index.png" alt="Prediction vs index" width="700"><br>
1121
1122
  <em>Figure 9: True vs predicted values by sample index</em>
1122
1123
  </p>
1123
1124
 
1124
1125
  <p align="center">
1125
- <img src="examples/elastic_cnn_example/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
+ <img src="examples/elasticity_prediction/test_results/error_boxplot.png" alt="Error box plot" width="400"><br>
1126
1127
  <em>Figure 10: Error distribution summary (median, quartiles, outliers)</em>
1127
1128
  </p>
1128
1129
 
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=RTePiYlzCrUofbGSYWAAqoKeeyYjqEPzuXyze6ai324,1177
1
+ wavedl/__init__.py,sha256=qesevvzcBx9pJrvfW07e7PB9_sjb1eOL1BrWpUF-wZM,1177
2
2
  wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
3
3
  wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
4
4
  wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
5
- wavedl/train.py,sha256=7AVaCORFUv2_IgdYSPKdHLxbi11GzMOyu4RcNc4Uf_I,55963
5
+ wavedl/train.py,sha256=JlSXWyTdU4S_PTgvANqXN4ceCS9KONOybbRksDPPcuo,57570
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -19,20 +19,20 @@ wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,874
19
19
  wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
20
20
  wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
21
21
  wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
22
- wavedl/models/vit.py,sha256=68o9nNjkftvHFArAPupU2ew5e5yCsI2AYaT9TQinVMk,12075
22
+ wavedl/models/vit.py,sha256=D4jlYAlvegb3O19jCPpUHYmt5q0SZ7EGVBIWiYbq0GA,14816
23
23
  wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
24
  wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
25
25
  wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
26
26
  wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
27
- wavedl/utils/data.py,sha256=cmJ6tUw4Tcxj-l3Xsphs1Dnlx1MzxOPvk8etD5KXFNs,57686
27
+ wavedl/utils/data.py,sha256=l8aqC7mtnUyXPOS0cCbgE-jS8TKndXnHU2WiP1VU1Zk,58361
28
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
29
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
30
30
  wavedl/utils/metrics.py,sha256=El2NYsulH5jxBhC1gCAMcS8C-yxEjuSC930LhsKYQrY,40059
31
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
32
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
33
- wavedl-1.5.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
- wavedl-1.5.5.dist-info/METADATA,sha256=0e7E8zLd-GlcR5Hbgp5VDYGGr36_9NzKBTShsG4xuQs,45604
35
- wavedl-1.5.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
- wavedl-1.5.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
- wavedl-1.5.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
- wavedl-1.5.5.dist-info/RECORD,,
33
+ wavedl-1.5.6.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.6.dist-info/METADATA,sha256=lLNnw1m1vOKEvFKr-9-v3xb71RHqDUUJxqu4VzNR8eI,45715
35
+ wavedl-1.5.6.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.6.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.6.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.6.dist-info/RECORD,,
File without changes