wavedl 1.4.2__tar.gz → 1.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {wavedl-1.4.2/src/wavedl.egg-info → wavedl-1.4.3}/PKG-INFO +2 -1
  2. {wavedl-1.4.2 → wavedl-1.4.3}/README.md +1 -0
  3. {wavedl-1.4.2 → wavedl-1.4.3}/pyproject.toml +1 -1
  4. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/__init__.py +1 -1
  5. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/train.py +60 -7
  6. {wavedl-1.4.2 → wavedl-1.4.3/src/wavedl.egg-info}/PKG-INFO +2 -1
  7. {wavedl-1.4.2 → wavedl-1.4.3}/LICENSE +0 -0
  8. {wavedl-1.4.2 → wavedl-1.4.3}/setup.cfg +0 -0
  9. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/hpc.py +0 -0
  10. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/hpo.py +0 -0
  11. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/__init__.py +0 -0
  12. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/_template.py +0 -0
  13. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/base.py +0 -0
  14. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/cnn.py +0 -0
  15. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/convnext.py +0 -0
  16. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/densenet.py +0 -0
  17. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/efficientnet.py +0 -0
  18. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/efficientnetv2.py +0 -0
  19. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/mobilenetv3.py +0 -0
  20. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/registry.py +0 -0
  21. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/regnet.py +0 -0
  22. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/resnet.py +0 -0
  23. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/resnet3d.py +0 -0
  24. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/swin.py +0 -0
  25. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/tcn.py +0 -0
  26. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/unet.py +0 -0
  27. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/models/vit.py +0 -0
  28. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/test.py +0 -0
  29. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/__init__.py +0 -0
  30. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/config.py +0 -0
  31. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/cross_validation.py +0 -0
  32. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/data.py +0 -0
  33. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/distributed.py +0 -0
  34. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/losses.py +0 -0
  35. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/metrics.py +0 -0
  36. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/optimizers.py +0 -0
  37. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl/utils/schedulers.py +0 -0
  38. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl.egg-info/SOURCES.txt +0 -0
  39. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl.egg-info/dependency_links.txt +0 -0
  40. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl.egg-info/entry_points.txt +0 -0
  41. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl.egg-info/requires.txt +0 -0
  42. {wavedl-1.4.2 → wavedl-1.4.3}/src/wavedl.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.2
3
+ Version: 1.4.3
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -57,6 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
+ [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
61
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
62
63
 
@@ -12,6 +12,7 @@
12
12
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
13
13
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
14
14
  <br>
15
+ [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
15
16
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
16
17
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
17
18
 
@@ -175,7 +175,7 @@ lines-after-imports = 2
175
175
  # Allow assert statements in tests
176
176
  "unit_tests/*" = ["B011", "S101"]
177
177
  # Allow unused imports for availability checks
178
- "src/wavedl/train.py" = ["F401"]
178
+ "src/wavedl/train.py" = ["F401", "E402"]
179
179
  "src/wavedl/test.py" = ["F401"]
180
180
  "src/wavedl/utils/data.py" = ["F401"]
181
181
 
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.2"
21
+ __version__ = "1.4.3"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
@@ -37,9 +37,54 @@ Author: Ductho Le (ductho.le@outlook.com)
37
37
 
38
38
  from __future__ import annotations
39
39
 
40
+ # =============================================================================
41
+ # HPC Environment Setup (MUST be before any library imports)
42
+ # =============================================================================
43
+ # Set writable cache directories for matplotlib and fontconfig ONLY when
44
+ # the default paths are not writable (common on HPC clusters).
45
+ import os
46
+ import tempfile
47
+
48
+
49
+ def _setup_cache_dir(env_var: str, default_subpath: str) -> None:
50
+ """Set cache directory only if default path is not writable."""
51
+ if env_var in os.environ:
52
+ return # User already set, respect their choice
53
+
54
+ # Check if default home config path is writable
55
+ home = os.path.expanduser("~")
56
+ default_path = os.path.join(home, ".config", default_subpath)
57
+ default_parent = os.path.dirname(default_path)
58
+
59
+ # If default path or its parent is writable, let the library use defaults
60
+ if (
61
+ os.access(default_path, os.W_OK)
62
+ or (os.path.exists(default_parent) and os.access(default_parent, os.W_OK))
63
+ or os.access(os.path.join(home, ".config"), os.W_OK)
64
+ ):
65
+ return
66
+
67
+ # Default not writable - find alternative location
68
+ for cache_base in [
69
+ os.environ.get("SCRATCH"),
70
+ os.environ.get("SLURM_TMPDIR"),
71
+ tempfile.gettempdir(),
72
+ ]:
73
+ if cache_base and os.access(cache_base, os.W_OK):
74
+ cache_path = os.path.join(cache_base, f".{default_subpath}")
75
+ os.makedirs(cache_path, exist_ok=True)
76
+ os.environ[env_var] = cache_path
77
+ return
78
+
79
+
80
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
81
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
82
+
83
+ # =============================================================================
84
+ # Standard imports (after environment setup)
85
+ # =============================================================================
40
86
  import argparse
41
87
  import logging
42
- import os
43
88
  import pickle
44
89
  import shutil
45
90
  import sys
@@ -47,6 +92,10 @@ import time
47
92
  import warnings
48
93
  from typing import Any
49
94
 
95
+
96
+ # Suppress Pydantic warnings from accelerate's internal Field() usage
97
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
98
+
50
99
  import matplotlib.pyplot as plt
51
100
  import numpy as np
52
101
  import pandas as pd
@@ -582,9 +631,9 @@ def main():
582
631
  # Torch 2.0 compilation (requires compatible Triton on GPU)
583
632
  if args.compile:
584
633
  try:
585
- # Test if Triton is available AND compatible with this PyTorch version
586
- # PyTorch needs triton_key from triton.compiler.compiler
587
- from triton.compiler.compiler import triton_key
634
+ # Test if Triton is available - just import the package
635
+ # Different Triton versions have different internal APIs, so just check base import
636
+ import triton
588
637
 
589
638
  model = torch.compile(model)
590
639
  if accelerator.is_main_process:
@@ -875,9 +924,13 @@ def main():
875
924
  cpu_preds = torch.cat(local_preds)
876
925
  cpu_targets = torch.cat(local_targets)
877
926
 
878
- # Gather to rank 0 only via gather_object (avoids all-gather to every rank)
879
- # gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
880
- gathered = accelerator.gather_object((cpu_preds, cpu_targets))
927
+ # Gather predictions and targets across all ranks
928
+ # Use accelerator.gather (works with all accelerate versions)
929
+ gpu_preds = cpu_preds.to(accelerator.device)
930
+ gpu_targets = cpu_targets.to(accelerator.device)
931
+ all_preds_gathered = accelerator.gather(gpu_preds).cpu()
932
+ all_targets_gathered = accelerator.gather(gpu_targets).cpu()
933
+ gathered = [(all_preds_gathered, all_targets_gathered)]
881
934
 
882
935
  # Synchronize validation metrics (scalars only - efficient)
883
936
  val_loss_scalar = val_loss_sum.item()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.2
3
+ Version: 1.4.3
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -57,6 +57,7 @@ Requires-Dist: triton>=2.0.0; sys_platform == "linux"
57
57
  [![Lint](https://img.shields.io/github/actions/workflow/status/ductho-le/WaveDL/lint.yml?branch=main&style=plastic&logo=ruff&logoColor=white&label=Lint)](https://github.com/ductho-le/WaveDL/actions/workflows/lint.yml)
58
58
  [![Try it on Colab](https://img.shields.io/badge/Try_it_on_Colab-8E44AD?style=plastic&logo=googlecolab&logoColor=white)](https://colab.research.google.com/github/ductho-le/WaveDL/blob/main/notebooks/demo.ipynb)
59
59
  <br>
60
+ [![Downloads](https://img.shields.io/pepy/dt/wavedl?style=plastic&logo=pypi&logoColor=white&color=9ACD32)](https://pepy.tech/project/wavedl)
60
61
  [![License: MIT](https://img.shields.io/badge/License-MIT-orange.svg?style=plastic)](LICENSE)
61
62
  [![DOI](https://img.shields.io/badge/DOI-10.5281/zenodo.18012338-008080.svg?style=plastic)](https://doi.org/10.5281/zenodo.18012338)
62
63
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes