wavedl 1.4.3__py3-none-any.whl → 1.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.3"
21
+ __version__ = "1.4.5"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpc.py CHANGED
@@ -26,7 +26,6 @@ import os
26
26
  import shutil
27
27
  import subprocess
28
28
  import sys
29
- import tempfile
30
29
  from pathlib import Path
31
30
 
32
31
 
@@ -60,26 +59,40 @@ def setup_hpc_environment() -> None:
60
59
  Handles restricted home directories (e.g., Compute Canada) and
61
60
  offline logging configurations.
62
61
  """
63
- # Use SLURM_TMPDIR if available, otherwise system temp
64
- tmpdir = os.environ.get(
65
- "SLURM_TMPDIR", os.environ.get("TMPDIR", tempfile.gettempdir())
66
- )
67
-
68
- # Configure directories for systems with restricted home directories
69
- os.environ.setdefault("MPLCONFIGDIR", f"{tmpdir}/matplotlib")
70
- os.environ.setdefault(
71
- "FONTCONFIG_PATH", os.environ.get("FONTCONFIG_PATH", "/etc/fonts")
72
- )
73
- os.environ.setdefault("XDG_CACHE_HOME", f"{tmpdir}/.cache")
74
-
75
- # Ensure matplotlib config dir exists
76
- Path(os.environ["MPLCONFIGDIR"]).mkdir(parents=True, exist_ok=True)
62
+ # Check if home is writable
63
+ home = os.path.expanduser("~")
64
+ home_writable = os.access(home, os.W_OK)
65
+
66
+ # Use SLURM_TMPDIR if available, otherwise CWD for HPC, or system temp
67
+ if home_writable:
68
+ # Local machine - let libraries use defaults
69
+ cache_base = None
70
+ else:
71
+ # HPC with restricted home - use CWD for persistent caches
72
+ cache_base = os.getcwd()
73
+
74
+ # Only set environment variables if home is not writable
75
+ if cache_base:
76
+ os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
77
+ os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
78
+ os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
79
+ os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
80
+
81
+ # Ensure directories exist
82
+ for env_var in [
83
+ "TORCH_HOME",
84
+ "MPLCONFIGDIR",
85
+ "FONTCONFIG_CACHE",
86
+ "XDG_CACHE_HOME",
87
+ ]:
88
+ Path(os.environ[env_var]).mkdir(parents=True, exist_ok=True)
77
89
 
78
90
  # WandB configuration (offline by default for HPC)
79
91
  os.environ.setdefault("WANDB_MODE", "offline")
80
- os.environ.setdefault("WANDB_DIR", f"{tmpdir}/wandb")
81
- os.environ.setdefault("WANDB_CACHE_DIR", f"{tmpdir}/wandb_cache")
82
- os.environ.setdefault("WANDB_CONFIG_DIR", f"{tmpdir}/wandb_config")
92
+ if cache_base:
93
+ os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
94
+ os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
95
+ os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
83
96
 
84
97
  # Suppress non-critical warnings
85
98
  os.environ.setdefault(
wavedl/test.py CHANGED
@@ -29,29 +29,52 @@ Author: Ductho Le (ductho.le@outlook.com)
29
29
  # ==============================================================================
30
30
  # ENVIRONMENT CONFIGURATION (must be before matplotlib import)
31
31
  # ==============================================================================
32
+ # Auto-configure writable cache directories when home is not writable.
33
+ # Uses current working directory as fallback - works on HPC and local machines.
32
34
  import os
33
35
 
34
36
 
35
- os.environ.setdefault("MPLCONFIGDIR", os.getenv("TMPDIR", "/tmp") + "/matplotlib")
36
- os.environ.setdefault("FONTCONFIG_PATH", "/etc/fonts")
37
-
38
- import argparse
39
- import logging
40
- import pickle
41
- from pathlib import Path
42
-
43
- import matplotlib.pyplot as plt
44
- import numpy as np
45
- import pandas as pd
46
- import torch
47
- import torch.nn as nn
48
- from sklearn.metrics import mean_absolute_error, r2_score
49
- from torch.utils.data import DataLoader, TensorDataset
50
- from tqdm.auto import tqdm
37
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
38
+ """Set cache directory to CWD if home is not writable."""
39
+ if env_var in os.environ:
40
+ return # User already set, respect their choice
41
+
42
+ # Check if home is writable
43
+ home = os.path.expanduser("~")
44
+ if os.access(home, os.W_OK):
45
+ return # Home is writable, let library use defaults
46
+
47
+ # Home not writable - use current working directory
48
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
49
+ os.makedirs(cache_path, exist_ok=True)
50
+ os.environ[env_var] = cache_path
51
+
52
+
53
+ # Configure cache directories (before any library imports)
54
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
55
+ _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
56
+ _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
57
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
58
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
59
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
60
+
61
+ import argparse # noqa: E402
62
+ import logging # noqa: E402
63
+ import pickle # noqa: E402
64
+ from pathlib import Path # noqa: E402
65
+
66
+ import matplotlib.pyplot as plt # noqa: E402
67
+ import numpy as np # noqa: E402
68
+ import pandas as pd # noqa: E402
69
+ import torch # noqa: E402
70
+ import torch.nn as nn # noqa: E402
71
+ from sklearn.metrics import mean_absolute_error, r2_score # noqa: E402
72
+ from torch.utils.data import DataLoader, TensorDataset # noqa: E402
73
+ from tqdm.auto import tqdm # noqa: E402
51
74
 
52
75
  # Local imports
53
- from wavedl.models import build_model, list_models
54
- from wavedl.utils import (
76
+ from wavedl.models import build_model, list_models # noqa: E402
77
+ from wavedl.utils import ( # noqa: E402
55
78
  FIGURE_DPI,
56
79
  calc_pearson,
57
80
  load_test_data,
@@ -356,10 +379,18 @@ def load_checkpoint(
356
379
  else:
357
380
  state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
358
381
 
359
- # Remove 'module.' prefix from DDP checkpoints (leading only, not all occurrences)
360
- state_dict = {
361
- (k[7:] if k.startswith("module.") else k): v for k, v in state_dict.items()
362
- }
382
+ # Remove wrapper prefixes from checkpoints:
383
+ # - 'module.' from DDP (DistributedDataParallel)
384
+ # - '_orig_mod.' from torch.compile()
385
+ cleaned_dict = {}
386
+ for k, v in state_dict.items():
387
+ key = k
388
+ if key.startswith("module."):
389
+ key = key[7:] # Remove 'module.' (7 chars)
390
+ if key.startswith("_orig_mod."):
391
+ key = key[10:] # Remove '_orig_mod.' (10 chars)
392
+ cleaned_dict[key] = v
393
+ state_dict = cleaned_dict
363
394
 
364
395
  model.load_state_dict(state_dict)
365
396
  model.eval()
wavedl/train.py CHANGED
@@ -40,45 +40,34 @@ from __future__ import annotations
40
40
  # =============================================================================
41
41
  # HPC Environment Setup (MUST be before any library imports)
42
42
  # =============================================================================
43
- # Set writable cache directories for matplotlib and fontconfig ONLY when
44
- # the default paths are not writable (common on HPC clusters).
43
+ # Auto-configure writable cache directories when home is not writable.
44
+ # Uses current working directory as fallback - works on HPC and local machines.
45
45
  import os
46
- import tempfile
47
46
 
48
47
 
49
- def _setup_cache_dir(env_var: str, default_subpath: str) -> None:
50
- """Set cache directory only if default path is not writable."""
48
+ def _setup_cache_dir(env_var: str, subdir: str) -> None:
49
+ """Set cache directory to CWD if home is not writable."""
51
50
  if env_var in os.environ:
52
51
  return # User already set, respect their choice
53
52
 
54
- # Check if default home config path is writable
53
+ # Check if home is writable
55
54
  home = os.path.expanduser("~")
56
- default_path = os.path.join(home, ".config", default_subpath)
57
- default_parent = os.path.dirname(default_path)
55
+ if os.access(home, os.W_OK):
56
+ return # Home is writable, let library use defaults
58
57
 
59
- # If default path or its parent is writable, let the library use defaults
60
- if (
61
- os.access(default_path, os.W_OK)
62
- or (os.path.exists(default_parent) and os.access(default_parent, os.W_OK))
63
- or os.access(os.path.join(home, ".config"), os.W_OK)
64
- ):
65
- return
66
-
67
- # Default not writable - find alternative location
68
- for cache_base in [
69
- os.environ.get("SCRATCH"),
70
- os.environ.get("SLURM_TMPDIR"),
71
- tempfile.gettempdir(),
72
- ]:
73
- if cache_base and os.access(cache_base, os.W_OK):
74
- cache_path = os.path.join(cache_base, f".{default_subpath}")
75
- os.makedirs(cache_path, exist_ok=True)
76
- os.environ[env_var] = cache_path
77
- return
58
+ # Home not writable - use current working directory
59
+ cache_path = os.path.join(os.getcwd(), f".{subdir}")
60
+ os.makedirs(cache_path, exist_ok=True)
61
+ os.environ[env_var] = cache_path
78
62
 
79
63
 
64
+ # Configure cache directories (before any library imports)
65
+ _setup_cache_dir("TORCH_HOME", "torch_cache")
80
66
  _setup_cache_dir("MPLCONFIGDIR", "matplotlib")
81
67
  _setup_cache_dir("FONTCONFIG_CACHE", "fontconfig")
68
+ _setup_cache_dir("XDG_DATA_HOME", "local/share")
69
+ _setup_cache_dir("XDG_STATE_HOME", "local/state")
70
+ _setup_cache_dir("XDG_CACHE_HOME", "cache")
82
71
 
83
72
  # =============================================================================
84
73
  # Standard imports (after environment setup)
@@ -1065,7 +1054,17 @@ def main():
1065
1054
  f,
1066
1055
  )
1067
1056
 
1068
- unwrapped = accelerator.unwrap_model(model)
1057
+ # Unwrap model for saving (handle torch.compile compatibility)
1058
+ try:
1059
+ unwrapped = accelerator.unwrap_model(model)
1060
+ except KeyError:
1061
+ # torch.compile model may not have _orig_mod in expected location
1062
+ # Fall back to getting the module directly
1063
+ unwrapped = model.module if hasattr(model, "module") else model
1064
+ # If still compiled, try to get the underlying model
1065
+ if hasattr(unwrapped, "_orig_mod"):
1066
+ unwrapped = unwrapped._orig_mod
1067
+
1069
1068
  torch.save(
1070
1069
  unwrapped.state_dict(),
1071
1070
  os.path.join(args.output_dir, "best_model_weights.pth"),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.3
3
+ Version: 1.4.5
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -462,7 +462,43 @@ WaveDL/
462
462
  | **U-Net** — U-shaped Network |||
463
463
  | `unet_regression` | 31.1M | 1D/2D/3D |
464
464
 
465
- > ⭐ = Pretrained on ImageNet. Recommended for smaller datasets.
465
+ ⭐ = **Pretrained on ImageNet** (recommended for smaller datasets). Weights are downloaded automatically on first use.
466
+ - **Cache location**: `~/.cache/torch/hub/checkpoints/` (or `./.torch_cache/` on HPC if home is not writable)
467
+ - **Size**: ~20–350 MB per model depending on architecture
468
+
469
+ **💡 HPC Users**: If compute nodes block internet, pre-download weights on the login node:
470
+
471
+ ```bash
472
+ # Run once on login node (with internet) — downloads ALL pretrained weights (~1.5 GB total)
473
+ python -c "
474
+ import os
475
+ os.environ['TORCH_HOME'] = '.torch_cache' # Match WaveDL's HPC cache location
476
+
477
+ from torchvision import models as m
478
+ from torchvision.models import video as v
479
+
480
+ # Model name -> Weights class mapping
481
+ weights = {
482
+ 'resnet18': m.ResNet18_Weights, 'resnet50': m.ResNet50_Weights,
483
+ 'efficientnet_b0': m.EfficientNet_B0_Weights, 'efficientnet_b1': m.EfficientNet_B1_Weights,
484
+ 'efficientnet_b2': m.EfficientNet_B2_Weights, 'efficientnet_v2_s': m.EfficientNet_V2_S_Weights,
485
+ 'efficientnet_v2_m': m.EfficientNet_V2_M_Weights, 'efficientnet_v2_l': m.EfficientNet_V2_L_Weights,
486
+ 'mobilenet_v3_small': m.MobileNet_V3_Small_Weights, 'mobilenet_v3_large': m.MobileNet_V3_Large_Weights,
487
+ 'regnet_y_400mf': m.RegNet_Y_400MF_Weights, 'regnet_y_800mf': m.RegNet_Y_800MF_Weights,
488
+ 'regnet_y_1_6gf': m.RegNet_Y_1_6GF_Weights, 'regnet_y_3_2gf': m.RegNet_Y_3_2GF_Weights,
489
+ 'regnet_y_8gf': m.RegNet_Y_8GF_Weights, 'swin_t': m.Swin_T_Weights, 'swin_s': m.Swin_S_Weights,
490
+ 'swin_b': m.Swin_B_Weights, 'convnext_tiny': m.ConvNeXt_Tiny_Weights, 'densenet121': m.DenseNet121_Weights,
491
+ }
492
+ for name, w in weights.items():
493
+ getattr(m, name)(weights=w.DEFAULT); print(f'✓ {name}')
494
+
495
+ # 3D video models
496
+ v.r3d_18(weights=v.R3D_18_Weights.DEFAULT); print('✓ r3d_18')
497
+ v.mc3_18(weights=v.MC3_18_Weights.DEFAULT); print('✓ mc3_18')
498
+ print('\\n✓ All pretrained weights cached!')
499
+ "
500
+ ```
501
+
466
502
 
467
503
  </details>
468
504
 
@@ -687,7 +723,6 @@ compile: false
687
723
  seed: 2025
688
724
  ```
689
725
 
690
- > [!TIP]
691
726
  > See [`configs/config.yaml`](configs/config.yaml) for the complete template with all available options documented.
692
727
 
693
728
  </details>
@@ -753,7 +788,7 @@ accelerate launch -m wavedl.train --data_path train.npz --model cnn --lr 3.2e-4
753
788
  | `--max_epochs` | `50` | Max epochs per trial |
754
789
  | `--output` | `hpo_results.json` | Output file |
755
790
 
756
- > [!TIP]
791
+
757
792
  > See [Available Models](#available-models) for all 38 architectures you can search.
758
793
 
759
794
  </details>
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=7uvgr9r21qLGu6RIZDUfQBhg1vXNAfmIWl_P6BMj_KQ,1177
2
- wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
1
+ wavedl/__init__.py,sha256=2ro7SYQ3wCmq-ejiAm5sd6BeXf6sZgixC9U2vS7Ckbs,1177
2
+ wavedl/hpc.py,sha256=0h8IZzOT0EzmEv3fU9cKyRVE9V1ivtBzbjuBCaxYadc,8445
3
3
  wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
- wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
5
- wavedl/train.py,sha256=TjLABBPCqu9r7FEWlxJlKsT7uAMo6hiDxRfii2SKXe4,49052
4
+ wavedl/test.py,sha256=81al6vQBDAJ3CpSEtxZn6xzR1c4-jo28R7tX_84KROc,37642
5
+ wavedl/train.py,sha256=_pW7prvlNqfUGrGweHO2QelS87UiAYKvyJwqMAIj6yI,49292
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -29,9 +29,9 @@ wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
29
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
30
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
31
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.3.dist-info/METADATA,sha256=vs3nt8R2O5lD7q-si9M5ChyrriIKm0fFzDwi4HVIYxw,40386
34
- wavedl-1.4.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.3.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.3.dist-info/RECORD,,
32
+ wavedl-1.4.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
+ wavedl-1.4.5.dist-info/METADATA,sha256=4ltxFDaqPqh4XUAW_K8nkFmvqBzPcL2cxmghH11GMWg,42191
34
+ wavedl-1.4.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
+ wavedl-1.4.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
+ wavedl-1.4.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
+ wavedl-1.4.5.dist-info/RECORD,,
File without changes