wavedl 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpc.py +22 -18
- wavedl/models/resnet.py +38 -9
- wavedl/train.py +57 -55
- wavedl/utils/config.py +3 -1
- wavedl/utils/cross_validation.py +11 -0
- wavedl/utils/data.py +51 -2
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/METADATA +23 -19
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/RECORD +13 -13
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/LICENSE +0 -0
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/WHEEL +0 -0
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.2.dist-info → wavedl-1.5.3.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpc.py
CHANGED
|
@@ -57,30 +57,35 @@ def setup_hpc_environment() -> None:
|
|
|
57
57
|
"""Configure environment variables for HPC systems.
|
|
58
58
|
|
|
59
59
|
Handles restricted home directories (e.g., Compute Canada) and
|
|
60
|
-
offline logging configurations.
|
|
60
|
+
offline logging configurations. Always uses CWD-based TORCH_HOME
|
|
61
|
+
since compute nodes typically lack internet access.
|
|
61
62
|
"""
|
|
62
|
-
#
|
|
63
|
+
# Use CWD for cache base since HPC compute nodes typically lack internet
|
|
64
|
+
cache_base = os.getcwd()
|
|
65
|
+
|
|
66
|
+
# TORCH_HOME always set to CWD - compute nodes need pre-cached weights
|
|
67
|
+
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
68
|
+
Path(os.environ["TORCH_HOME"]).mkdir(parents=True, exist_ok=True)
|
|
69
|
+
|
|
70
|
+
# Triton/Inductor caches - prevents permission errors with --compile
|
|
71
|
+
# These MUST be set before any torch.compile calls
|
|
72
|
+
os.environ.setdefault("TRITON_CACHE_DIR", f"{cache_base}/.triton_cache")
|
|
73
|
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{cache_base}/.inductor_cache")
|
|
74
|
+
Path(os.environ["TRITON_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
75
|
+
Path(os.environ["TORCHINDUCTOR_CACHE_DIR"]).mkdir(parents=True, exist_ok=True)
|
|
76
|
+
|
|
77
|
+
# Check if home is writable for other caches
|
|
63
78
|
home = os.path.expanduser("~")
|
|
64
79
|
home_writable = os.access(home, os.W_OK)
|
|
65
80
|
|
|
66
|
-
#
|
|
67
|
-
if home_writable:
|
|
68
|
-
# Local machine - let libraries use defaults
|
|
69
|
-
cache_base = None
|
|
70
|
-
else:
|
|
71
|
-
# HPC with restricted home - use CWD for persistent caches
|
|
72
|
-
cache_base = os.getcwd()
|
|
73
|
-
|
|
74
|
-
# Only set environment variables if home is not writable
|
|
75
|
-
if cache_base:
|
|
76
|
-
os.environ.setdefault("TORCH_HOME", f"{cache_base}/.torch_cache")
|
|
81
|
+
# Other caches only if home is not writable
|
|
82
|
+
if not home_writable:
|
|
77
83
|
os.environ.setdefault("MPLCONFIGDIR", f"{cache_base}/.matplotlib")
|
|
78
84
|
os.environ.setdefault("FONTCONFIG_CACHE", f"{cache_base}/.fontconfig")
|
|
79
85
|
os.environ.setdefault("XDG_CACHE_HOME", f"{cache_base}/.cache")
|
|
80
86
|
|
|
81
87
|
# Ensure directories exist
|
|
82
88
|
for env_var in [
|
|
83
|
-
"TORCH_HOME",
|
|
84
89
|
"MPLCONFIGDIR",
|
|
85
90
|
"FONTCONFIG_CACHE",
|
|
86
91
|
"XDG_CACHE_HOME",
|
|
@@ -89,10 +94,9 @@ def setup_hpc_environment() -> None:
|
|
|
89
94
|
|
|
90
95
|
# WandB configuration (offline by default for HPC)
|
|
91
96
|
os.environ.setdefault("WANDB_MODE", "offline")
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
97
|
+
os.environ.setdefault("WANDB_DIR", f"{cache_base}/.wandb")
|
|
98
|
+
os.environ.setdefault("WANDB_CACHE_DIR", f"{cache_base}/.wandb_cache")
|
|
99
|
+
os.environ.setdefault("WANDB_CONFIG_DIR", f"{cache_base}/.wandb_config")
|
|
96
100
|
|
|
97
101
|
# Suppress non-critical warnings
|
|
98
102
|
os.environ.setdefault(
|
wavedl/models/resnet.py
CHANGED
|
@@ -49,6 +49,36 @@ def _get_conv_layers(
|
|
|
49
49
|
raise ValueError(f"Unsupported dimensionality: {dim}D. Supported: 1D, 2D, 3D.")
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
def _get_num_groups(num_channels: int, preferred_groups: int = 32) -> int:
|
|
53
|
+
"""
|
|
54
|
+
Get valid num_groups for GroupNorm that divides num_channels evenly.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
num_channels: Number of channels to normalize
|
|
58
|
+
preferred_groups: Preferred number of groups (default: 32)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Valid num_groups that divides num_channels
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If no valid divisor found (shouldn't happen with power-of-2 channels)
|
|
65
|
+
"""
|
|
66
|
+
# Try preferred groups first, then decrease
|
|
67
|
+
for groups in [preferred_groups, 16, 8, 4, 2, 1]:
|
|
68
|
+
if groups <= num_channels and num_channels % groups == 0:
|
|
69
|
+
return groups
|
|
70
|
+
|
|
71
|
+
# Fallback: find any valid divisor
|
|
72
|
+
for groups in range(min(32, num_channels), 0, -1):
|
|
73
|
+
if num_channels % groups == 0:
|
|
74
|
+
return groups
|
|
75
|
+
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Cannot find valid num_groups for {num_channels} channels. "
|
|
78
|
+
f"Consider using base_width that is a power of 2 (e.g., 32, 64, 128)."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
52
82
|
class BasicBlock(nn.Module):
|
|
53
83
|
"""
|
|
54
84
|
Basic residual block for ResNet-18/34.
|
|
@@ -77,12 +107,12 @@ class BasicBlock(nn.Module):
|
|
|
77
107
|
padding=1,
|
|
78
108
|
bias=False,
|
|
79
109
|
)
|
|
80
|
-
self.gn1 = nn.GroupNorm(
|
|
110
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
81
111
|
self.relu = nn.ReLU(inplace=True)
|
|
82
112
|
self.conv2 = Conv(
|
|
83
113
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
|
|
84
114
|
)
|
|
85
|
-
self.gn2 = nn.GroupNorm(
|
|
115
|
+
self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
86
116
|
self.downsample = downsample
|
|
87
117
|
|
|
88
118
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -125,7 +155,7 @@ class Bottleneck(nn.Module):
|
|
|
125
155
|
|
|
126
156
|
# 1x1 reduce
|
|
127
157
|
self.conv1 = Conv(in_channels, out_channels, kernel_size=1, bias=False)
|
|
128
|
-
self.gn1 = nn.GroupNorm(
|
|
158
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
129
159
|
|
|
130
160
|
# 3x3 conv
|
|
131
161
|
self.conv2 = Conv(
|
|
@@ -136,15 +166,14 @@ class Bottleneck(nn.Module):
|
|
|
136
166
|
padding=1,
|
|
137
167
|
bias=False,
|
|
138
168
|
)
|
|
139
|
-
self.gn2 = nn.GroupNorm(
|
|
169
|
+
self.gn2 = nn.GroupNorm(_get_num_groups(out_channels), out_channels)
|
|
140
170
|
|
|
141
171
|
# 1x1 expand
|
|
142
172
|
self.conv3 = Conv(
|
|
143
173
|
out_channels, out_channels * self.expansion, kernel_size=1, bias=False
|
|
144
174
|
)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
175
|
+
expanded_channels = out_channels * self.expansion
|
|
176
|
+
self.gn3 = nn.GroupNorm(_get_num_groups(expanded_channels), expanded_channels)
|
|
148
177
|
|
|
149
178
|
self.relu = nn.ReLU(inplace=True)
|
|
150
179
|
self.downsample = downsample
|
|
@@ -200,7 +229,7 @@ class ResNetBase(BaseModel):
|
|
|
200
229
|
|
|
201
230
|
# Stem: 7x7 conv (or equivalent for 1D/3D)
|
|
202
231
|
self.conv1 = Conv(1, base_width, kernel_size=7, stride=2, padding=3, bias=False)
|
|
203
|
-
self.gn1 = nn.GroupNorm(
|
|
232
|
+
self.gn1 = nn.GroupNorm(_get_num_groups(base_width), base_width)
|
|
204
233
|
self.relu = nn.ReLU(inplace=True)
|
|
205
234
|
self.maxpool = MaxPool(kernel_size=3, stride=2, padding=1)
|
|
206
235
|
|
|
@@ -246,7 +275,7 @@ class ResNetBase(BaseModel):
|
|
|
246
275
|
bias=False,
|
|
247
276
|
),
|
|
248
277
|
nn.GroupNorm(
|
|
249
|
-
|
|
278
|
+
_get_num_groups(out_channels * block.expansion),
|
|
250
279
|
out_channels * block.expansion,
|
|
251
280
|
),
|
|
252
281
|
)
|
wavedl/train.py
CHANGED
|
@@ -69,6 +69,39 @@ _setup_cache_dir("XDG_DATA_HOME", "local/share")
|
|
|
69
69
|
_setup_cache_dir("XDG_STATE_HOME", "local/state")
|
|
70
70
|
_setup_cache_dir("XDG_CACHE_HOME", "cache")
|
|
71
71
|
|
|
72
|
+
|
|
73
|
+
def _setup_per_rank_compile_cache() -> None:
|
|
74
|
+
"""Set per-GPU Triton/Inductor cache to prevent multi-process race warnings.
|
|
75
|
+
|
|
76
|
+
When using torch.compile with multiple GPUs, all processes try to write to
|
|
77
|
+
the same cache directory, causing 'Directory is not empty - skipping!' warnings.
|
|
78
|
+
This gives each GPU rank its own isolated cache subdirectory.
|
|
79
|
+
"""
|
|
80
|
+
# Get local rank from environment (set by accelerate/torchrun)
|
|
81
|
+
local_rank = os.environ.get("LOCAL_RANK", "0")
|
|
82
|
+
|
|
83
|
+
# Get cache base from environment or use CWD
|
|
84
|
+
cache_base = os.environ.get(
|
|
85
|
+
"TRITON_CACHE_DIR", os.path.join(os.getcwd(), ".triton_cache")
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Set per-rank cache directories
|
|
89
|
+
os.environ["TRITON_CACHE_DIR"] = os.path.join(cache_base, f"rank_{local_rank}")
|
|
90
|
+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = os.path.join(
|
|
91
|
+
os.environ.get(
|
|
92
|
+
"TORCHINDUCTOR_CACHE_DIR", os.path.join(os.getcwd(), ".inductor_cache")
|
|
93
|
+
),
|
|
94
|
+
f"rank_{local_rank}",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Create directories
|
|
98
|
+
os.makedirs(os.environ["TRITON_CACHE_DIR"], exist_ok=True)
|
|
99
|
+
os.makedirs(os.environ["TORCHINDUCTOR_CACHE_DIR"], exist_ok=True)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Setup per-rank compile caches (before torch imports)
|
|
103
|
+
_setup_per_rank_compile_cache()
|
|
104
|
+
|
|
72
105
|
# =============================================================================
|
|
73
106
|
# Standard imports (after environment setup)
|
|
74
107
|
# =============================================================================
|
|
@@ -908,7 +941,6 @@ def main():
|
|
|
908
941
|
logger.info("=" * len(header))
|
|
909
942
|
|
|
910
943
|
try:
|
|
911
|
-
time.time()
|
|
912
944
|
total_training_time = 0.0
|
|
913
945
|
|
|
914
946
|
for epoch in range(start_epoch, args.epochs):
|
|
@@ -1002,49 +1034,29 @@ def main():
|
|
|
1002
1034
|
local_preds.append(pred.detach().cpu())
|
|
1003
1035
|
local_targets.append(y.detach().cpu())
|
|
1004
1036
|
|
|
1005
|
-
# Concatenate locally on
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
# Gather predictions and targets
|
|
1010
|
-
#
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
torch.zeros_like(gpu_targets)
|
|
1025
|
-
for _ in range(accelerator.num_processes)
|
|
1026
|
-
]
|
|
1027
|
-
torch.distributed.gather(
|
|
1028
|
-
gpu_preds, gather_list=all_preds_list, dst=0
|
|
1029
|
-
)
|
|
1030
|
-
torch.distributed.gather(
|
|
1031
|
-
gpu_targets, gather_list=all_targets_list, dst=0
|
|
1032
|
-
)
|
|
1033
|
-
# Move back to CPU for metric computation
|
|
1034
|
-
gathered = [
|
|
1035
|
-
(
|
|
1036
|
-
torch.cat(all_preds_list).cpu(),
|
|
1037
|
-
torch.cat(all_targets_list).cpu(),
|
|
1038
|
-
)
|
|
1039
|
-
]
|
|
1040
|
-
else:
|
|
1041
|
-
# Other ranks: send to rank 0, don't allocate gather buffers
|
|
1042
|
-
torch.distributed.gather(gpu_preds, gather_list=None, dst=0)
|
|
1043
|
-
torch.distributed.gather(gpu_targets, gather_list=None, dst=0)
|
|
1044
|
-
gathered = [(cpu_preds, cpu_targets)] # Placeholder, not used
|
|
1037
|
+
# Concatenate locally (keep on GPU for gather_for_metrics compatibility)
|
|
1038
|
+
local_preds_cat = torch.cat(local_preds)
|
|
1039
|
+
local_targets_cat = torch.cat(local_targets)
|
|
1040
|
+
|
|
1041
|
+
# Gather predictions and targets using Accelerate's CPU-efficient utility
|
|
1042
|
+
# gather_for_metrics handles:
|
|
1043
|
+
# - DDP padding removal (no need to trim manually)
|
|
1044
|
+
# - Efficient cross-rank gathering without GPU memory spike
|
|
1045
|
+
# - Returns concatenated tensors on CPU for metric computation
|
|
1046
|
+
if accelerator.num_processes > 1:
|
|
1047
|
+
# Move to GPU for gather (required by NCCL), then back to CPU
|
|
1048
|
+
# gather_for_metrics is more memory-efficient than manual gather
|
|
1049
|
+
# as it processes in chunks internally
|
|
1050
|
+
gathered_preds = accelerator.gather_for_metrics(
|
|
1051
|
+
local_preds_cat.to(accelerator.device)
|
|
1052
|
+
).cpu()
|
|
1053
|
+
gathered_targets = accelerator.gather_for_metrics(
|
|
1054
|
+
local_targets_cat.to(accelerator.device)
|
|
1055
|
+
).cpu()
|
|
1045
1056
|
else:
|
|
1046
1057
|
# Single-GPU mode: no gathering needed
|
|
1047
|
-
|
|
1058
|
+
gathered_preds = local_preds_cat
|
|
1059
|
+
gathered_targets = local_targets_cat
|
|
1048
1060
|
|
|
1049
1061
|
# Synchronize validation metrics (scalars only - efficient)
|
|
1050
1062
|
val_loss_scalar = val_loss_sum.item()
|
|
@@ -1069,20 +1081,10 @@ def main():
|
|
|
1069
1081
|
|
|
1070
1082
|
# ==================== LOGGING & CHECKPOINTING ====================
|
|
1071
1083
|
if accelerator.is_main_process:
|
|
1072
|
-
# Concatenate gathered tensors from all ranks (only on rank 0)
|
|
1073
|
-
# gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
|
|
1074
|
-
all_preds = torch.cat([item[0] for item in gathered])
|
|
1075
|
-
all_targets = torch.cat([item[1] for item in gathered])
|
|
1076
|
-
|
|
1077
1084
|
# Scientific metrics - cast to float32 before numpy
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
# Trim DDP padding
|
|
1082
|
-
real_len = len(val_dl.dataset)
|
|
1083
|
-
if len(y_pred) > real_len:
|
|
1084
|
-
y_pred = y_pred[:real_len]
|
|
1085
|
-
y_true = y_true[:real_len]
|
|
1085
|
+
# gather_for_metrics already handles DDP padding removal
|
|
1086
|
+
y_pred = gathered_preds.float().numpy()
|
|
1087
|
+
y_true = gathered_targets.float().numpy()
|
|
1086
1088
|
|
|
1087
1089
|
# Guard against tiny validation sets (R² undefined for <2 samples)
|
|
1088
1090
|
if len(y_true) >= 2:
|
wavedl/utils/config.py
CHANGED
|
@@ -183,9 +183,11 @@ def save_config(
|
|
|
183
183
|
config[key] = value
|
|
184
184
|
|
|
185
185
|
# Add metadata
|
|
186
|
+
from wavedl import __version__
|
|
187
|
+
|
|
186
188
|
config["_metadata"] = {
|
|
187
189
|
"saved_at": datetime.now().isoformat(),
|
|
188
|
-
"wavedl_version":
|
|
190
|
+
"wavedl_version": __version__,
|
|
189
191
|
}
|
|
190
192
|
|
|
191
193
|
output_path = Path(output_path)
|
wavedl/utils/cross_validation.py
CHANGED
|
@@ -337,6 +337,17 @@ def run_cross_validation(
|
|
|
337
337
|
torch.cuda.manual_seed_all(seed)
|
|
338
338
|
|
|
339
339
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
340
|
+
|
|
341
|
+
# Auto-detect optimal DataLoader workers if not specified (matches train.py behavior)
|
|
342
|
+
if workers < 0:
|
|
343
|
+
cpu_count = os.cpu_count() or 4
|
|
344
|
+
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
|
345
|
+
# Heuristic: 4-16 workers per GPU, bounded by available CPU cores
|
|
346
|
+
workers = min(16, max(2, (cpu_count - 2) // max(1, num_gpus)))
|
|
347
|
+
logger.info(
|
|
348
|
+
f"⚙️ Auto-detected workers: {workers} (CPUs: {cpu_count}, GPUs: {num_gpus})"
|
|
349
|
+
)
|
|
350
|
+
|
|
340
351
|
logger.info(f"🚀 K-Fold Cross-Validation ({folds} folds)")
|
|
341
352
|
logger.info(f" Model: {model_name} | Device: {device}")
|
|
342
353
|
logger.info(
|
wavedl/utils/data.py
CHANGED
|
@@ -763,9 +763,58 @@ def load_test_data(
|
|
|
763
763
|
k for k in OUTPUT_KEYS if k != "output_test"
|
|
764
764
|
]
|
|
765
765
|
|
|
766
|
-
# Load data using appropriate source
|
|
766
|
+
# Load data using appropriate source with test-key priority
|
|
767
|
+
# We detect keys first to ensure input_test/output_test are used when present
|
|
767
768
|
try:
|
|
768
|
-
|
|
769
|
+
if format == "npz":
|
|
770
|
+
with np.load(path, allow_pickle=False) as probe:
|
|
771
|
+
keys = list(probe.keys())
|
|
772
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
773
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
774
|
+
if inp_key is None:
|
|
775
|
+
raise KeyError(
|
|
776
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
777
|
+
)
|
|
778
|
+
data = NPZSource._safe_load(
|
|
779
|
+
path, [inp_key] + ([out_key] if out_key else [])
|
|
780
|
+
)
|
|
781
|
+
inp = data[inp_key]
|
|
782
|
+
if inp.dtype == object:
|
|
783
|
+
inp = np.array(
|
|
784
|
+
[x.toarray() if hasattr(x, "toarray") else x for x in inp]
|
|
785
|
+
)
|
|
786
|
+
outp = data[out_key] if out_key else None
|
|
787
|
+
elif format == "hdf5":
|
|
788
|
+
with h5py.File(path, "r") as f:
|
|
789
|
+
keys = list(f.keys())
|
|
790
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
791
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
792
|
+
if inp_key is None:
|
|
793
|
+
raise KeyError(
|
|
794
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
795
|
+
)
|
|
796
|
+
inp = f[inp_key][:]
|
|
797
|
+
outp = f[out_key][:] if out_key else None
|
|
798
|
+
elif format == "mat":
|
|
799
|
+
mat_source = MATSource()
|
|
800
|
+
with h5py.File(path, "r") as f:
|
|
801
|
+
keys = list(f.keys())
|
|
802
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
803
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
804
|
+
if inp_key is None:
|
|
805
|
+
raise KeyError(
|
|
806
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
807
|
+
)
|
|
808
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
809
|
+
if out_key:
|
|
810
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
811
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
812
|
+
outp = outp.T
|
|
813
|
+
else:
|
|
814
|
+
outp = None
|
|
815
|
+
else:
|
|
816
|
+
# Fallback to default source.load() for unknown formats
|
|
817
|
+
inp, outp = source.load(path)
|
|
769
818
|
except KeyError:
|
|
770
819
|
# Try with just inputs if outputs not found (inference-only mode)
|
|
771
820
|
if format == "npz":
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.3
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -301,8 +301,8 @@ python -m wavedl.test --checkpoint <checkpoint_folder> --data_path <test_data> \
|
|
|
301
301
|
|
|
302
302
|
**Requirements** (your model must):
|
|
303
303
|
1. Inherit from `BaseModel`
|
|
304
|
-
2. Accept `
|
|
305
|
-
3. Return a tensor of shape `(batch,
|
|
304
|
+
2. Accept `in_shape`, `out_size` in `__init__`
|
|
305
|
+
3. Return a tensor of shape `(batch, out_size)` from `forward()`
|
|
306
306
|
|
|
307
307
|
---
|
|
308
308
|
|
|
@@ -315,23 +315,22 @@ from wavedl.models import BaseModel, register_model
|
|
|
315
315
|
|
|
316
316
|
@register_model("my_model") # This name is used with --model flag
|
|
317
317
|
class MyModel(BaseModel):
|
|
318
|
-
def __init__(self,
|
|
319
|
-
#
|
|
320
|
-
#
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
|
|
318
|
+
def __init__(self, in_shape, out_size, **kwargs):
|
|
319
|
+
# in_shape: spatial dimensions, e.g., (128,) or (64, 64) or (32, 32, 32)
|
|
320
|
+
# out_size: number of parameters to predict (auto-detected from data)
|
|
321
|
+
super().__init__(in_shape, out_size)
|
|
322
|
+
|
|
323
|
+
# Define your layers (this is just an example for 2D)
|
|
324
|
+
self.conv1 = nn.Conv2d(1, 64, 3, padding=1) # Input always has 1 channel
|
|
326
325
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
|
327
|
-
self.fc = nn.Linear(128,
|
|
326
|
+
self.fc = nn.Linear(128, out_size)
|
|
328
327
|
|
|
329
328
|
def forward(self, x):
|
|
330
|
-
# Input x has shape: (batch,
|
|
329
|
+
# Input x has shape: (batch, 1, *in_shape)
|
|
331
330
|
x = F.relu(self.conv1(x))
|
|
332
331
|
x = F.relu(self.conv2(x))
|
|
333
332
|
x = x.mean(dim=[-2, -1]) # Global average pooling
|
|
334
|
-
return self.fc(x) # Output shape: (batch,
|
|
333
|
+
return self.fc(x) # Output shape: (batch, out_size)
|
|
335
334
|
```
|
|
336
335
|
|
|
337
336
|
**Step 2: Train**
|
|
@@ -573,14 +572,19 @@ WaveDL automatically enables performance optimizations for modern GPUs:
|
|
|
573
572
|
</details>
|
|
574
573
|
|
|
575
574
|
<details>
|
|
576
|
-
<summary><b>
|
|
575
|
+
<summary><b>HPC CLI Arguments (wavedl-hpc)</b></summary>
|
|
576
|
+
|
|
577
|
+
| Argument | Default | Description |
|
|
578
|
+
|----------|---------|-------------|
|
|
579
|
+
| `--num_gpus` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override |
|
|
580
|
+
| `--num_machines` | `1` | Number of machines in distributed setup |
|
|
581
|
+
| `--mixed_precision` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
582
|
+
| `--dynamo_backend` | `no` | PyTorch Dynamo backend |
|
|
583
|
+
|
|
584
|
+
**Environment Variables (for logging):**
|
|
577
585
|
|
|
578
586
|
| Variable | Default | Description |
|
|
579
587
|
|----------|---------|-------------|
|
|
580
|
-
| `NUM_GPUS` | **Auto-detected** | Number of GPUs to use. By default, automatically detected via `nvidia-smi`. Set explicitly to override (e.g., `NUM_GPUS=2`) |
|
|
581
|
-
| `NUM_MACHINES` | `1` | Number of machines in distributed setup |
|
|
582
|
-
| `MIXED_PRECISION` | `bf16` | Precision mode: `bf16`, `fp16`, or `no` |
|
|
583
|
-
| `DYNAMO_BACKEND` | `no` | PyTorch Dynamo backend |
|
|
584
588
|
| `WANDB_MODE` | `offline` | WandB mode: `offline` or `online` |
|
|
585
589
|
|
|
586
590
|
</details>
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256
|
|
2
|
-
wavedl/hpc.py,sha256
|
|
1
|
+
wavedl/__init__.py,sha256=1h6l9c3ms45mYhJZskUm28my7Lrq9tXMUs4BtMTiK_s,1177
|
|
2
|
+
wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
|
|
3
3
|
wavedl/hpo.py,sha256=DGCGyt2yhr3WAifAuljhE26gg07CHdaQW4wpDaTKbyo,14968
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=Aao8ofyYALqPrMTQarRn4rPWzDLZD-PeuKNVJ76IrVQ,54344
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -14,25 +14,25 @@ wavedl/models/efficientnetv2.py,sha256=rP8y1ZAWyNyi0PXGPXg-4HjgzoELZ-CjMFgr8WnSX
|
|
|
14
14
|
wavedl/models/mobilenetv3.py,sha256=h3f6TiNSyHRH9Qidce7dCGTbdEWYfYF5kbU-TFoTg0U,9490
|
|
15
15
|
wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,2980
|
|
16
16
|
wavedl/models/regnet.py,sha256=Yf9gAoDLv0j4uEuoKC822gizHNh59LCbvFCMP11Q1C0,13116
|
|
17
|
-
wavedl/models/resnet.py,sha256=
|
|
17
|
+
wavedl/models/resnet.py,sha256=laePTbIgINijh-Xkcp4iui8-1F17NJAjyAuA4T11eG4,18027
|
|
18
18
|
wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,8746
|
|
19
19
|
wavedl/models/swin.py,sha256=p-okfq3Qm4_neJTxCcMzoHoVzC0BHW3BMnbpr_Ri2U0,13224
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
22
|
wavedl/models/vit.py,sha256=68o9nNjkftvHFArAPupU2ew5e5yCsI2AYaT9TQinVMk,12075
|
|
23
23
|
wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
24
|
-
wavedl/utils/config.py,sha256=
|
|
24
|
+
wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
|
-
wavedl/utils/cross_validation.py,sha256=
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
26
|
+
wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
|
|
27
|
+
wavedl/utils/data.py,sha256=H5crttnSfJZBMWQOvM7Cq7nkefnhVlgO0O6J71zJdgI,52651
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
30
|
wavedl/utils/metrics.py,sha256=EJmJvF7gACQsUoKYldlladN_SbnRiuE-Smj0eSnbraQ,39394
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.3.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.3.dist-info/METADATA,sha256=bPNcR8sYE9U7a001lvMFn9oHfmcmkpHUDdGRowLjJEs,45488
|
|
35
|
+
wavedl-1.5.3.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.3.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.3.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|