wavedl 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +22 -0
- wavedl/test.py +8 -0
- wavedl/train.py +2 -2
- wavedl/utils/data.py +36 -6
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/METADATA +3 -5
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/RECORD +11 -11
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/LICENSE +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/WHEEL +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/entry_points.txt +0 -0
- {wavedl-1.6.2.dist-info → wavedl-1.6.3.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/models/__init__.py
CHANGED
|
@@ -77,6 +77,15 @@ from .unet import UNetRegression
|
|
|
77
77
|
from .vit import ViTBase_, ViTSmall, ViTTiny
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
# Optional RATENet (unpublished, may be gitignored)
|
|
81
|
+
try:
|
|
82
|
+
from .ratenet import RATENet, RATENetLite, RATENetTiny, RATENetV2
|
|
83
|
+
|
|
84
|
+
_HAS_RATENET = True
|
|
85
|
+
except ImportError:
|
|
86
|
+
_HAS_RATENET = False
|
|
87
|
+
|
|
88
|
+
|
|
80
89
|
# Optional timm-based models (imported conditionally)
|
|
81
90
|
try:
|
|
82
91
|
from .caformer import CaFormerS18, CaFormerS36, PoolFormerS12
|
|
@@ -111,6 +120,7 @@ __all__ = [
|
|
|
111
120
|
"MC3_18",
|
|
112
121
|
"MODEL_REGISTRY",
|
|
113
122
|
"TCN",
|
|
123
|
+
# Classes (uppercase first, alphabetically)
|
|
114
124
|
"BaseModel",
|
|
115
125
|
"ConvNeXtBase_",
|
|
116
126
|
"ConvNeXtSmall",
|
|
@@ -152,6 +162,7 @@ __all__ = [
|
|
|
152
162
|
"VimBase",
|
|
153
163
|
"VimSmall",
|
|
154
164
|
"VimTiny",
|
|
165
|
+
# Functions (lowercase, alphabetically)
|
|
155
166
|
"build_model",
|
|
156
167
|
"get_model",
|
|
157
168
|
"list_models",
|
|
@@ -186,3 +197,14 @@ if _HAS_TIMM_MODELS:
|
|
|
186
197
|
"UniRepLKNetTiny",
|
|
187
198
|
]
|
|
188
199
|
)
|
|
200
|
+
|
|
201
|
+
# Add RATENet models to __all__ if available (unpublished)
|
|
202
|
+
if _HAS_RATENET:
|
|
203
|
+
__all__.extend(
|
|
204
|
+
[
|
|
205
|
+
"RATENet",
|
|
206
|
+
"RATENetLite",
|
|
207
|
+
"RATENetTiny",
|
|
208
|
+
"RATENetV2",
|
|
209
|
+
]
|
|
210
|
+
)
|
wavedl/test.py
CHANGED
|
@@ -398,6 +398,14 @@ def load_checkpoint(
|
|
|
398
398
|
|
|
399
399
|
if HAS_SAFETENSORS and weight_path.suffix == ".safetensors":
|
|
400
400
|
state_dict = load_safetensors(str(weight_path))
|
|
401
|
+
elif weight_path.suffix == ".safetensors":
|
|
402
|
+
# Safetensors file exists but library not installed
|
|
403
|
+
raise ImportError(
|
|
404
|
+
f"Checkpoint uses safetensors format ({weight_path.name}) but "
|
|
405
|
+
f"'safetensors' package is not installed. Install it with:\n"
|
|
406
|
+
f" pip install safetensors\n"
|
|
407
|
+
f"Or convert the checkpoint to PyTorch format (model.bin)."
|
|
408
|
+
)
|
|
401
409
|
else:
|
|
402
410
|
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
|
403
411
|
|
wavedl/train.py
CHANGED
|
@@ -1266,10 +1266,10 @@ def main():
|
|
|
1266
1266
|
os.path.join(args.output_dir, "best_model_weights.pth"),
|
|
1267
1267
|
)
|
|
1268
1268
|
|
|
1269
|
-
# Copy scaler to checkpoint for portability
|
|
1269
|
+
# Copy scaler to checkpoint for portability (always overwrite to stay current)
|
|
1270
1270
|
scaler_src = os.path.join(args.output_dir, "scaler.pkl")
|
|
1271
1271
|
scaler_dst = os.path.join(ckpt_dir, "scaler.pkl")
|
|
1272
|
-
if os.path.exists(scaler_src)
|
|
1272
|
+
if os.path.exists(scaler_src):
|
|
1273
1273
|
shutil.copy2(scaler_src, scaler_dst)
|
|
1274
1274
|
|
|
1275
1275
|
logger.info(
|
wavedl/utils/data.py
CHANGED
|
@@ -984,7 +984,15 @@ def load_test_data(
|
|
|
984
984
|
f"Available keys depend on file format. Original error: {e}"
|
|
985
985
|
) from e
|
|
986
986
|
|
|
987
|
-
#
|
|
987
|
+
# Also fail-fast if explicit input_key was provided but not found
|
|
988
|
+
# This prevents silently loading a different tensor when user mistyped key
|
|
989
|
+
if input_key is not None:
|
|
990
|
+
raise KeyError(
|
|
991
|
+
f"Explicit --input_key '{input_key}' not found in file. "
|
|
992
|
+
f"Original error: {e}"
|
|
993
|
+
) from e
|
|
994
|
+
|
|
995
|
+
# Legitimate fallback: no explicit keys, outputs just not present
|
|
988
996
|
if format == "npz":
|
|
989
997
|
# First pass to find keys
|
|
990
998
|
with np.load(path, allow_pickle=False) as probe:
|
|
@@ -1524,21 +1532,43 @@ def prepare_data(
|
|
|
1524
1532
|
|
|
1525
1533
|
logger.info(" ✔ Cache creation complete, synchronizing ranks...")
|
|
1526
1534
|
else:
|
|
1527
|
-
# NON-MAIN RANKS: Wait for cache creation
|
|
1528
|
-
#
|
|
1535
|
+
# NON-MAIN RANKS: Wait for cache creation with timeout
|
|
1536
|
+
# Use monotonic clock (immune to system clock changes)
|
|
1529
1537
|
import time
|
|
1530
1538
|
|
|
1531
|
-
wait_start = time.
|
|
1539
|
+
wait_start = time.monotonic()
|
|
1540
|
+
|
|
1541
|
+
# Robust env parsing with guards for invalid/non-positive values
|
|
1542
|
+
DEFAULT_CACHE_TIMEOUT = 3600 # 1 hour default
|
|
1543
|
+
try:
|
|
1544
|
+
env_timeout = os.environ.get("WAVEDL_CACHE_TIMEOUT", "")
|
|
1545
|
+
CACHE_TIMEOUT = (
|
|
1546
|
+
int(env_timeout) if env_timeout else DEFAULT_CACHE_TIMEOUT
|
|
1547
|
+
)
|
|
1548
|
+
if CACHE_TIMEOUT <= 0:
|
|
1549
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1550
|
+
except ValueError:
|
|
1551
|
+
CACHE_TIMEOUT = DEFAULT_CACHE_TIMEOUT
|
|
1552
|
+
|
|
1532
1553
|
while not (
|
|
1533
1554
|
os.path.exists(CACHE_FILE)
|
|
1534
1555
|
and os.path.exists(SCALER_FILE)
|
|
1535
1556
|
and os.path.exists(META_FILE)
|
|
1536
1557
|
):
|
|
1537
1558
|
time.sleep(5) # Check every 5 seconds
|
|
1538
|
-
elapsed = time.
|
|
1559
|
+
elapsed = time.monotonic() - wait_start
|
|
1560
|
+
|
|
1561
|
+
if elapsed > CACHE_TIMEOUT:
|
|
1562
|
+
raise RuntimeError(
|
|
1563
|
+
f"[Rank {accelerator.process_index}] Timeout waiting for cache "
|
|
1564
|
+
f"files after {CACHE_TIMEOUT}s. Rank 0 may have failed during "
|
|
1565
|
+
f"cache generation. Check rank 0 logs for errors."
|
|
1566
|
+
)
|
|
1567
|
+
|
|
1539
1568
|
if elapsed > 60 and int(elapsed) % 60 < 5: # Log every ~minute
|
|
1540
1569
|
logger.info(
|
|
1541
|
-
f" [Rank {accelerator.process_index}] Waiting for cache
|
|
1570
|
+
f" [Rank {accelerator.process_index}] Waiting for cache "
|
|
1571
|
+
f"creation... ({int(elapsed)}s / {CACHE_TIMEOUT}s max)"
|
|
1542
1572
|
)
|
|
1543
1573
|
# Small delay to ensure files are fully written
|
|
1544
1574
|
time.sleep(2)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.6.
|
|
3
|
+
Version: 1.6.3
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -214,11 +214,11 @@ This installs everything you need: training, inference, HPO, ONNX export.
|
|
|
214
214
|
```bash
|
|
215
215
|
git clone https://github.com/ductho-le/WaveDL.git
|
|
216
216
|
cd WaveDL
|
|
217
|
-
pip install -e .
|
|
217
|
+
pip install -e ".[dev]"
|
|
218
218
|
```
|
|
219
219
|
|
|
220
220
|
> [!NOTE]
|
|
221
|
-
> Python 3.11+ required. For
|
|
221
|
+
> Python 3.11+ required. For contributor setup (pre-commit hooks), see [CONTRIBUTING.md](.github/CONTRIBUTING.md).
|
|
222
222
|
|
|
223
223
|
### Quick Start
|
|
224
224
|
|
|
@@ -273,8 +273,6 @@ accelerate launch --num_machines 2 --main_process_ip <ip> -m wavedl.train --mode
|
|
|
273
273
|
|
|
274
274
|
### Testing & Inference
|
|
275
275
|
|
|
276
|
-
After training, use `wavedl-test` to evaluate your model on test data:
|
|
277
|
-
|
|
278
276
|
```bash
|
|
279
277
|
# Basic inference
|
|
280
278
|
wavedl-test --checkpoint <checkpoint_folder> --data_path <test_data>
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=CdjeIFEZh4ccwPiNsUxpac5l3pxwb9e1SmCNiLiYG1I,1177
|
|
2
2
|
wavedl/hpo.py,sha256=nEiy-2O_5EhxF5hU8X5TviSAiXfVrTQx0-VE6baW7JQ,14633
|
|
3
3
|
wavedl/launcher.py,sha256=_CFlgpKgHrtZebl1yQbJZJEcob06Y9-fqnRYzwW7UJQ,11776
|
|
4
|
-
wavedl/test.py,sha256=
|
|
5
|
-
wavedl/train.py,sha256=
|
|
6
|
-
wavedl/models/__init__.py,sha256=
|
|
4
|
+
wavedl/test.py,sha256=5MzBtEH2lWWYG23Fz-VpMFAWR5SfZbFomBbu8ptsZRU,39208
|
|
5
|
+
wavedl/train.py,sha256=DizXhi9BFL8heLmO8ENiNm2QubAMm9mdpDiaBlULeKM,57824
|
|
6
|
+
wavedl/models/__init__.py,sha256=hyR__h_D8PsUQCBSM5tj94yYK00uG8ABjEmj_RR8SGE,5719
|
|
7
7
|
wavedl/models/_pretrained_utils.py,sha256=VPdU1DwJB93ZBf_GFIgb8-6BbAt18Phs4yorwlhLw70,12404
|
|
8
8
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
9
9
|
wavedl/models/base.py,sha256=bDoHYFli-aR8amcFYXbF98QYaKSCEwZWpvOhN21ODro,9075
|
|
@@ -32,15 +32,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
|
32
32
|
wavedl/utils/config.py,sha256=MXkaVc1_zo8sDro8mjtK1MV65t2z8b1Z6fviwSorNiY,10534
|
|
33
33
|
wavedl/utils/constraints.py,sha256=V9Gyi8-uIMbLUWb2cOaHZD0SliWLxVrHZHFyo4HWK7g,18031
|
|
34
34
|
wavedl/utils/cross_validation.py,sha256=HfInyZ8gUROc_AyihYKzzUE0vnoPt_mFvAI2OPK4P54,17945
|
|
35
|
-
wavedl/utils/data.py,sha256=
|
|
35
|
+
wavedl/utils/data.py,sha256=HXod6i6g76oFAjLz7xepBPQEFHRgQ7E1M-YSKwUya-I,64799
|
|
36
36
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
37
37
|
wavedl/utils/losses.py,sha256=KWpU5S5noFzp3bLbcH9RNpkFPajy6fyTIh5cNjI-BYA,7038
|
|
38
38
|
wavedl/utils/metrics.py,sha256=YoqiXWOsUB9Y4_alj8CmHcTgnV4MFcH5PH4XlIC13HY,40304
|
|
39
39
|
wavedl/utils/optimizers.py,sha256=ZoETDSOK1fWUT2dx69PyYebeM8Vcqf9zOIKUERWk5HY,6107
|
|
40
40
|
wavedl/utils/schedulers.py,sha256=K6YCiyiMM9rb0cCRXTp89noXeXcAyUEiePr27O5Cozs,7408
|
|
41
|
-
wavedl-1.6.
|
|
42
|
-
wavedl-1.6.
|
|
43
|
-
wavedl-1.6.
|
|
44
|
-
wavedl-1.6.
|
|
45
|
-
wavedl-1.6.
|
|
46
|
-
wavedl-1.6.
|
|
41
|
+
wavedl-1.6.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
42
|
+
wavedl-1.6.3.dist-info/METADATA,sha256=BtPoAiMwvE58b-dmR0TfPPLLBqYzaAiJ2GBUd0FBSY0,47613
|
|
43
|
+
wavedl-1.6.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
44
|
+
wavedl-1.6.3.dist-info/entry_points.txt,sha256=NuAvdiG93EYYpqv-_1wf6PN0WqBfABanDKalNKe2GOs,148
|
|
45
|
+
wavedl-1.6.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
46
|
+
wavedl-1.6.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|