wavedl 1.5.4__py3-none-any.whl → 1.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.4"
21
+ __version__ = "1.5.5"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpo.py CHANGED
@@ -175,13 +175,14 @@ def create_objective(args):
175
175
  env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
176
176
 
177
177
  # Run training
178
+ # Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
179
+ # because site-packages may be read-only and train.py creates cache directories
178
180
  try:
179
181
  result = subprocess.run(
180
182
  cmd,
181
183
  capture_output=True,
182
184
  text=True,
183
185
  timeout=args.timeout,
184
- cwd=Path(__file__).parent,
185
186
  env=env,
186
187
  )
187
188
 
wavedl/utils/data.py CHANGED
@@ -207,6 +207,10 @@ class NPZSource(DataSource):
207
207
 
208
208
  The error for object arrays happens at ACCESS time, not load time.
209
209
  So we need to probe the keys to detect if pickle is required.
210
+
211
+ WARNING: When mmap_mode is not None, the returned NpzFile must be kept
212
+ open for arrays to remain valid. Caller is responsible for closing.
213
+ For non-mmap loading, use _load_and_copy() instead to avoid leaks.
210
214
  """
211
215
  data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
212
216
  try:
@@ -222,6 +226,26 @@ class NPZSource(DataSource):
222
226
  return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
223
227
  raise
224
228
 
229
+ @staticmethod
230
+ def _load_and_copy(path: str, keys: list[str]) -> dict[str, np.ndarray]:
231
+ """Load NPZ and copy arrays, ensuring file is properly closed.
232
+
233
+ This prevents file descriptor leaks by copying arrays before closing.
234
+ Use this for eager loading; use _safe_load for memory-mapped access.
235
+ """
236
+ data = NPZSource._safe_load(path, keys, mmap_mode=None)
237
+ try:
238
+ result = {}
239
+ for key in keys:
240
+ if key in data:
241
+ arr = data[key]
242
+ # Copy ensures we don't hold reference to mmap
243
+ result[key] = arr.copy() if hasattr(arr, "copy") else arr
244
+ return result
245
+ finally:
246
+ if hasattr(data, "close"):
247
+ data.close()
248
+
225
249
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
226
250
  """Load NPZ file (pickle enabled only for sparse matrices)."""
227
251
  # First pass to find keys without loading data
@@ -238,7 +262,7 @@ class NPZSource(DataSource):
238
262
  f"Found: {keys}"
239
263
  )
240
264
 
241
- data = self._safe_load(path, [input_key, output_key])
265
+ data = self._load_and_copy(path, [input_key, output_key])
242
266
  inp = data[input_key]
243
267
  outp = data[output_key]
244
268
 
@@ -290,7 +314,7 @@ class NPZSource(DataSource):
290
314
  f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
291
315
  )
292
316
 
293
- data = self._safe_load(path, [output_key])
317
+ data = self._load_and_copy(path, [output_key])
294
318
  return data[output_key]
295
319
 
296
320
 
@@ -527,9 +551,17 @@ class MATSource(DataSource):
527
551
  inp = self._load_dataset(f, input_key)
528
552
  outp = self._load_dataset(f, output_key)
529
553
 
530
- # Handle 1D outputs that become (1, N) after transpose
531
- if outp.ndim == 2 and outp.shape[0] == 1:
532
- outp = outp.T
554
+ # Handle transposed outputs from MATLAB.
555
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
556
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
557
+ num_samples = inp.shape[0] # inp is already transposed
558
+ if outp.ndim == 2:
559
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
560
+ # 1D vector: (1, N) → (N, 1)
561
+ outp = outp.T
562
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
563
+ # Single sample with multiple targets: (T, 1) → (1, T)
564
+ outp = outp.T
533
565
 
534
566
  except OSError as e:
535
567
  raise ValueError(
@@ -614,7 +646,10 @@ class MATSource(DataSource):
614
646
  # Load with sparse matrix support
615
647
  outp = self._load_dataset(f, output_key)
616
648
 
617
- # Handle 1D outputs
649
+ # Handle 1D outputs that become (1, N) after transpose.
650
+ # Note: This method has no input to compare against, so we can't
651
+ # distinguish single-sample outputs. This is acceptable for training
652
+ # data where single-sample is unlikely. For inference, use load_test_data.
618
653
  if outp.ndim == 2 and outp.shape[0] == 1:
619
654
  outp = outp.T
620
655
 
@@ -775,7 +810,7 @@ def load_test_data(
775
810
  raise KeyError(
776
811
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
777
812
  )
778
- data = NPZSource._safe_load(
813
+ data = NPZSource._load_and_copy(
779
814
  path, [inp_key] + ([out_key] if out_key else [])
780
815
  )
781
816
  inp = data[inp_key]
@@ -824,8 +859,17 @@ def load_test_data(
824
859
  inp = mat_source._load_dataset(f, inp_key)
825
860
  if out_key:
826
861
  outp = mat_source._load_dataset(f, out_key)
827
- if outp.ndim == 2 and outp.shape[0] == 1:
828
- outp = outp.T
862
+ # Handle transposed outputs from MATLAB
863
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
864
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
865
+ num_samples = inp.shape[0]
866
+ if outp.ndim == 2:
867
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
868
+ # 1D vector: (1, N) → (N, 1)
869
+ outp = outp.T
870
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
871
+ # Single sample with multiple targets: (T, 1) → (1, T)
872
+ outp = outp.T
829
873
  else:
830
874
  outp = None
831
875
  else:
@@ -844,7 +888,7 @@ def load_test_data(
844
888
  )
845
889
  out_key = DataSource._find_key(keys, custom_output_keys)
846
890
  keys_to_probe = [inp_key] + ([out_key] if out_key else [])
847
- data = NPZSource._safe_load(path, keys_to_probe)
891
+ data = NPZSource._load_and_copy(path, keys_to_probe)
848
892
  inp = data[inp_key]
849
893
  if inp.dtype == object:
850
894
  inp = np.array(
@@ -894,9 +938,17 @@ def load_test_data(
894
938
  out_key = DataSource._find_key(keys, custom_output_keys)
895
939
  if out_key:
896
940
  outp = mat_source._load_dataset(f, out_key)
897
- # Handle 1D outputs that become (1, N) after transpose
898
- if outp.ndim == 2 and outp.shape[0] == 1:
899
- outp = outp.T
941
+ # Handle transposed outputs from MATLAB
942
+ # Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
943
+ # Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
944
+ num_samples = inp.shape[0]
945
+ if outp.ndim == 2:
946
+ if outp.shape[0] == 1 and outp.shape[1] == num_samples:
947
+ # 1D vector: (1, N) → (N, 1)
948
+ outp = outp.T
949
+ elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
950
+ # Single sample with multiple targets: (T, 1) → (1, T)
951
+ outp = outp.T
900
952
  else:
901
953
  outp = None
902
954
  else:
@@ -1153,6 +1205,18 @@ def prepare_data(
1153
1205
  logger.warning(
1154
1206
  f" Failed to remove stale cache {stale_file}: {e}"
1155
1207
  )
1208
+
1209
+ # Fail explicitly if stale cache files couldn't be removed
1210
+ # This prevents silent reuse of outdated data
1211
+ remaining_stale = [
1212
+ f for f in [CACHE_FILE, SCALER_FILE] if os.path.exists(f)
1213
+ ]
1214
+ if remaining_stale:
1215
+ raise RuntimeError(
1216
+ f"Cannot regenerate cache: stale files could not be removed. "
1217
+ f"Please manually delete: {remaining_stale}"
1218
+ )
1219
+
1156
1220
  # RANK 0: Create cache (can take a long time for large datasets)
1157
1221
  # Other ranks will wait at the barrier below
1158
1222
 
wavedl/utils/metrics.py CHANGED
@@ -815,7 +815,28 @@ def plot_qq(
815
815
 
816
816
  # Standardize errors for QQ plot
817
817
  err = errors[:, i]
818
- standardized = (err - np.mean(err)) / np.std(err)
818
+ std_err = np.std(err)
819
+
820
+ # Guard against zero variance (constant errors)
821
+ if std_err < 1e-10:
822
+ title = (
823
+ param_names[i] if param_names and i < len(param_names) else f"Param {i}"
824
+ )
825
+ ax.text(
826
+ 0.5,
827
+ 0.5,
828
+ "Zero variance\n(constant errors)",
829
+ ha="center",
830
+ va="center",
831
+ fontsize=10,
832
+ transform=ax.transAxes,
833
+ )
834
+ ax.set_title(f"{title}\n(zero variance)")
835
+ ax.set_xlabel("Theoretical Quantiles")
836
+ ax.set_ylabel("Sample Quantiles")
837
+ continue
838
+
839
+ standardized = (err - np.mean(err)) / std_err
819
840
 
820
841
  # Calculate theoretical quantiles and sample quantiles
821
842
  (osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.4
3
+ Version: 1.5.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -388,7 +388,7 @@ WaveDL/
388
388
  ├── configs/ # YAML config templates
389
389
  ├── examples/ # Ready-to-run examples
390
390
  ├── notebooks/ # Jupyter notebooks
391
- ├── unit_tests/ # Pytest test suite (725 tests)
391
+ ├── unit_tests/ # Pytest test suite (731 tests)
392
392
 
393
393
  ├── pyproject.toml # Package config, dependencies
394
394
  ├── CHANGELOG.md # Version history
@@ -1,6 +1,6 @@
1
- wavedl/__init__.py,sha256=L3ckuWk3BDr6h9oiADkGP_JKcGSF669qDkuzofh86IU,1177
1
+ wavedl/__init__.py,sha256=RTePiYlzCrUofbGSYWAAqoKeeyYjqEPzuXyze6ai324,1177
2
2
  wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
3
- wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
3
+ wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
4
4
  wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
5
5
  wavedl/train.py,sha256=7AVaCORFUv2_IgdYSPKdHLxbi11GzMOyu4RcNc4Uf_I,55963
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
24
  wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
25
25
  wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
26
26
  wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
27
- wavedl/utils/data.py,sha256=JusSrIZd98t9oiN0xTy2V2mfVyuBCIu0MLAQGcaC0vQ,54194
27
+ wavedl/utils/data.py,sha256=cmJ6tUw4Tcxj-l3Xsphs1Dnlx1MzxOPvk8etD5KXFNs,57686
28
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
29
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
30
- wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
30
+ wavedl/utils/metrics.py,sha256=El2NYsulH5jxBhC1gCAMcS8C-yxEjuSC930LhsKYQrY,40059
31
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
32
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
33
- wavedl-1.5.4.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
- wavedl-1.5.4.dist-info/METADATA,sha256=D7_MbjGWyVEIEH2m23GrJInZO4pcfHAINlY1FIUgD-A,45604
35
- wavedl-1.5.4.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
- wavedl-1.5.4.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
- wavedl-1.5.4.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
- wavedl-1.5.4.dist-info/RECORD,,
33
+ wavedl-1.5.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.5.dist-info/METADATA,sha256=0e7E8zLd-GlcR5Hbgp5VDYGGr36_9NzKBTShsG4xuQs,45604
35
+ wavedl-1.5.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.5.dist-info/RECORD,,
File without changes