wavedl 1.5.1__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.5.1"
21
+ __version__ = "1.5.2"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/utils/data.py CHANGED
@@ -202,18 +202,31 @@ class NPZSource(DataSource):
202
202
  """Load data from NumPy .npz archives."""
203
203
 
204
204
  @staticmethod
205
- def _safe_load(path: str, mmap_mode: str | None = None):
206
- """Load NPZ with pickle only if needed (sparse matrix support)."""
205
+ def _safe_load(path: str, keys_to_probe: list[str], mmap_mode: str | None = None):
206
+ """Load NPZ with pickle only if needed (sparse matrix support).
207
+
208
+ The error for object arrays happens at ACCESS time, not load time.
209
+ So we need to probe the keys to detect if pickle is required.
210
+ """
211
+ data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
207
212
  try:
208
- return np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
209
- except ValueError:
210
- # Fallback for sparse matrices stored as object arrays
211
- return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
213
+ # Probe keys to trigger error if object arrays exist
214
+ for key in keys_to_probe:
215
+ if key in data:
216
+ _ = data[key] # This raises ValueError for object arrays
217
+ return data
218
+ except ValueError as e:
219
+ if "allow_pickle=False" in str(e):
220
+ # Fallback for sparse matrices stored as object arrays
221
+ data.close() if hasattr(data, "close") else None
222
+ return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
223
+ raise
212
224
 
213
225
  def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
214
226
  """Load NPZ file (pickle enabled only for sparse matrices)."""
215
- data = self._safe_load(path)
216
- keys = list(data.keys())
227
+ # First pass to find keys without loading data
228
+ with np.load(path, allow_pickle=False) as probe:
229
+ keys = list(probe.keys())
217
230
 
218
231
  input_key = self._find_key(keys, INPUT_KEYS)
219
232
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -225,6 +238,7 @@ class NPZSource(DataSource):
225
238
  f"Found: {keys}"
226
239
  )
227
240
 
241
+ data = self._safe_load(path, [input_key, output_key])
228
242
  inp = data[input_key]
229
243
  outp = data[output_key]
230
244
 
@@ -243,8 +257,9 @@ class NPZSource(DataSource):
243
257
 
244
258
  Note: Returns memory-mapped arrays - do NOT modify them.
245
259
  """
246
- data = self._safe_load(path, mmap_mode="r")
247
- keys = list(data.keys())
260
+ # First pass to find keys without loading data
261
+ with np.load(path, allow_pickle=False) as probe:
262
+ keys = list(probe.keys())
248
263
 
249
264
  input_key = self._find_key(keys, INPUT_KEYS)
250
265
  output_key = self._find_key(keys, OUTPUT_KEYS)
@@ -256,6 +271,7 @@ class NPZSource(DataSource):
256
271
  f"Found: {keys}"
257
272
  )
258
273
 
274
+ data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
259
275
  inp = data[input_key]
260
276
  outp = data[output_key]
261
277
 
@@ -263,8 +279,9 @@ class NPZSource(DataSource):
263
279
 
264
280
  def load_outputs_only(self, path: str) -> np.ndarray:
265
281
  """Load only targets from NPZ (avoids loading large input arrays)."""
266
- data = self._safe_load(path)
267
- keys = list(data.keys())
282
+ # First pass to find keys without loading data
283
+ with np.load(path, allow_pickle=False) as probe:
284
+ keys = list(probe.keys())
268
285
 
269
286
  output_key = self._find_key(keys, OUTPUT_KEYS)
270
287
  if output_key is None:
@@ -273,6 +290,7 @@ class NPZSource(DataSource):
273
290
  f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
274
291
  )
275
292
 
293
+ data = self._safe_load(path, [output_key])
276
294
  return data[output_key]
277
295
 
278
296
 
@@ -751,19 +769,22 @@ def load_test_data(
751
769
  except KeyError:
752
770
  # Try with just inputs if outputs not found (inference-only mode)
753
771
  if format == "npz":
754
- data = NPZSource._safe_load(path)
755
- keys = list(data.keys())
772
+ # First pass to find keys
773
+ with np.load(path, allow_pickle=False) as probe:
774
+ keys = list(probe.keys())
756
775
  inp_key = DataSource._find_key(keys, custom_input_keys)
757
776
  if inp_key is None:
758
777
  raise KeyError(
759
778
  f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
760
779
  )
780
+ out_key = DataSource._find_key(keys, custom_output_keys)
781
+ keys_to_probe = [inp_key] + ([out_key] if out_key else [])
782
+ data = NPZSource._safe_load(path, keys_to_probe)
761
783
  inp = data[inp_key]
762
784
  if inp.dtype == object:
763
785
  inp = np.array(
764
786
  [x.toarray() if hasattr(x, "toarray") else x for x in inp]
765
787
  )
766
- out_key = DataSource._find_key(keys, custom_output_keys)
767
788
  outp = data[out_key] if out_key else None
768
789
  elif format == "hdf5":
769
790
  # HDF5: input-only loading for inference
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.5.1
3
+ Version: 1.5.2
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -1,4 +1,4 @@
1
- wavedl/__init__.py,sha256=9RY06pDbdsRrjTpxRYGYMTMZl7jB6-Tm-elfcPtvY3Y,1177
1
+ wavedl/__init__.py,sha256=-Xrp2nvZZmC3EqroLASOAApSKxBBqAL8YO0BQQYKRTs,1177
2
2
  wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
3
3
  wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
4
4
  wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
24
24
  wavedl/utils/config.py,sha256=jGW-K7AYB6zrD2BfVm2XPnSY9rbfL_EkM4bwxhBLuwM,10859
25
25
  wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
26
26
  wavedl/utils/cross_validation.py,sha256=tXiBOY1T7eyO9FwOcxvOkPlhMDdm5rCH1TGDPE-jZak,17961
27
- wavedl/utils/data.py,sha256=11N7Y6w5PYUSBcGANIB1P8JOapqPrVHjTj9CuUeTLac,49172
27
+ wavedl/utils/data.py,sha256=KYw6YJQKPrcqwd_iqJSedws8bAWBOz5fqd3RSb2Wff4,50334
28
28
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
29
29
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
30
30
  wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
31
31
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
32
32
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
33
- wavedl-1.5.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
- wavedl-1.5.1.dist-info/METADATA,sha256=pzrIe4ZaDrRRGPIFlgQayG9L_k2q2g8TBP_puOQoTio,45512
35
- wavedl-1.5.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
- wavedl-1.5.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
- wavedl-1.5.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
- wavedl-1.5.1.dist-info/RECORD,,
33
+ wavedl-1.5.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
34
+ wavedl-1.5.2.dist-info/METADATA,sha256=1iOcg_SkMZW6ju4rDDhyH_dcyOHe4yZqGQEynVT73b0,45512
35
+ wavedl-1.5.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
36
+ wavedl-1.5.2.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
37
+ wavedl-1.5.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
38
+ wavedl-1.5.2.dist-info/RECORD,,
File without changes