wavedl 1.5.3__py3-none-any.whl → 1.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/hpo.py +2 -1
- wavedl/models/swin.py +31 -10
- wavedl/train.py +32 -4
- wavedl/utils/data.py +104 -13
- wavedl/utils/metrics.py +22 -1
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/METADATA +11 -10
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/RECORD +12 -12
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/LICENSE +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/WHEEL +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.3.dist-info → wavedl-1.5.5.dist-info}/top_level.txt +0 -0
wavedl/__init__.py
CHANGED
wavedl/hpo.py
CHANGED
|
@@ -175,13 +175,14 @@ def create_objective(args):
|
|
|
175
175
|
env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
176
176
|
|
|
177
177
|
# Run training
|
|
178
|
+
# Note: We inherit the user's cwd instead of setting cwd=Path(__file__).parent
|
|
179
|
+
# because site-packages may be read-only and train.py creates cache directories
|
|
178
180
|
try:
|
|
179
181
|
result = subprocess.run(
|
|
180
182
|
cmd,
|
|
181
183
|
capture_output=True,
|
|
182
184
|
text=True,
|
|
183
185
|
timeout=args.timeout,
|
|
184
|
-
cwd=Path(__file__).parent,
|
|
185
186
|
env=env,
|
|
186
187
|
)
|
|
187
188
|
|
wavedl/models/swin.py
CHANGED
|
@@ -191,22 +191,33 @@ class SwinTransformerBase(BaseModel):
|
|
|
191
191
|
Returns:
|
|
192
192
|
List of parameter group dictionaries
|
|
193
193
|
"""
|
|
194
|
-
# Separate parameters
|
|
194
|
+
# Separate parameters into 4 groups for proper LR decay:
|
|
195
|
+
# 1. Head params with decay (full LR)
|
|
196
|
+
# 2. Backbone params with decay (0.1× LR)
|
|
197
|
+
# 3. Head bias/norm without decay (full LR)
|
|
198
|
+
# 4. Backbone bias/norm without decay (0.1× LR)
|
|
195
199
|
head_params = []
|
|
196
200
|
backbone_params = []
|
|
197
|
-
|
|
201
|
+
head_no_decay = []
|
|
202
|
+
backbone_no_decay = []
|
|
198
203
|
|
|
199
204
|
for name, param in self.backbone.named_parameters():
|
|
200
205
|
if not param.requires_grad:
|
|
201
206
|
continue
|
|
202
207
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
+
is_head = "head" in name
|
|
209
|
+
is_no_decay = "bias" in name or "norm" in name
|
|
210
|
+
|
|
211
|
+
if is_head:
|
|
212
|
+
if is_no_decay:
|
|
213
|
+
head_no_decay.append(param)
|
|
214
|
+
else:
|
|
215
|
+
head_params.append(param)
|
|
208
216
|
else:
|
|
209
|
-
|
|
217
|
+
if is_no_decay:
|
|
218
|
+
backbone_no_decay.append(param)
|
|
219
|
+
else:
|
|
220
|
+
backbone_params.append(param)
|
|
210
221
|
|
|
211
222
|
groups = []
|
|
212
223
|
|
|
@@ -229,15 +240,25 @@ class SwinTransformerBase(BaseModel):
|
|
|
229
240
|
}
|
|
230
241
|
)
|
|
231
242
|
|
|
232
|
-
if
|
|
243
|
+
if head_no_decay:
|
|
233
244
|
groups.append(
|
|
234
245
|
{
|
|
235
|
-
"params":
|
|
246
|
+
"params": head_no_decay,
|
|
236
247
|
"lr": base_lr,
|
|
237
248
|
"weight_decay": 0.0,
|
|
238
249
|
}
|
|
239
250
|
)
|
|
240
251
|
|
|
252
|
+
if backbone_no_decay:
|
|
253
|
+
# Backbone bias/norm also gets 0.1× LR to match intended decay
|
|
254
|
+
groups.append(
|
|
255
|
+
{
|
|
256
|
+
"params": backbone_no_decay,
|
|
257
|
+
"lr": base_lr * 0.1,
|
|
258
|
+
"weight_decay": 0.0,
|
|
259
|
+
}
|
|
260
|
+
)
|
|
261
|
+
|
|
241
262
|
return groups if groups else [{"params": self.parameters(), "lr": base_lr}]
|
|
242
263
|
|
|
243
264
|
|
wavedl/train.py
CHANGED
|
@@ -122,6 +122,7 @@ import matplotlib.pyplot as plt
|
|
|
122
122
|
import numpy as np
|
|
123
123
|
import pandas as pd
|
|
124
124
|
import torch
|
|
125
|
+
import torch.distributed as dist
|
|
125
126
|
from accelerate import Accelerator
|
|
126
127
|
from accelerate.utils import set_seed
|
|
127
128
|
from sklearn.metrics import r2_score
|
|
@@ -470,15 +471,19 @@ def main():
|
|
|
470
471
|
try:
|
|
471
472
|
# Handle both module names (my_model) and file paths (./my_model.py)
|
|
472
473
|
if module_name.endswith(".py"):
|
|
473
|
-
# Import from file path
|
|
474
|
+
# Import from file path with unique module name
|
|
474
475
|
import importlib.util
|
|
475
476
|
|
|
477
|
+
# Derive unique module name from filename to avoid collisions
|
|
478
|
+
base_name = os.path.splitext(os.path.basename(module_name))[0]
|
|
479
|
+
unique_name = f"wavedl_custom_{base_name}"
|
|
480
|
+
|
|
476
481
|
spec = importlib.util.spec_from_file_location(
|
|
477
|
-
|
|
482
|
+
unique_name, module_name
|
|
478
483
|
)
|
|
479
484
|
if spec and spec.loader:
|
|
480
485
|
module = importlib.util.module_from_spec(spec)
|
|
481
|
-
sys.modules[
|
|
486
|
+
sys.modules[unique_name] = module
|
|
482
487
|
spec.loader.exec_module(module)
|
|
483
488
|
print(f"✓ Imported custom module from: {module_name}")
|
|
484
489
|
else:
|
|
@@ -1250,9 +1255,32 @@ def main():
|
|
|
1250
1255
|
)
|
|
1251
1256
|
|
|
1252
1257
|
# Learning rate scheduling (epoch-based schedulers only)
|
|
1258
|
+
# NOTE: For ReduceLROnPlateau with DDP, we must step only on main process
|
|
1259
|
+
# to avoid patience counter being incremented by all GPU processes.
|
|
1260
|
+
# Then we sync the new LR to all processes to keep them consistent.
|
|
1253
1261
|
if not scheduler_step_per_batch:
|
|
1254
1262
|
if args.scheduler == "plateau":
|
|
1255
|
-
|
|
1263
|
+
# Step only on main process to avoid multi-GPU patience bug
|
|
1264
|
+
if accelerator.is_main_process:
|
|
1265
|
+
scheduler.step(avg_val_loss)
|
|
1266
|
+
|
|
1267
|
+
# Sync LR across all processes after main process updates it
|
|
1268
|
+
accelerator.wait_for_everyone()
|
|
1269
|
+
|
|
1270
|
+
# Broadcast new LR from rank 0 to all processes
|
|
1271
|
+
if dist.is_initialized():
|
|
1272
|
+
if accelerator.is_main_process:
|
|
1273
|
+
new_lr = optimizer.param_groups[0]["lr"]
|
|
1274
|
+
else:
|
|
1275
|
+
new_lr = 0.0
|
|
1276
|
+
new_lr_tensor = torch.tensor(
|
|
1277
|
+
new_lr, device=accelerator.device, dtype=torch.float32
|
|
1278
|
+
)
|
|
1279
|
+
dist.broadcast(new_lr_tensor, src=0)
|
|
1280
|
+
# Update LR on non-main processes
|
|
1281
|
+
if not accelerator.is_main_process:
|
|
1282
|
+
for param_group in optimizer.param_groups:
|
|
1283
|
+
param_group["lr"] = new_lr_tensor.item()
|
|
1256
1284
|
else:
|
|
1257
1285
|
scheduler.step()
|
|
1258
1286
|
|
wavedl/utils/data.py
CHANGED
|
@@ -207,6 +207,10 @@ class NPZSource(DataSource):
|
|
|
207
207
|
|
|
208
208
|
The error for object arrays happens at ACCESS time, not load time.
|
|
209
209
|
So we need to probe the keys to detect if pickle is required.
|
|
210
|
+
|
|
211
|
+
WARNING: When mmap_mode is not None, the returned NpzFile must be kept
|
|
212
|
+
open for arrays to remain valid. Caller is responsible for closing.
|
|
213
|
+
For non-mmap loading, use _load_and_copy() instead to avoid leaks.
|
|
210
214
|
"""
|
|
211
215
|
data = np.load(path, allow_pickle=False, mmap_mode=mmap_mode)
|
|
212
216
|
try:
|
|
@@ -222,6 +226,26 @@ class NPZSource(DataSource):
|
|
|
222
226
|
return np.load(path, allow_pickle=True, mmap_mode=mmap_mode)
|
|
223
227
|
raise
|
|
224
228
|
|
|
229
|
+
@staticmethod
|
|
230
|
+
def _load_and_copy(path: str, keys: list[str]) -> dict[str, np.ndarray]:
|
|
231
|
+
"""Load NPZ and copy arrays, ensuring file is properly closed.
|
|
232
|
+
|
|
233
|
+
This prevents file descriptor leaks by copying arrays before closing.
|
|
234
|
+
Use this for eager loading; use _safe_load for memory-mapped access.
|
|
235
|
+
"""
|
|
236
|
+
data = NPZSource._safe_load(path, keys, mmap_mode=None)
|
|
237
|
+
try:
|
|
238
|
+
result = {}
|
|
239
|
+
for key in keys:
|
|
240
|
+
if key in data:
|
|
241
|
+
arr = data[key]
|
|
242
|
+
# Copy ensures we don't hold reference to mmap
|
|
243
|
+
result[key] = arr.copy() if hasattr(arr, "copy") else arr
|
|
244
|
+
return result
|
|
245
|
+
finally:
|
|
246
|
+
if hasattr(data, "close"):
|
|
247
|
+
data.close()
|
|
248
|
+
|
|
225
249
|
def load(self, path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
226
250
|
"""Load NPZ file (pickle enabled only for sparse matrices)."""
|
|
227
251
|
# First pass to find keys without loading data
|
|
@@ -238,7 +262,7 @@ class NPZSource(DataSource):
|
|
|
238
262
|
f"Found: {keys}"
|
|
239
263
|
)
|
|
240
264
|
|
|
241
|
-
data = self.
|
|
265
|
+
data = self._load_and_copy(path, [input_key, output_key])
|
|
242
266
|
inp = data[input_key]
|
|
243
267
|
outp = data[output_key]
|
|
244
268
|
|
|
@@ -290,7 +314,7 @@ class NPZSource(DataSource):
|
|
|
290
314
|
f"Supported keys: {OUTPUT_KEYS}. Found: {keys}"
|
|
291
315
|
)
|
|
292
316
|
|
|
293
|
-
data = self.
|
|
317
|
+
data = self._load_and_copy(path, [output_key])
|
|
294
318
|
return data[output_key]
|
|
295
319
|
|
|
296
320
|
|
|
@@ -527,9 +551,17 @@ class MATSource(DataSource):
|
|
|
527
551
|
inp = self._load_dataset(f, input_key)
|
|
528
552
|
outp = self._load_dataset(f, output_key)
|
|
529
553
|
|
|
530
|
-
# Handle
|
|
531
|
-
|
|
532
|
-
|
|
554
|
+
# Handle transposed outputs from MATLAB.
|
|
555
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
556
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
557
|
+
num_samples = inp.shape[0] # inp is already transposed
|
|
558
|
+
if outp.ndim == 2:
|
|
559
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
560
|
+
# 1D vector: (1, N) → (N, 1)
|
|
561
|
+
outp = outp.T
|
|
562
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
563
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
564
|
+
outp = outp.T
|
|
533
565
|
|
|
534
566
|
except OSError as e:
|
|
535
567
|
raise ValueError(
|
|
@@ -614,7 +646,10 @@ class MATSource(DataSource):
|
|
|
614
646
|
# Load with sparse matrix support
|
|
615
647
|
outp = self._load_dataset(f, output_key)
|
|
616
648
|
|
|
617
|
-
# Handle 1D outputs
|
|
649
|
+
# Handle 1D outputs that become (1, N) after transpose.
|
|
650
|
+
# Note: This method has no input to compare against, so we can't
|
|
651
|
+
# distinguish single-sample outputs. This is acceptable for training
|
|
652
|
+
# data where single-sample is unlikely. For inference, use load_test_data.
|
|
618
653
|
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
619
654
|
outp = outp.T
|
|
620
655
|
|
|
@@ -775,7 +810,7 @@ def load_test_data(
|
|
|
775
810
|
raise KeyError(
|
|
776
811
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
777
812
|
)
|
|
778
|
-
data = NPZSource.
|
|
813
|
+
data = NPZSource._load_and_copy(
|
|
779
814
|
path, [inp_key] + ([out_key] if out_key else [])
|
|
780
815
|
)
|
|
781
816
|
inp = data[inp_key]
|
|
@@ -793,6 +828,14 @@ def load_test_data(
|
|
|
793
828
|
raise KeyError(
|
|
794
829
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
795
830
|
)
|
|
831
|
+
# OOM guard: warn if dataset is very large
|
|
832
|
+
n_samples = f[inp_key].shape[0]
|
|
833
|
+
if n_samples > 100000:
|
|
834
|
+
raise ValueError(
|
|
835
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
836
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
837
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
838
|
+
)
|
|
796
839
|
inp = f[inp_key][:]
|
|
797
840
|
outp = f[out_key][:] if out_key else None
|
|
798
841
|
elif format == "mat":
|
|
@@ -805,11 +848,28 @@ def load_test_data(
|
|
|
805
848
|
raise KeyError(
|
|
806
849
|
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
807
850
|
)
|
|
851
|
+
# OOM guard: warn if dataset is very large (MAT is transposed)
|
|
852
|
+
n_samples = f[inp_key].shape[-1]
|
|
853
|
+
if n_samples > 100000:
|
|
854
|
+
raise ValueError(
|
|
855
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
856
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
857
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
858
|
+
)
|
|
808
859
|
inp = mat_source._load_dataset(f, inp_key)
|
|
809
860
|
if out_key:
|
|
810
861
|
outp = mat_source._load_dataset(f, out_key)
|
|
811
|
-
|
|
812
|
-
|
|
862
|
+
# Handle transposed outputs from MATLAB
|
|
863
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
864
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
865
|
+
num_samples = inp.shape[0]
|
|
866
|
+
if outp.ndim == 2:
|
|
867
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
868
|
+
# 1D vector: (1, N) → (N, 1)
|
|
869
|
+
outp = outp.T
|
|
870
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
871
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
872
|
+
outp = outp.T
|
|
813
873
|
else:
|
|
814
874
|
outp = None
|
|
815
875
|
else:
|
|
@@ -828,7 +888,7 @@ def load_test_data(
|
|
|
828
888
|
)
|
|
829
889
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
830
890
|
keys_to_probe = [inp_key] + ([out_key] if out_key else [])
|
|
831
|
-
data = NPZSource.
|
|
891
|
+
data = NPZSource._load_and_copy(path, keys_to_probe)
|
|
832
892
|
inp = data[inp_key]
|
|
833
893
|
if inp.dtype == object:
|
|
834
894
|
inp = np.array(
|
|
@@ -878,9 +938,17 @@ def load_test_data(
|
|
|
878
938
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
879
939
|
if out_key:
|
|
880
940
|
outp = mat_source._load_dataset(f, out_key)
|
|
881
|
-
# Handle
|
|
882
|
-
|
|
883
|
-
|
|
941
|
+
# Handle transposed outputs from MATLAB
|
|
942
|
+
# Case 1: (1, N) - N samples with 1 target → transpose to (N, 1)
|
|
943
|
+
# Case 2: (T, 1) - 1 sample with T targets → transpose to (1, T)
|
|
944
|
+
num_samples = inp.shape[0]
|
|
945
|
+
if outp.ndim == 2:
|
|
946
|
+
if outp.shape[0] == 1 and outp.shape[1] == num_samples:
|
|
947
|
+
# 1D vector: (1, N) → (N, 1)
|
|
948
|
+
outp = outp.T
|
|
949
|
+
elif outp.shape[1] == 1 and outp.shape[0] != num_samples:
|
|
950
|
+
# Single sample with multiple targets: (T, 1) → (1, T)
|
|
951
|
+
outp = outp.T
|
|
884
952
|
else:
|
|
885
953
|
outp = None
|
|
886
954
|
else:
|
|
@@ -1126,6 +1194,29 @@ def prepare_data(
|
|
|
1126
1194
|
|
|
1127
1195
|
if not cache_exists:
|
|
1128
1196
|
if accelerator.is_main_process:
|
|
1197
|
+
# Delete stale cache files to force regeneration
|
|
1198
|
+
# This prevents silent reuse of old data when metadata invalidates cache
|
|
1199
|
+
for stale_file in [CACHE_FILE, SCALER_FILE]:
|
|
1200
|
+
if os.path.exists(stale_file):
|
|
1201
|
+
try:
|
|
1202
|
+
os.remove(stale_file)
|
|
1203
|
+
logger.debug(f" Removed stale cache: {stale_file}")
|
|
1204
|
+
except OSError as e:
|
|
1205
|
+
logger.warning(
|
|
1206
|
+
f" Failed to remove stale cache {stale_file}: {e}"
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
# Fail explicitly if stale cache files couldn't be removed
|
|
1210
|
+
# This prevents silent reuse of outdated data
|
|
1211
|
+
remaining_stale = [
|
|
1212
|
+
f for f in [CACHE_FILE, SCALER_FILE] if os.path.exists(f)
|
|
1213
|
+
]
|
|
1214
|
+
if remaining_stale:
|
|
1215
|
+
raise RuntimeError(
|
|
1216
|
+
f"Cannot regenerate cache: stale files could not be removed. "
|
|
1217
|
+
f"Please manually delete: {remaining_stale}"
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1129
1220
|
# RANK 0: Create cache (can take a long time for large datasets)
|
|
1130
1221
|
# Other ranks will wait at the barrier below
|
|
1131
1222
|
|
wavedl/utils/metrics.py
CHANGED
|
@@ -815,7 +815,28 @@ def plot_qq(
|
|
|
815
815
|
|
|
816
816
|
# Standardize errors for QQ plot
|
|
817
817
|
err = errors[:, i]
|
|
818
|
-
|
|
818
|
+
std_err = np.std(err)
|
|
819
|
+
|
|
820
|
+
# Guard against zero variance (constant errors)
|
|
821
|
+
if std_err < 1e-10:
|
|
822
|
+
title = (
|
|
823
|
+
param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
824
|
+
)
|
|
825
|
+
ax.text(
|
|
826
|
+
0.5,
|
|
827
|
+
0.5,
|
|
828
|
+
"Zero variance\n(constant errors)",
|
|
829
|
+
ha="center",
|
|
830
|
+
va="center",
|
|
831
|
+
fontsize=10,
|
|
832
|
+
transform=ax.transAxes,
|
|
833
|
+
)
|
|
834
|
+
ax.set_title(f"{title}\n(zero variance)")
|
|
835
|
+
ax.set_xlabel("Theoretical Quantiles")
|
|
836
|
+
ax.set_ylabel("Sample Quantiles")
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
standardized = (err - np.mean(err)) / std_err
|
|
819
840
|
|
|
820
841
|
# Calculate theoretical quantiles and sample quantiles
|
|
821
842
|
(osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: wavedl
|
|
3
|
-
Version: 1.5.
|
|
3
|
+
Version: 1.5.5
|
|
4
4
|
Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
|
|
5
5
|
Author: Ductho Le
|
|
6
6
|
License: MIT
|
|
@@ -37,11 +37,12 @@ Requires-Dist: wandb>=0.15.0
|
|
|
37
37
|
Requires-Dist: optuna>=3.0.0
|
|
38
38
|
Requires-Dist: onnx>=1.14.0
|
|
39
39
|
Requires-Dist: onnxruntime>=1.15.0
|
|
40
|
-
Requires-Dist: pytest>=7.0.0
|
|
41
|
-
Requires-Dist: pytest-xdist>=3.5.0
|
|
42
|
-
Requires-Dist: ruff>=0.8.0
|
|
43
|
-
Requires-Dist: pre-commit>=3.5.0
|
|
44
40
|
Requires-Dist: triton>=2.0.0; sys_platform == "linux"
|
|
41
|
+
Provides-Extra: dev
|
|
42
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
|
44
|
+
Requires-Dist: ruff>=0.8.0; extra == "dev"
|
|
45
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
45
46
|
|
|
46
47
|
<div align="center">
|
|
47
48
|
|
|
@@ -204,7 +205,7 @@ Deploy models anywhere:
|
|
|
204
205
|
pip install wavedl
|
|
205
206
|
```
|
|
206
207
|
|
|
207
|
-
This installs everything you need: training, inference, HPO, ONNX export
|
|
208
|
+
This installs everything you need: training, inference, HPO, ONNX export.
|
|
208
209
|
|
|
209
210
|
#### From Source (for development)
|
|
210
211
|
|
|
@@ -336,7 +337,7 @@ class MyModel(BaseModel):
|
|
|
336
337
|
**Step 2: Train**
|
|
337
338
|
|
|
338
339
|
```bash
|
|
339
|
-
wavedl-hpc --import my_model --model my_model --data_path train.npz
|
|
340
|
+
wavedl-hpc --import my_model.py --model my_model --data_path train.npz
|
|
340
341
|
```
|
|
341
342
|
|
|
342
343
|
WaveDL handles everything else: training loop, logging, checkpoints, multi-GPU, early stopping, etc.
|
|
@@ -387,7 +388,7 @@ WaveDL/
|
|
|
387
388
|
├── configs/ # YAML config templates
|
|
388
389
|
├── examples/ # Ready-to-run examples
|
|
389
390
|
├── notebooks/ # Jupyter notebooks
|
|
390
|
-
├── unit_tests/ # Pytest test suite (
|
|
391
|
+
├── unit_tests/ # Pytest test suite (731 tests)
|
|
391
392
|
│
|
|
392
393
|
├── pyproject.toml # Package config, dependencies
|
|
393
394
|
├── CHANGELOG.md # Version history
|
|
@@ -512,7 +513,7 @@ print('\\n✓ All pretrained weights cached!')
|
|
|
512
513
|
| Argument | Default | Description |
|
|
513
514
|
|----------|---------|-------------|
|
|
514
515
|
| `--model` | `cnn` | Model architecture |
|
|
515
|
-
| `--import` | - | Python
|
|
516
|
+
| `--import` | - | Python file(s) to import for custom models (supports multiple) |
|
|
516
517
|
| `--batch_size` | `128` | Per-GPU batch size |
|
|
517
518
|
| `--lr` | `1e-3` | Learning rate |
|
|
518
519
|
| `--epochs` | `1000` | Maximum epochs |
|
|
@@ -1223,6 +1224,6 @@ This research was enabled in part by support provided by [Compute Ontario](https
|
|
|
1223
1224
|
[](https://scholar.google.ca/citations?user=OlwMr9AAAAAJ)
|
|
1224
1225
|
[](https://www.researchgate.net/profile/Ductho-Le)
|
|
1225
1226
|
|
|
1226
|
-
<sub>
|
|
1227
|
+
<sub>May your signals be strong and your attenuation low 👋</sub>
|
|
1227
1228
|
|
|
1228
1229
|
</div>
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
wavedl/__init__.py,sha256=
|
|
1
|
+
wavedl/__init__.py,sha256=RTePiYlzCrUofbGSYWAAqoKeeyYjqEPzuXyze6ai324,1177
|
|
2
2
|
wavedl/hpc.py,sha256=6rV38nozzMt0-jKZbVJNwvQZXK0wUsIZmr9lgWN_XUw,9212
|
|
3
|
-
wavedl/hpo.py,sha256=
|
|
3
|
+
wavedl/hpo.py,sha256=CZF0MZwTGMOrPGDveUXZFbGHwLHj1FcJTCBKVVEtLWg,15105
|
|
4
4
|
wavedl/test.py,sha256=WIHG3HWT-uF399FQApPpxjggBVFn59cC54HAL4990QU,38550
|
|
5
|
-
wavedl/train.py,sha256=
|
|
5
|
+
wavedl/train.py,sha256=7AVaCORFUv2_IgdYSPKdHLxbi11GzMOyu4RcNc4Uf_I,55963
|
|
6
6
|
wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
|
|
7
7
|
wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
|
|
8
8
|
wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
|
|
@@ -16,7 +16,7 @@ wavedl/models/registry.py,sha256=InYAXX2xbRvsFDFnYUPCptJh0F9lHlFPN77A9kqHRT0,298
|
|
|
16
16
|
wavedl/models/regnet.py,sha256=Yf9gAoDLv0j4uEuoKC822gizHNh59LCbvFCMP11Q1C0,13116
|
|
17
17
|
wavedl/models/resnet.py,sha256=laePTbIgINijh-Xkcp4iui8-1F17NJAjyAuA4T11eG4,18027
|
|
18
18
|
wavedl/models/resnet3d.py,sha256=C7CL4XeSnRlIBuwf5Ei-z183uzIBObrXfkM9Iwuc5e0,8746
|
|
19
|
-
wavedl/models/swin.py,sha256=
|
|
19
|
+
wavedl/models/swin.py,sha256=cbV_iqIS4no-EAUR8j_93gqd59AkAkfM5DYo6VryLEg,13937
|
|
20
20
|
wavedl/models/tcn.py,sha256=RtY13QpFHqz72b4ultv2lStCIDxfvjySVe5JaTx_GaM,12601
|
|
21
21
|
wavedl/models/unet.py,sha256=LqIXhasdBygwP7SZNNmiW1bHMPaJTVBpaeHtPgEHkdU,7790
|
|
22
22
|
wavedl/models/vit.py,sha256=68o9nNjkftvHFArAPupU2ew5e5yCsI2AYaT9TQinVMk,12075
|
|
@@ -24,15 +24,15 @@ wavedl/utils/__init__.py,sha256=s5R9bRmJ8GNcJrD3OSAOXzwZJIXZbdYrAkZnus11sVQ,3300
|
|
|
24
24
|
wavedl/utils/config.py,sha256=AsGwb3XtxmbTLb59BLl5AA4wzMNgVTpl7urOJ6IGqfM,10901
|
|
25
25
|
wavedl/utils/constraints.py,sha256=Pof5hzeTSGsPY_E6Sc8iMQDaXc_zfEasQI2tCszk_gw,17614
|
|
26
26
|
wavedl/utils/cross_validation.py,sha256=gwXSFTx5oxWndPjWLJAJzB6nnq2f1t9f86SbjbF-jNI,18475
|
|
27
|
-
wavedl/utils/data.py,sha256=
|
|
27
|
+
wavedl/utils/data.py,sha256=cmJ6tUw4Tcxj-l3Xsphs1Dnlx1MzxOPvk8etD5KXFNs,57686
|
|
28
28
|
wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
|
|
29
29
|
wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
|
|
30
|
-
wavedl/utils/metrics.py,sha256=
|
|
30
|
+
wavedl/utils/metrics.py,sha256=El2NYsulH5jxBhC1gCAMcS8C-yxEjuSC930LhsKYQrY,40059
|
|
31
31
|
wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
|
|
32
32
|
wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
|
|
33
|
-
wavedl-1.5.
|
|
34
|
-
wavedl-1.5.
|
|
35
|
-
wavedl-1.5.
|
|
36
|
-
wavedl-1.5.
|
|
37
|
-
wavedl-1.5.
|
|
38
|
-
wavedl-1.5.
|
|
33
|
+
wavedl-1.5.5.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
|
|
34
|
+
wavedl-1.5.5.dist-info/METADATA,sha256=0e7E8zLd-GlcR5Hbgp5VDYGGr36_9NzKBTShsG4xuQs,45604
|
|
35
|
+
wavedl-1.5.5.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
36
|
+
wavedl-1.5.5.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
|
|
37
|
+
wavedl-1.5.5.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
|
|
38
|
+
wavedl-1.5.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|