wavedl 1.4.1__py3-none-any.whl → 1.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/__init__.py CHANGED
@@ -18,7 +18,7 @@ For inference:
18
18
  # or: python -m wavedl.test --checkpoint best_checkpoint --data_path test.npz
19
19
  """
20
20
 
21
- __version__ = "1.4.1"
21
+ __version__ = "1.4.2"
22
22
  __author__ = "Ductho Le"
23
23
  __email__ = "ductho.le@outlook.com"
24
24
 
wavedl/hpo.py CHANGED
@@ -145,6 +145,7 @@ def create_objective(args):
145
145
  # Use temporary directory for trial output
146
146
  with tempfile.TemporaryDirectory() as tmpdir:
147
147
  cmd.extend(["--output_dir", tmpdir])
148
+ history_file = Path(tmpdir) / "training_history.csv"
148
149
 
149
150
  # Run training
150
151
  try:
@@ -156,29 +157,55 @@ def create_objective(args):
156
157
  cwd=Path(__file__).parent,
157
158
  )
158
159
 
159
- # Parse validation loss from output
160
- # Look for "Best val_loss: X.XXXX" in stdout
160
+ # Read best val_loss from training_history.csv (reliable machine-readable)
161
161
  val_loss = None
162
- for line in result.stdout.split("\n"):
163
- if "Best val_loss:" in line:
164
- try:
165
- val_loss = float(line.split(":")[-1].strip())
166
- except ValueError:
167
- pass
168
- # Also check for final validation loss
169
- if "val_loss=" in line.lower():
170
- try:
171
- # Extract number after val_loss=
172
- parts = line.lower().split("val_loss=")
173
- if len(parts) > 1:
174
- val_str = parts[1].split()[0].strip(",")
175
- val_loss = float(val_str)
176
- except (ValueError, IndexError):
177
- pass
162
+ if history_file.exists():
163
+ try:
164
+ import csv
165
+
166
+ with open(history_file) as f:
167
+ reader = csv.DictReader(f)
168
+ val_losses = []
169
+ for row in reader:
170
+ if "val_loss" in row:
171
+ try:
172
+ val_losses.append(float(row["val_loss"]))
173
+ except (ValueError, TypeError):
174
+ pass
175
+ if val_losses:
176
+ val_loss = min(val_losses) # Best (minimum) val_loss
177
+ except Exception as e:
178
+ print(f"Trial {trial.number}: Error reading history: {e}")
179
+
180
+ if val_loss is None:
181
+ # Fallback: parse stdout for training log format
182
+ # Pattern: "epoch | train_loss | val_loss | ..."
183
+ # Use regex to avoid false positives from unrelated lines
184
+ import re
185
+
186
+ # Match lines like: " 42 | 0.0123 | 0.0156 | ..."
187
+ log_pattern = re.compile(
188
+ r"^\s*\d+\s*\|\s*[\d.]+\s*\|\s*([\d.]+)\s*\|"
189
+ )
190
+ val_losses_stdout = []
191
+ for line in result.stdout.split("\n"):
192
+ match = log_pattern.match(line)
193
+ if match:
194
+ try:
195
+ val_losses_stdout.append(float(match.group(1)))
196
+ except ValueError:
197
+ continue
198
+ if val_losses_stdout:
199
+ val_loss = min(val_losses_stdout)
178
200
 
179
201
  if val_loss is None:
180
202
  # Training failed or no loss found
181
- print(f"Trial {trial.number}: Training failed")
203
+ print(f"Trial {trial.number}: Training failed (no val_loss found)")
204
+ if result.returncode != 0:
205
+ # Show last few lines of stderr for debugging
206
+ stderr_lines = result.stderr.strip().split("\n")[-3:]
207
+ for line in stderr_lines:
208
+ print(f" stderr: {line}")
182
209
  return float("inf")
183
210
 
184
211
  print(f"Trial {trial.number}: val_loss={val_loss:.6f}")
wavedl/train.py CHANGED
@@ -851,7 +851,7 @@ def main():
851
851
  val_mae_sum = torch.zeros(out_dim, device=accelerator.device)
852
852
  val_samples = 0
853
853
 
854
- # Accumulate predictions locally, gather ONCE at end (reduces sync overhead)
854
+ # Accumulate predictions locally ON CPU to prevent GPU OOM
855
855
  local_preds = []
856
856
  local_targets = []
857
857
 
@@ -867,17 +867,19 @@ def main():
867
867
  mae_batch = torch.abs((pred - y) * phys_scale).sum(dim=0)
868
868
  val_mae_sum += mae_batch
869
869
 
870
- # Store locally (no GPU sync per batch)
871
- local_preds.append(pred)
872
- local_targets.append(y)
870
+ # Store on CPU (critical for large val sets)
871
+ local_preds.append(pred.detach().cpu())
872
+ local_targets.append(y.detach().cpu())
873
873
 
874
- # Single gather at end of validation (2 syncs instead of 2×num_batches)
875
- all_local_preds = torch.cat(local_preds)
876
- all_local_targets = torch.cat(local_targets)
877
- all_preds = accelerator.gather_for_metrics(all_local_preds)
878
- all_targets = accelerator.gather_for_metrics(all_local_targets)
874
+ # Concatenate locally on CPU (no GPU memory spike)
875
+ cpu_preds = torch.cat(local_preds)
876
+ cpu_targets = torch.cat(local_targets)
879
877
 
880
- # Synchronize validation metrics
878
+ # Gather to rank 0 only via gather_object (avoids all-gather to every rank)
879
+ # gather_object returns list of objects from each rank: [(preds0, targs0), (preds1, targs1), ...]
880
+ gathered = accelerator.gather_object((cpu_preds, cpu_targets))
881
+
882
+ # Synchronize validation metrics (scalars only - efficient)
881
883
  val_loss_scalar = val_loss_sum.item()
882
884
  val_metrics = torch.cat(
883
885
  [
@@ -900,9 +902,14 @@ def main():
900
902
 
901
903
  # ==================== LOGGING & CHECKPOINTING ====================
902
904
  if accelerator.is_main_process:
903
- # Scientific metrics - cast to float32 before numpy (bf16 can't convert)
904
- y_pred = all_preds.float().cpu().numpy()
905
- y_true = all_targets.float().cpu().numpy()
905
+ # Concatenate gathered tensors from all ranks (only on rank 0)
906
+ # gathered is list of tuples: [(preds_rank0, targs_rank0), (preds_rank1, targs_rank1), ...]
907
+ all_preds = torch.cat([item[0] for item in gathered])
908
+ all_targets = torch.cat([item[1] for item in gathered])
909
+
910
+ # Scientific metrics - cast to float32 before numpy
911
+ y_pred = all_preds.float().numpy()
912
+ y_true = all_targets.float().numpy()
906
913
 
907
914
  # Trim DDP padding
908
915
  real_len = len(val_dl.dataset)
wavedl/utils/data.py CHANGED
@@ -735,7 +735,7 @@ def load_test_data(
735
735
  try:
736
736
  inp, outp = source.load(path)
737
737
  except KeyError:
738
- # Try with just inputs if outputs not found
738
+ # Try with just inputs if outputs not found (inference-only mode)
739
739
  if format == "npz":
740
740
  data = np.load(path, allow_pickle=True)
741
741
  keys = list(data.keys())
@@ -751,6 +751,54 @@ def load_test_data(
751
751
  )
752
752
  out_key = DataSource._find_key(keys, custom_output_keys)
753
753
  outp = data[out_key] if out_key else None
754
+ elif format == "hdf5":
755
+ # HDF5: input-only loading for inference
756
+ with h5py.File(path, "r") as f:
757
+ keys = list(f.keys())
758
+ inp_key = DataSource._find_key(keys, custom_input_keys)
759
+ if inp_key is None:
760
+ raise KeyError(
761
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
762
+ )
763
+ # Check size - load_test_data is eager, large files should use DataLoader
764
+ n_samples = f[inp_key].shape[0]
765
+ if n_samples > 100000:
766
+ raise ValueError(
767
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
768
+ f"everything into RAM which may cause OOM. For large inference "
769
+ f"sets, use a DataLoader with HDF5Source.load_mmap() instead."
770
+ )
771
+ inp = f[inp_key][:]
772
+ out_key = DataSource._find_key(keys, custom_output_keys)
773
+ outp = f[out_key][:] if out_key else None
774
+ elif format == "mat":
775
+ # MAT v7.3: input-only loading with proper sparse handling
776
+ mat_source = MATSource()
777
+ with h5py.File(path, "r") as f:
778
+ keys = list(f.keys())
779
+ inp_key = DataSource._find_key(keys, custom_input_keys)
780
+ if inp_key is None:
781
+ raise KeyError(
782
+ f"Input key not found. Tried: {custom_input_keys}. Found: {keys}"
783
+ )
784
+ # Check size - load_test_data is eager, large files should use DataLoader
785
+ n_samples = f[inp_key].shape[-1] # MAT is transposed
786
+ if n_samples > 100000:
787
+ raise ValueError(
788
+ f"Dataset has {n_samples:,} samples. load_test_data() loads "
789
+ f"everything into RAM which may cause OOM. For large inference "
790
+ f"sets, use a DataLoader with MATSource.load_mmap() instead."
791
+ )
792
+ # Use _load_dataset for sparse support and proper transpose
793
+ inp = mat_source._load_dataset(f, inp_key)
794
+ out_key = DataSource._find_key(keys, custom_output_keys)
795
+ if out_key:
796
+ outp = mat_source._load_dataset(f, out_key)
797
+ # Handle 1D outputs that become (1, N) after transpose
798
+ if outp.ndim == 2 and outp.shape[0] == 1:
799
+ outp = outp.T
800
+ else:
801
+ outp = None
754
802
  else:
755
803
  raise
756
804
 
@@ -949,6 +997,15 @@ def prepare_data(
949
997
  with open(META_FILE, "rb") as f:
950
998
  meta = pickle.load(f)
951
999
  cached_data_path = meta.get("data_path", None)
1000
+ cached_file_size = meta.get("file_size", None)
1001
+ cached_file_mtime = meta.get("file_mtime", None)
1002
+
1003
+ # Get current file stats
1004
+ current_stats = os.stat(args.data_path)
1005
+ current_size = current_stats.st_size
1006
+ current_mtime = current_stats.st_mtime
1007
+
1008
+ # Check if data path changed
952
1009
  if cached_data_path != os.path.abspath(args.data_path):
953
1010
  if accelerator.is_main_process:
954
1011
  logger.warning(
@@ -958,6 +1015,23 @@ def prepare_data(
958
1015
  f" Invalidating cache and regenerating..."
959
1016
  )
960
1017
  cache_exists = False
1018
+ # Check if file was modified (size or mtime changed)
1019
+ elif cached_file_size is not None and cached_file_size != current_size:
1020
+ if accelerator.is_main_process:
1021
+ logger.warning(
1022
+ f"⚠️ Data file size changed!\n"
1023
+ f" Cached size: {cached_file_size:,} bytes\n"
1024
+ f" Current size: {current_size:,} bytes\n"
1025
+ f" Invalidating cache and regenerating..."
1026
+ )
1027
+ cache_exists = False
1028
+ elif cached_file_mtime is not None and cached_file_mtime != current_mtime:
1029
+ if accelerator.is_main_process:
1030
+ logger.warning(
1031
+ "⚠️ Data file was modified!\n"
1032
+ " Cache may be stale, regenerating..."
1033
+ )
1034
+ cache_exists = False
961
1035
  except Exception:
962
1036
  cache_exists = False
963
1037
 
@@ -1053,13 +1127,16 @@ def prepare_data(
1053
1127
  f" Shape Detected: {full_shape} [{dim_type}] | Output Dim: {out_dim}"
1054
1128
  )
1055
1129
 
1056
- # Save metadata (including data path for cache validation)
1130
+ # Save metadata (including data path, size, mtime for cache validation)
1131
+ file_stats = os.stat(args.data_path)
1057
1132
  with open(META_FILE, "wb") as f:
1058
1133
  pickle.dump(
1059
1134
  {
1060
1135
  "shape": full_shape,
1061
1136
  "out_dim": out_dim,
1062
1137
  "data_path": os.path.abspath(args.data_path),
1138
+ "file_size": file_stats.st_size,
1139
+ "file_mtime": file_stats.st_mtime,
1063
1140
  },
1064
1141
  f,
1065
1142
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: wavedl
3
- Version: 1.4.1
3
+ Version: 1.4.2
4
4
  Summary: A Scalable Deep Learning Framework for Wave-Based Inverse Problems
5
5
  Author: Ductho Le
6
6
  License: MIT
@@ -1,8 +1,8 @@
1
- wavedl/__init__.py,sha256=2LU5rtHKoYgXBAZ4zGNtFcHjrTtmmYskXnaURHEwkNc,1177
1
+ wavedl/__init__.py,sha256=K52yq0nkj2B3W0ZpR3tb7RUcHMIiANAE2d1WTuGVZLI,1177
2
2
  wavedl/hpc.py,sha256=de_GKERX8GS10sXRX9yXiGzMnk1jjq8JPzRw7QDs6d4,7967
3
- wavedl/hpo.py,sha256=aZoa_Oto_anZpIhz-YM6kN8KxQXTolUvDEyg3NXwBrY,11542
3
+ wavedl/hpo.py,sha256=YJXsnSGEBSVUqp_2ah7zu3_VClAUqZrdkuzDaSqQUjU,12952
4
4
  wavedl/test.py,sha256=jZmRJaivYYTMMTaccCi0yQjHOfp0a9YWR1wAPeKFH-k,36246
5
- wavedl/train.py,sha256=e0tX7_j2gkuYpPjZJqGoDV8arAe4bc4YVRMyrg-RcRY,46402
5
+ wavedl/train.py,sha256=Gh02hlfjcote6w1sgUndJzKU_FhkJckTAhlq1aL8FY8,46842
6
6
  wavedl/models/__init__.py,sha256=lfSohEnAUztO14nuwayMJhPjpgySzRN3jGiyAUuBmAU,3206
7
7
  wavedl/models/_template.py,sha256=J_D8taSPmV8lBaucN_vU-WiG98iFr7CJrZVNNX_Tdts,4600
8
8
  wavedl/models/base.py,sha256=T9iDF9IQM2MYucG_ggQd31rieUkB2fob-nkHyNIl2ak,7337
@@ -23,15 +23,15 @@ wavedl/models/vit.py,sha256=0C3GZk11VsYFTl14d86Wtl1Zk1T5rYJjvkaEfEN4N3k,11100
23
23
  wavedl/utils/__init__.py,sha256=YMgzuwndjr64kt9k0_6_9PMJYTVdiaH5veSMff_ZycA,3051
24
24
  wavedl/utils/config.py,sha256=fMoucikIQHn85mVhGMa7TnXTuFDcEEPjfXk2EjbkJR0,10591
25
25
  wavedl/utils/cross_validation.py,sha256=117ac9KDzaIaqhtP8ZRs15Xpqmq5fLpX2-vqkNvtMaU,17487
26
- wavedl/utils/data.py,sha256=9LrB9MC6jRZzbRSc9xiGzJWoh8FahwP_68REqBAT3Os,44131
26
+ wavedl/utils/data.py,sha256=_OaWvU5oFVJW0NwM5WyDD0Kb1hy5MgvJIFpzvJGux9w,48214
27
27
  wavedl/utils/distributed.py,sha256=7wQ3mRjkp_xjPSxDWMnBf5dSkAGUaTzntxbz0BhC5v0,4145
28
28
  wavedl/utils/losses.py,sha256=5762M-TBC_hz6uyj1NPbU1vZeFOJQq7fR3-j7OygJRo,7254
29
29
  wavedl/utils/metrics.py,sha256=mkCpqZwl_XUpNvA5Ekjf7y-HqApafR7eR6EuA8cBdM8,37287
30
30
  wavedl/utils/optimizers.py,sha256=PyIkJ_hRhFi_Fio81Gy5YQNhcME0JUUEl8OTSyu-0RA,6323
31
31
  wavedl/utils/schedulers.py,sha256=e6Sf0yj8VOqkdwkUHLMyUfGfHKTX4NMr-zfgxWqCTYI,7659
32
- wavedl-1.4.1.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
- wavedl-1.4.1.dist-info/METADATA,sha256=FEafy9hY2su6bB8iS8VNZceLpTs9E7nhVaejsOEHTUM,40245
34
- wavedl-1.4.1.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
- wavedl-1.4.1.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
- wavedl-1.4.1.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
- wavedl-1.4.1.dist-info/RECORD,,
32
+ wavedl-1.4.2.dist-info/LICENSE,sha256=cEUCvcvH-9BT9Y-CNGY__PwWONCKu9zsoIqWA-NeHJ4,1066
33
+ wavedl-1.4.2.dist-info/METADATA,sha256=xHMuRcGdF8Tdju6aBK601mqLYx0uSIdmtjoQrHpPYmI,40245
34
+ wavedl-1.4.2.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
35
+ wavedl-1.4.2.dist-info/entry_points.txt,sha256=f1RNDkXFZwBzrBzTMFocJ6xhfTvTmaEDTi5YyDEUaF8,140
36
+ wavedl-1.4.2.dist-info/top_level.txt,sha256=ccneUt3D5Qzbh3bsBSSrq9bqrhGiogcWKY24ZC4Q6Xw,7
37
+ wavedl-1.4.2.dist-info/RECORD,,
File without changes