wavedl 1.4.1__tar.gz → 1.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavedl-1.4.1/src/wavedl.egg-info → wavedl-1.4.2}/PKG-INFO +1 -1
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/__init__.py +1 -1
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/hpo.py +46 -19
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/train.py +20 -13
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/data.py +79 -2
- {wavedl-1.4.1 → wavedl-1.4.2/src/wavedl.egg-info}/PKG-INFO +1 -1
- {wavedl-1.4.1 → wavedl-1.4.2}/LICENSE +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/README.md +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/pyproject.toml +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/setup.cfg +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/hpc.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/__init__.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/_template.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/base.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/cnn.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/convnext.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/densenet.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/efficientnet.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/efficientnetv2.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/mobilenetv3.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/registry.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/regnet.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/resnet.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/resnet3d.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/swin.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/tcn.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/unet.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/models/vit.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/test.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/__init__.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/config.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/cross_validation.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/distributed.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/losses.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/metrics.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/optimizers.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl/utils/schedulers.py +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl.egg-info/SOURCES.txt +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl.egg-info/dependency_links.txt +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl.egg-info/entry_points.txt +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl.egg-info/requires.txt +0 -0
- {wavedl-1.4.1 → wavedl-1.4.2}/src/wavedl.egg-info/top_level.txt +0 -0
|
@@ -145,6 +145,7 @@ def create_objective(args):
|
|
|
145
145
|
# Use temporary directory for trial output
|
|
146
146
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
147
147
|
cmd.extend(["--output_dir", tmpdir])
|
|
148
|
+
history_file = Path(tmpdir) / "training_history.csv"
|
|
148
149
|
|
|
149
150
|
# Run training
|
|
150
151
|
try:
|
|
@@ -156,29 +157,55 @@ def create_objective(args):
|
|
|
156
157
|
cwd=Path(__file__).parent,
|
|
157
158
|
)
|
|
158
159
|
|
|
159
|
-
#
|
|
160
|
-
# Look for "Best val_loss: X.XXXX" in stdout
|
|
160
|
+
# Read best val_loss from training_history.csv (reliable machine-readable)
|
|
161
161
|
val_loss = None
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
162
|
+
if history_file.exists():
|
|
163
|
+
try:
|
|
164
|
+
import csv
|
|
165
|
+
|
|
166
|
+
with open(history_file) as f:
|
|
167
|
+
reader = csv.DictReader(f)
|
|
168
|
+
val_losses = []
|
|
169
|
+
for row in reader:
|
|
170
|
+
if "val_loss" in row:
|
|
171
|
+
try:
|
|
172
|
+
val_losses.append(float(row["val_loss"]))
|
|
173
|
+
except (ValueError, TypeError):
|
|
174
|
+
pass
|
|
175
|
+
if val_losses:
|
|
176
|
+
val_loss = min(val_losses) # Best (minimum) val_loss
|
|
177
|
+
except Exception as e:
|
|
178
|
+
print(f"Trial {trial.number}: Error reading history: {e}")
|
|
179
|
+
|
|
180
|
+
if val_loss is None:
|
|
181
|
+
# Fallback: parse stdout for training log format
|
|
182
|
+
# Pattern: "epoch | train_loss | val_loss | ..."
|
|
183
|
+
# Use regex to avoid false positives from unrelated lines
|
|
184
|
+
import re
|
|
185
|
+
|
|
186
|
+
# Match lines like: " 42 | 0.0123 | 0.0156 | ..."
|
|
187
|
+
log_pattern = re.compile(
|
|
188
|
+
r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
|
|
189
|
+
)
|
|
190
|
+
val_losses_stdout = []
|
|
191
|
+
for line in result.stdout.split("\n"):
|
|
192
|
+
match = log_pattern.match(line)
|
|
193
|
+
if match:
|
|
194
|
+
try:
|
|
195
|
+
val_losses_stdout.append(float(match.group(1)))
|
|
196
|
+
except ValueError:
|
|
197
|
+
continue
|
|
198
|
+
if val_losses_stdout:
|
|
199
|
+
val_loss = min(val_losses_stdout)
|
|
178
200
|
|
|
179
201
|
if val_loss is None:
|
|
180
202
|
# Training failed or no loss found
|
|
181
|
-
print(f"Trial {trial.number}: Training failed")
|
|
203
|
+
print(f"Trial {trial.number}: Training failed (no val_loss found)")
|
|
204
|
+
if result.returncode != 0:
|
|
205
|
+
# Show last few lines of stderr for debugging
|
|
206
|
+
stderr_lines = result.stderr.strip().split("\n")[-3:]
|
|
207
|
+
for line in stderr_lines:
|
|
208
|
+
print(f" stderr: {line}")
|
|
182
209
|
return float("inf")
|
|
183
210
|
|
|
184
211
|
print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
|
|
@@ -851,7 +851,7 @@ def main():
|
|
|
851
851
|
val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
|
|
852
852
|
val_samples = 0
|
|
853
853
|
|
|
854
|
-
# Accumulate predictions locally
|
|
854
|
+
# Accumulate predictions locally ON CPU to prevent GPU OOM
|
|
855
855
|
local_preds = []
|
|
856
856
|
local_targets = []
|
|
857
857
|
|
|
@@ -867,17 +867,19 @@ def main():
|
|
|
867
867
|
mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
|
|
868
868
|
val_mae_sum += mae_batch
|
|
869
869
|
|
|
870
|
-
# Store
|
|
871
|
-
local_preds.append(pred)
|
|
872
|
-
local_targets.append(y)
|
|
870
|
+
# Store on CPU (critical for large val sets)
|
|
871
|
+
local_preds.append(pred.detach().cpu())
|
|
872
|
+
local_targets.append(y.detach().cpu())
|
|
873
873
|
|
|
874
|
-
#
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
all_preds = accelerator.gather_for_metrics(all_local_preds)
|
|
878
|
-
all_targets = accelerator.gather_for_metrics(all_local_targets)
|
|
874
|
+
# Concatenate locally on CPU (no GPU memory spike)
|
|
875
|
+
cpu_preds = torch.cat(local_preds)
|
|
876
|
+
cpu_targets = torch.cat(local_targets)
|
|
879
877
|
|
|
880
|
-
#
|
|
878
|
+
# Gather to rank 0 only via gather_object (avoids all-gather to every rank)
|
|
879
|
+
# gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
|
|
880
|
+
gathered = accelerator.gather_object((cpu_preds, cpu_targets))
|
|
881
|
+
|
|
882
|
+
# Synchronize validation metrics (scalars only - efficient)
|
|
881
883
|
val_loss_scalar = val_loss_sum.item()
|
|
882
884
|
val_metrics = torch.cat(
|
|
883
885
|
[
|
|
@@ -900,9 +902,14 @@ def main():
|
|
|
900
902
|
|
|
901
903
|
# ==================== LOGGING & CHECKPOINTING ====================
|
|
902
904
|
if accelerator.is_main_process:
|
|
903
|
-
#
|
|
904
|
-
|
|
905
|
-
|
|
905
|
+
# Concatenate gathered tensors from all ranks (only on rank 0)
|
|
906
|
+
# gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
|
|
907
|
+
all_preds = torch.cat([item[0] for item in gathered])
|
|
908
|
+
all_targets = torch.cat([item[1] for item in gathered])
|
|
909
|
+
|
|
910
|
+
# Scientific metrics - cast to float32 before numpy
|
|
911
|
+
y_pred = all_preds.float().numpy()
|
|
912
|
+
y_true = all_targets.float().numpy()
|
|
906
913
|
|
|
907
914
|
# Trim DDP padding
|
|
908
915
|
real_len = len(val_dl.dataset)
|
|
@@ -735,7 +735,7 @@ def load_test_data(
|
|
|
735
735
|
try:
|
|
736
736
|
inp, outp = source.load(path)
|
|
737
737
|
except KeyError:
|
|
738
|
-
# Try with just inputs if outputs not found
|
|
738
|
+
# Try with just inputs if outputs not found (inference-only mode)
|
|
739
739
|
if format == "npz":
|
|
740
740
|
data = np.load(path, allow_pickle=True)
|
|
741
741
|
keys = list(data.keys())
|
|
@@ -751,6 +751,54 @@ def load_test_data(
|
|
|
751
751
|
)
|
|
752
752
|
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
753
753
|
outp = data[out_key] if out_key else None
|
|
754
|
+
elif format == "hdf5":
|
|
755
|
+
# HDF5: input-only loading for inference
|
|
756
|
+
with h5py.File(path, "r") as f:
|
|
757
|
+
keys = list(f.keys())
|
|
758
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
759
|
+
if inp_key is None:
|
|
760
|
+
raise KeyError(
|
|
761
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
762
|
+
)
|
|
763
|
+
# Check size - load_test_data is eager, large files should use DataLoader
|
|
764
|
+
n_samples = f[inp_key].shape[0]
|
|
765
|
+
if n_samples > 100000:
|
|
766
|
+
raise ValueError(
|
|
767
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
768
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
769
|
+
f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
|
|
770
|
+
)
|
|
771
|
+
inp = f[inp_key][:]
|
|
772
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
773
|
+
outp = f[out_key][:] if out_key else None
|
|
774
|
+
elif format == "mat":
|
|
775
|
+
# MAT v7.3: input-only loading with proper sparse handling
|
|
776
|
+
mat_source = MATSource()
|
|
777
|
+
with h5py.File(path, "r") as f:
|
|
778
|
+
keys = list(f.keys())
|
|
779
|
+
inp_key = DataSource._find_key(keys, custom_input_keys)
|
|
780
|
+
if inp_key is None:
|
|
781
|
+
raise KeyError(
|
|
782
|
+
f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
|
|
783
|
+
)
|
|
784
|
+
# Check size - load_test_data is eager, large files should use DataLoader
|
|
785
|
+
n_samples = f[inp_key].shape[-1] # MAT is transposed
|
|
786
|
+
if n_samples > 100000:
|
|
787
|
+
raise ValueError(
|
|
788
|
+
f"Dataset has {n_samples:,} samples. load_test_data() loads "
|
|
789
|
+
f"everything into RAM which may cause OOM. For large inference "
|
|
790
|
+
f"sets, use a DataLoader with MATSource.load_mmap() instead."
|
|
791
|
+
)
|
|
792
|
+
# Use _load_dataset for sparse support and proper transpose
|
|
793
|
+
inp = mat_source._load_dataset(f, inp_key)
|
|
794
|
+
out_key = DataSource._find_key(keys, custom_output_keys)
|
|
795
|
+
if out_key:
|
|
796
|
+
outp = mat_source._load_dataset(f, out_key)
|
|
797
|
+
# Handle 1D outputs that become (1, N) after transpose
|
|
798
|
+
if outp.ndim == 2 and outp.shape[0] == 1:
|
|
799
|
+
outp = outp.T
|
|
800
|
+
else:
|
|
801
|
+
outp = None
|
|
754
802
|
else:
|
|
755
803
|
raise
|
|
756
804
|
|
|
@@ -949,6 +997,15 @@ def prepare_data(
|
|
|
949
997
|
with open(META_FILE, "rb") as f:
|
|
950
998
|
meta = pickle.load(f)
|
|
951
999
|
cached_data_path = meta.get("data_path", None)
|
|
1000
|
+
cached_file_size = meta.get("file_size", None)
|
|
1001
|
+
cached_file_mtime = meta.get("file_mtime", None)
|
|
1002
|
+
|
|
1003
|
+
# Get current file stats
|
|
1004
|
+
current_stats = os.stat(args.data_path)
|
|
1005
|
+
current_size = current_stats.st_size
|
|
1006
|
+
current_mtime = current_stats.st_mtime
|
|
1007
|
+
|
|
1008
|
+
# Check if data path changed
|
|
952
1009
|
if cached_data_path != os.path.abspath(args.data_path):
|
|
953
1010
|
if accelerator.is_main_process:
|
|
954
1011
|
logger.warning(
|
|
@@ -958,6 +1015,23 @@ def prepare_data(
|
|
|
958
1015
|
f" Invalidating cache and regenerating..."
|
|
959
1016
|
)
|
|
960
1017
|
cache_exists = False
|
|
1018
|
+
# Check if file was modified (size or mtime changed)
|
|
1019
|
+
elif cached_file_size is not None and cached_file_size != current_size:
|
|
1020
|
+
if accelerator.is_main_process:
|
|
1021
|
+
logger.warning(
|
|
1022
|
+
f"⚠️ Data file size changed!\n"
|
|
1023
|
+
f" Cached size: {cached_file_size:,} bytes\n"
|
|
1024
|
+
f" Current size: {current_size:,} bytes\n"
|
|
1025
|
+
f" Invalidating cache and regenerating..."
|
|
1026
|
+
)
|
|
1027
|
+
cache_exists = False
|
|
1028
|
+
elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
|
|
1029
|
+
if accelerator.is_main_process:
|
|
1030
|
+
logger.warning(
|
|
1031
|
+
"⚠️ Data file was modified!\n"
|
|
1032
|
+
" Cache may be stale, regenerating..."
|
|
1033
|
+
)
|
|
1034
|
+
cache_exists = False
|
|
961
1035
|
except Exception:
|
|
962
1036
|
cache_exists = False
|
|
963
1037
|
|
|
@@ -1053,13 +1127,16 @@ def prepare_data(
|
|
|
1053
1127
|
f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
|
|
1054
1128
|
)
|
|
1055
1129
|
|
|
1056
|
-
# Save metadata (including data path for cache validation)
|
|
1130
|
+
# Save metadata (including data path, size, mtime for cache validation)
|
|
1131
|
+
file_stats = os.stat(args.data_path)
|
|
1057
1132
|
with open(META_FILE, "wb") as f:
|
|
1058
1133
|
pickle.dump(
|
|
1059
1134
|
{
|
|
1060
1135
|
"shape": full_shape,
|
|
1061
1136
|
"out_dim": out_dim,
|
|
1062
1137
|
"data_path": os.path.abspath(args.data_path),
|
|
1138
|
+
"file_size": file_stats.st_size,
|
|
1139
|
+
"file_mtime": file_stats.st_mtime,
|
|
1063
1140
|
},
|
|
1064
1141
|
f,
|
|
1065
1142
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|