wavedl 1.5.1__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/utils/data.py +36 -15
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/METADATA +1 -1
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/RECORD +8 -8
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/LICENSE +0 -0
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/WHEEL +0 -0
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.1.dist-info → wavedl-1.5.2.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/utils/data.py
CHANGED
|
@@ -202,18 +202,31 @@ class NPZSource(DataSource):
|
|
|
202
202
|
"""Load data from NumPy .npz archives."""
|
|
203
203
|
|
|
204
204
|
@staticmethod
|
|
205
|
-
def _safe_load(path: str, mmap_mode: str | None = None):
|
|
206
|
-
"""Load NPZ with pickle only if needed (sparse matrix support).
|
|
205
|
+
def _safe_load(path: str, keys_to_probe: list[str], mmap_mode: str | None = None):
|
|
206
|
+
"""Load NPZ with pickle only if needed (sparse matrix support).
|
|
207
|
+
|
|
208
|
+
The error for object arrays happens at ACCESS time, not load time.
|
|
209
|
+
So we need to probe the keys to detect if pickle is required.
|
|
210
|
+
"""
|
|
211
|
+
data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
|
|
207
212
|
try:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
213
|
+
# Probe keys to trigger error if object arrays exist
|
|
214
|
+
for key in keys_to_probe:
|
|
215
|
+
if key in data:
|
|
216
|
+
_ = data[key] # This raises ValueError for object arrays
|
|
217
|
+
return data
|
|
218
|
+
except ValueError as e:
|
|
219
|
+
if "allow_pickle=False" in str(e):
|
|
220
|
+
# Fallback for sparse matrices stored as object arrays
|
|
221
|
+
data.close() if hasattr(data, "close") else None
|
|
222
|
+
return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
|
|
223
|
+
raise
|
|
212
224
|
|
|
213
225
|
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
214
226
|
"""Load NPZ file (pickle enabled only for sparse matrices)."""
|
|
215
|
-
|
|
216
|
-
|
|
227
|
+
# First pass to find keys without loading data
|
|
228
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
229
|
+
keys = list(probe.keys())
|
|
217
230
|
|
|
218
231
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
219
232
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
@@ -225,6 +238,7 @@ class NPZSource(DataSource):
|
|
|
225
238
|
f"Found: {keys}"
|
|
226
239
|
)
|
|
227
240
|
|
|
241
|
+
data = self._safe_load(path, [input_key, output_key])
|
|
228
242
|
inp = data[input_key]
|
|
229
243
|
outp = data[output_key]
|
|
230
244
|
|
|
@@ -243,8 +257,9 @@ class NPZSource(DataSource):
|
|
|
243
257
|
|
|
244
258
|
Note: Returns memory-mapped arrays - do NOT modify them.
|
|
245
259
|
"""
|
|
246
|
-
|
|
247
|
-
|
|
260
|
+
# First pass to find keys without loading data
|
|
261
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
262
|
+
keys = list(probe.keys())
|
|
248
263
|
|
|
249
264
|
input_key = self._find_key(keys, INPUT_KEYS)
|
|
250
265
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
@@ -256,6 +271,7 @@ class NPZSource(DataSource):
|
|
|
256
271
|
f"Found: {keys}"
|
|
257
272
|
)
|
|
258
273
|
|
|
274
|
+
data = self._safe_load(path, [input_key, output_key], mmap_mode="r")
|
|
259
275
|
inp = data[input_key]
|
|
260
276
|
outp = data[output_key]
|
|
261
277
|
|
|
@@ -263,8 +279,9 @@ class NPZSource(DataSource):
|
|
|
263
279
|
|
|
264
280
|
def load_outputs_only(self, path: str) -> np.ndarray:
|
|
265
281
|
"""Load only targets from NPZ (avoids loading large input arrays)."""
|
|
266
|
-
|
|
267
|
-
|
|
282
|
+
# First pass to find keys without loading data
|
|
283
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
284
|
+
keys = list(probe.keys())
|
|
268
285
|
|
|
269
286
|
output_key = self._find_key(keys, OUTPUT_KEYS)
|
|
270
287
|
if output_key is None:
|
|
@@ -273,6 +290,7 @@ class NPZSource(DataSource):
|
|
|
273
290
|
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
274
291
|
)
|
|
275
292
|
|
|
293
|
+
data = self._safe_load(path, [output_key])
|
|
276
294
|
return data[output_key]
|
|
277
295
|
|
|
278
296
|
|
|
@@ -751,19 +769,22 @@ def load_test_data(
|
|
|
751
769
|
except KeyError:
|
|
752
770
|
# Try with just inputs if outputs not found (inference-only mode)
|
|
753
771
|
if format == "npz":
|
|
754
|
-
|
|
755
|
-
|
|
772
|
+
# First pass to find keys
|
|
773
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
774
|
+
keys = list(probe.keys())
|
|
756
775
|
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
757
776
|
if inp_key is None:
|
|
758
777
|
raise KeyError(
|
|
759
778
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
760
779
|
)
|
|
780
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
781
|
+
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
782
|
+
data = NPZSource._safe_load(path, keys_to_probe)
|
|
761
783
|
inp = data[inp_key]
|
|
762
784
|
if inp.dtype == object:
|
|
763
785
|
inp = np.array(
|
|
764
786
|
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
765
787
|
)
|
|
766
|
-
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
767
788
|
outp = data[out_key] if out_key else None
|
|
768
789
|
elif format == "hdf5":
|
|
769
790
|
# HDF5: input-only loading for inference
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256
|
|
1
|
+
wavedl/__init__.py,sha256=-Xrp2nvZZmC3EqroLASOAApSKxBBqAL8YO0BQQYKRTs,1177
|
|
2
2
|
wavedl/hpc.py,sha256=-iOjjKkXPcV_quj4vAsMBJN_zWKtD1lMRfIZZBhyGms,8756
|
|
3
3
|
wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
|
24
24
|
wavedl/utils/config.py,sha256=jGW-K7AYB6zrD2BfVm2XPnSY9rbfL_EkM4bwxhBLuwM,10859
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
26
|
wavedl/utils/cross_validation.py,sha256=tXiBOY1T7eyO9FwOcxvOkPlhMDdm5rCH1TGDPE-jZak,17961
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
27
|
+
wavedl/utils/data.py,sha256=KYw6YJQKPrcqwd_iqJSedws8bAWBOz5fqd3RSb2Wff4,50334
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
30
|
wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.2.dist-info/METADATA,sha256=1iOcg_SkMZW6ju4rDDhyH_dcyOHe4yZqGQEynVT73b0,45512
|
|
35
|
+
wavedl-1.5.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.2.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|