ins-pricing 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ins_pricing/cli/BayesOpt_entry.py +15 -5
  2. ins_pricing/cli/BayesOpt_incremental.py +43 -10
  3. ins_pricing/cli/Explain_Run.py +16 -5
  4. ins_pricing/cli/Explain_entry.py +29 -8
  5. ins_pricing/cli/Pricing_Run.py +16 -5
  6. ins_pricing/cli/bayesopt_entry_runner.py +45 -12
  7. ins_pricing/cli/utils/bootstrap.py +23 -0
  8. ins_pricing/cli/utils/cli_config.py +34 -15
  9. ins_pricing/cli/utils/import_resolver.py +14 -14
  10. ins_pricing/cli/utils/notebook_utils.py +120 -106
  11. ins_pricing/cli/watchdog_run.py +15 -5
  12. ins_pricing/frontend/app.py +132 -61
  13. ins_pricing/frontend/config_builder.py +33 -0
  14. ins_pricing/frontend/example_config.json +11 -0
  15. ins_pricing/frontend/runner.py +340 -388
  16. ins_pricing/modelling/README.md +1 -1
  17. ins_pricing/modelling/bayesopt/README.md +29 -11
  18. ins_pricing/modelling/bayesopt/config_components.py +12 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +50 -13
  20. ins_pricing/modelling/bayesopt/core.py +47 -19
  21. ins_pricing/modelling/bayesopt/model_plotting_mixin.py +20 -14
  22. ins_pricing/modelling/bayesopt/models/model_ft_components.py +349 -342
  23. ins_pricing/modelling/bayesopt/models/model_ft_trainer.py +11 -5
  24. ins_pricing/modelling/bayesopt/models/model_gnn.py +20 -14
  25. ins_pricing/modelling/bayesopt/models/model_resn.py +9 -3
  26. ins_pricing/modelling/bayesopt/trainers/trainer_base.py +62 -50
  27. ins_pricing/modelling/bayesopt/trainers/trainer_ft.py +61 -53
  28. ins_pricing/modelling/bayesopt/trainers/trainer_glm.py +9 -3
  29. ins_pricing/modelling/bayesopt/trainers/trainer_gnn.py +40 -32
  30. ins_pricing/modelling/bayesopt/trainers/trainer_resn.py +36 -24
  31. ins_pricing/modelling/bayesopt/trainers/trainer_xgb.py +240 -37
  32. ins_pricing/modelling/bayesopt/utils/distributed_utils.py +193 -186
  33. ins_pricing/modelling/bayesopt/utils/torch_trainer_mixin.py +23 -10
  34. ins_pricing/pricing/factors.py +67 -56
  35. ins_pricing/setup.py +1 -1
  36. ins_pricing/utils/__init__.py +7 -6
  37. ins_pricing/utils/device.py +45 -24
  38. ins_pricing/utils/logging.py +34 -1
  39. ins_pricing/utils/profiling.py +8 -4
  40. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +182 -182
  41. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.1.dist-info}/RECORD +43 -42
  42. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  43. {ins_pricing-0.5.0.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from torch.nn.utils import clip_grad_norm_
17
17
 
18
18
  from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
19
19
  from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
20
- from ins_pricing.utils import EPS
20
+ from ins_pricing.utils import EPS, get_logger, log_print
21
21
  from ins_pricing.utils.losses import (
22
22
  infer_loss_name_from_model_name,
23
23
  normalize_loss_name,
@@ -25,6 +25,12 @@ from ins_pricing.utils.losses import (
25
25
  )
26
26
  from ins_pricing.modelling.bayesopt.models.model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
27
27
 
28
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_ft_trainer")
29
+
30
+
31
+ def _log(*args, **kwargs) -> None:
32
+ log_print(_logger, *args, **kwargs)
33
+
28
34
 
29
35
  # --- Helper functions for reconstruction loss computation ---
30
36
 
@@ -281,7 +287,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
281
287
  self.use_data_parallel = False
282
288
  elif use_dp:
283
289
  if self.use_ddp and not self.is_ddp_enabled:
284
- print(
290
+ _log(
285
291
  ">>> DDP requested but not initialized; falling back to DataParallel.")
286
292
  core = nn.DataParallel(core, device_ids=list(
287
293
  range(torch.cuda.device_count())))
@@ -699,15 +705,15 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
699
705
  should_log = (not dist.is_initialized()
700
706
  or DistributedUtils.is_main_process())
701
707
  if should_log:
702
- print(msg, flush=True)
703
- print(
708
+ _log(msg, flush=True)
709
+ _log(
704
710
  f" X_num: finite={bool(torch.isfinite(X_num_b).all())} "
705
711
  f"min={float(X_num_b.min().detach().cpu()) if X_num_b.numel() else 0.0:.3g} "
706
712
  f"max={float(X_num_b.max().detach().cpu()) if X_num_b.numel() else 0.0:.3g}",
707
713
  flush=True,
708
714
  )
709
715
  if X_geo_b is not None:
710
- print(
716
+ _log(
711
717
  f" X_geo: finite={bool(torch.isfinite(X_geo_b).all())} "
712
718
  f"min={float(X_geo_b.min().detach().cpu()) if X_geo_b.numel() else 0.0:.3g} "
713
719
  f"max={float(X_geo_b.max().detach().cpu()) if X_geo_b.numel() else 0.0:.3g}",
@@ -18,7 +18,7 @@ from torch.nn.utils import clip_grad_norm_
18
18
 
19
19
  from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
20
20
  from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
21
- from ins_pricing.utils import EPS
21
+ from ins_pricing.utils import EPS, get_logger, log_print
22
22
  from ins_pricing.utils.io import IOUtils
23
23
  from ins_pricing.utils.losses import (
24
24
  infer_loss_name_from_model_name,
@@ -45,6 +45,12 @@ except Exception:
45
45
 
46
46
  _GNN_MPS_WARNED = False
47
47
 
48
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_gnn")
49
+
50
+
51
+ def _log(*args, **kwargs) -> None:
52
+ log_print(_logger, *args, **kwargs)
53
+
48
54
 
49
55
  # =============================================================================
50
56
  # Simplified GNN implementation.
@@ -169,7 +175,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
169
175
  if use_ddp:
170
176
  world_size = int(os.environ.get("WORLD_SIZE", "1"))
171
177
  if world_size > 1:
172
- print(
178
+ _log(
173
179
  "[GNN] DDP training is not supported; falling back to single process.",
174
180
  flush=True,
175
181
  )
@@ -194,7 +200,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
194
200
  self.device = torch.device('mps')
195
201
  global _GNN_MPS_WARNED
196
202
  if not _GNN_MPS_WARNED:
197
- print(
203
+ _log(
198
204
  "[GNN] Using MPS backend; will fall back to CPU on unsupported ops.",
199
205
  flush=True,
200
206
  )
@@ -271,7 +277,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
271
277
  if self.device.type != "mps" or self._mps_fallback_triggered:
272
278
  return
273
279
  self._mps_fallback_triggered = True
274
- print(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
280
+ _log(f"[GNN] MPS op unsupported ({reason}); falling back to CPU.", flush=True)
275
281
  self.device = torch.device("cpu")
276
282
  self.use_pyg_knn = False
277
283
  self.data_parallel_enabled = False
@@ -347,7 +353,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
347
353
  try:
348
354
  payload = torch.load(self.graph_cache_path, map_location="cpu")
349
355
  except Exception as exc:
350
- print(
356
+ _log(
351
357
  f"[GNN] Failed to load cached graph from {self.graph_cache_path}: {exc}")
352
358
  return None
353
359
  if isinstance(payload, dict) and "adj" in payload:
@@ -355,19 +361,19 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
355
361
  if meta_cached == meta_expected:
356
362
  adj = payload["adj"]
357
363
  if self.device.type == "mps" and getattr(adj, "is_sparse", False):
358
- print(
364
+ _log(
359
365
  f"[GNN] Cached sparse graph incompatible with MPS; rebuilding: {self.graph_cache_path}"
360
366
  )
361
367
  return None
362
368
  return adj.to(self.device)
363
- print(
369
+ _log(
364
370
  f"[GNN] Cached graph metadata mismatch; rebuilding: {self.graph_cache_path}")
365
371
  return None
366
372
  if isinstance(payload, torch.Tensor):
367
- print(
373
+ _log(
368
374
  f"[GNN] Cached graph missing metadata; rebuilding: {self.graph_cache_path}")
369
375
  return None
370
- print(
376
+ _log(
371
377
  f"[GNN] Invalid cached graph format; rebuilding: {self.graph_cache_path}")
372
378
  return None
373
379
 
@@ -387,7 +393,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
387
393
  )
388
394
  indices, _ = nn_index.neighbor_graph
389
395
  except Exception as exc:
390
- print(
396
+ _log(
391
397
  f"[GNN] Approximate kNN failed ({exc}); falling back to exact search.")
392
398
  use_approx = False
393
399
 
@@ -440,7 +446,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
440
446
  if self._knn_warning_emitted:
441
447
  return
442
448
  if (not self.ddp_enabled) or self.local_rank == 0:
443
- print(f"[GNN] Falling back to CPU kNN builder: {reason}")
449
+ _log(f"[GNN] Falling back to CPU kNN builder: {reason}")
444
450
  self._knn_warning_emitted = True
445
451
 
446
452
  def _should_use_gpu_knn(self, n_samples: int, X_tensor: torch.Tensor) -> bool:
@@ -592,7 +598,7 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
592
598
  IOUtils.ensure_parent_dir(str(self.graph_cache_path))
593
599
  torch.save({"adj": adj_norm.cpu(), "meta": meta_expected}, self.graph_cache_path)
594
600
  except Exception as exc:
595
- print(
601
+ _log(
596
602
  f"[GNN] Failed to cache graph to {self.graph_cache_path}: {exc}")
597
603
  self._adj_cache_meta = meta_expected
598
604
  self._adj_cache_key = None
@@ -712,12 +718,12 @@ class GraphNeuralNetSklearn(TorchTrainerMixin, nn.Module):
712
718
  if should_log:
713
719
  elapsed = int(time.time() - epoch_start_ts)
714
720
  if val_loss is None:
715
- print(
721
+ _log(
716
722
  f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} elapsed={elapsed}s",
717
723
  flush=True,
718
724
  )
719
725
  else:
720
- print(
726
+ _log(
721
727
  f"[GNN] Epoch {epoch}/{self.epochs} loss={float(loss):.6f} "
722
728
  f"val_loss={float(val_loss):.6f} elapsed={elapsed}s",
723
729
  flush=True,
@@ -13,13 +13,19 @@ from torch.utils.data import TensorDataset
13
13
 
14
14
  from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
15
15
  from ins_pricing.modelling.bayesopt.utils.torch_trainer_mixin import TorchTrainerMixin
16
- from ins_pricing.utils import EPS
16
+ from ins_pricing.utils import EPS, get_logger, log_print
17
17
  from ins_pricing.utils.losses import (
18
18
  infer_loss_name_from_model_name,
19
19
  normalize_loss_name,
20
20
  resolve_tweedie_power,
21
21
  )
22
22
 
23
+ _logger = get_logger("ins_pricing.modelling.bayesopt.models.model_resn")
24
+
25
+
26
+ def _log(*args, **kwargs) -> None:
27
+ log_print(_logger, *args, **kwargs)
28
+
23
29
 
24
30
  # =============================================================================
25
31
  # ResNet model and sklearn-style wrapper
@@ -130,7 +136,7 @@ class ResNetSequential(nn.Module):
130
136
 
131
137
  def forward(self, x):
132
138
  if self.training and not hasattr(self, '_printed_device'):
133
- print(f">>> ResNetSequential executing on device: {x.device}")
139
+ _log(f">>> ResNetSequential executing on device: {x.device}")
134
140
  self._printed_device = True
135
141
  return self.net(x)
136
142
 
@@ -220,7 +226,7 @@ class ResNetSklearn(TorchTrainerMixin, nn.Module):
220
226
  self.use_data_parallel = False
221
227
  elif use_data_parallel and (self.device.type == 'cuda') and (torch.cuda.device_count() > 1):
222
228
  if self.use_ddp and not self.is_ddp_enabled:
223
- print(
229
+ _log(
224
230
  ">>> DDP requested but not initialized; falling back to DataParallel.")
225
231
  core = nn.DataParallel(core, device_ids=list(
226
232
  range(torch.cuda.device_count())))
@@ -26,12 +26,16 @@ from sklearn.preprocessing import StandardScaler
26
26
 
27
27
  from ins_pricing.modelling.bayesopt.config_preprocess import BayesOptConfig, OutputManager
28
28
  from ins_pricing.modelling.bayesopt.utils.distributed_utils import DistributedUtils
29
- from ins_pricing.utils import EPS, ensure_parent_dir, get_logger, GPUMemoryManager, DeviceManager
29
+ from ins_pricing.utils import EPS, ensure_parent_dir, get_logger, GPUMemoryManager, DeviceManager, log_print
30
30
  from ins_pricing.utils.torch_compat import torch_load
31
31
 
32
32
  # Module-level logger
33
33
  _logger = get_logger("ins_pricing.trainer")
34
34
 
35
+
36
+ def _log(*args, **kwargs) -> None:
37
+ log_print(_logger, *args, **kwargs)
38
+
35
39
  class _OrderSplitter:
36
40
  def __init__(self, splitter, order: np.ndarray) -> None:
37
41
  self._splitter = splitter
@@ -364,7 +368,7 @@ class TrainerBase:
364
368
  try:
365
369
  rank = dist.get_rank()
366
370
  world = dist.get_world_size()
367
- print(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
371
+ _log(f"[DDP][{self.label}] entering barrier({reason}) rank={rank}/{world}", flush=True)
368
372
  except Exception:
369
373
  debug_barrier = False
370
374
  try:
@@ -399,9 +403,9 @@ class TrainerBase:
399
403
  else:
400
404
  dist.barrier()
401
405
  if debug_barrier:
402
- print(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
406
+ _log(f"[DDP][{self.label}] exit barrier({reason}) rank={rank}/{world}", flush=True)
403
407
  except Exception as exc:
404
- print(
408
+ _log(
405
409
  f"[DDP][{self.label}] barrier failed during {reason}: {exc}",
406
410
  flush=True,
407
411
  )
@@ -433,12 +437,15 @@ class TrainerBase:
433
437
  ensure_parent_dir(str(path))
434
438
  return f"sqlite:///{path.as_posix()}"
435
439
 
436
- def _resolve_optuna_study_name(self) -> str:
437
- prefix = getattr(self.config, "optuna_study_prefix",
438
- None) or "bayesopt"
439
- raw = f"{prefix}_{self.ctx.model_nme}_{self.model_name_prefix}"
440
- safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
441
- return safe.lower()
440
+ def _resolve_optuna_study_name(self) -> str:
441
+ prefix = getattr(self.config, "optuna_study_prefix",
442
+ None) or "bayesopt"
443
+ raw = f"{prefix}_{self.ctx.model_nme}_{self.model_name_prefix}"
444
+ safe = "".join([c if c.isalnum() or c in "._-" else "_" for c in raw])
445
+ return safe.lower()
446
+
447
+ def _optuna_cleanup_sync(self) -> bool:
448
+ return bool(getattr(self.config, "optuna_cleanup_synchronize", False))
442
449
 
443
450
  def tune(self, max_evals: int, objective_fn=None) -> None:
444
451
  # Generic Optuna tuning loop.
@@ -457,27 +464,27 @@ class TrainerBase:
457
464
  should_log = DistributedUtils.is_main_process()
458
465
  if should_log:
459
466
  current_idx = progress_counter["count"] + 1
460
- print(
467
+ _log(
461
468
  f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
462
469
  f"(trial_id={trial.number})."
463
470
  )
464
471
  try:
465
472
  result = objective_fn(trial)
466
- except RuntimeError as exc:
467
- if "out of memory" in str(exc).lower():
468
- print(
469
- f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
470
- )
471
- self._clean_gpu()
472
- raise optuna.TrialPruned() from exc
473
- raise
474
- finally:
475
- self._clean_gpu()
473
+ except RuntimeError as exc:
474
+ if "out of memory" in str(exc).lower():
475
+ _log(
476
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
477
+ )
478
+ self._clean_gpu(synchronize=True)
479
+ raise optuna.TrialPruned() from exc
480
+ raise
481
+ finally:
482
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
476
483
  if should_log:
477
484
  progress_counter["count"] = progress_counter["count"] + 1
478
485
  trial_state = getattr(trial, "state", None)
479
486
  state_repr = getattr(trial_state, "name", "OK")
480
- print(
487
+ _log(
481
488
  f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
482
489
  f"(status={state_repr})."
483
490
  )
@@ -552,7 +559,7 @@ class TrainerBase:
552
559
 
553
560
  def save(self) -> None:
554
561
  if self.model is None:
555
- print(f"[save] Warning: No model to save for {self.label}")
562
+ _log(f"[save] Warning: No model to save for {self.label}")
556
563
  return
557
564
 
558
565
  path = self.output.model_path(self._get_model_filename())
@@ -615,7 +622,7 @@ class TrainerBase:
615
622
  def load(self) -> None:
616
623
  path = self.output.model_path(self._get_model_filename())
617
624
  if not os.path.exists(path):
618
- print(f"[load] Warning: Model file not found: {path}")
625
+ _log(f"[load] Warning: Model file not found: {path}")
619
626
  return
620
627
 
621
628
  if self.label in ['Xgboost', 'GLM']:
@@ -695,7 +702,7 @@ class TrainerBase:
695
702
  self.model = loaded_model
696
703
  else:
697
704
  # Unknown format
698
- print(f"[load] Warning: Unknown model format in {path}")
705
+ _log(f"[load] Warning: Unknown model format in {path}")
699
706
  else:
700
707
  # Very old format: direct model object
701
708
  if loaded is not None:
@@ -749,14 +756,14 @@ class TrainerBase:
749
756
 
750
757
  def _distributed_worker_loop(self, objective_fn: Callable[[Optional[optuna.trial.Trial]], float]) -> None:
751
758
  if dist is None:
752
- print(
759
+ _log(
753
760
  f"[Optuna][Worker][{self.label}] torch.distributed unavailable. Worker exit.",
754
761
  flush=True,
755
762
  )
756
763
  return
757
764
  DistributedUtils.setup_ddp()
758
765
  if not dist.is_initialized():
759
- print(
766
+ _log(
760
767
  f"[Optuna][Worker][{self.label}] DDP init failed. Worker exit.",
761
768
  flush=True,
762
769
  )
@@ -783,16 +790,16 @@ class TrainerBase:
783
790
  except optuna.TrialPruned:
784
791
  pass
785
792
  except Exception as exc:
786
- print(
793
+ _log(
787
794
  f"[Optuna][Worker][{self.label}] Exception: {exc}", flush=True)
788
- finally:
789
- self._clean_gpu()
790
- # STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
791
- self._dist_barrier("worker_end")
795
+ finally:
796
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
797
+ # STEP 2 (DDP/Optuna): align worker with rank0 after objective_fn returns/raises.
798
+ self._dist_barrier("worker_end")
792
799
 
793
800
  def _distributed_tune(self, max_evals: int, objective_fn: Callable[[optuna.trial.Trial], float]) -> None:
794
801
  if dist is None:
795
- print(
802
+ _log(
796
803
  f"[Optuna][{self.label}] torch.distributed unavailable. Fallback to single-process.",
797
804
  flush=True,
798
805
  )
@@ -807,12 +814,12 @@ class TrainerBase:
807
814
  if not dist.is_initialized():
808
815
  rank_env = os.environ.get("RANK", "0")
809
816
  if str(rank_env) != "0":
810
- print(
817
+ _log(
811
818
  f"[Optuna][{self.label}] DDP init failed on worker. Skip.",
812
819
  flush=True,
813
820
  )
814
821
  return
815
- print(
822
+ _log(
816
823
  f"[Optuna][{self.label}] DDP init failed. Fallback to single-process.",
817
824
  flush=True,
818
825
  )
@@ -834,27 +841,27 @@ class TrainerBase:
834
841
  should_log = True
835
842
  if should_log:
836
843
  current_idx = progress_counter["count"] + 1
837
- print(
844
+ _log(
838
845
  f"[Optuna][{self.label}] Trial {current_idx}/{total_trials} started "
839
846
  f"(trial_id={trial.number})."
840
847
  )
841
848
  try:
842
849
  result = objective_fn(trial)
843
- except RuntimeError as exc:
844
- if "out of memory" in str(exc).lower():
845
- print(
846
- f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
847
- )
848
- self._clean_gpu()
849
- raise optuna.TrialPruned() from exc
850
- raise
851
- finally:
852
- self._clean_gpu()
850
+ except RuntimeError as exc:
851
+ if "out of memory" in str(exc).lower():
852
+ _log(
853
+ f"[Optuna][{self.label}] OOM detected. Pruning trial and clearing CUDA cache."
854
+ )
855
+ self._clean_gpu(synchronize=True)
856
+ raise optuna.TrialPruned() from exc
857
+ raise
858
+ finally:
859
+ self._clean_gpu(synchronize=self._optuna_cleanup_sync())
853
860
  if should_log:
854
861
  progress_counter["count"] = progress_counter["count"] + 1
855
862
  trial_state = getattr(trial, "state", None)
856
863
  state_repr = getattr(trial_state, "name", "OK")
857
- print(
864
+ _log(
858
865
  f"[Optuna][{self.label}] Trial {progress_counter['count']}/{total_trials} finished "
859
866
  f"(status={state_repr})."
860
867
  )
@@ -919,9 +926,14 @@ class TrainerBase:
919
926
  self._distributed_send_command(
920
927
  {"type": "STOP", "best_params": self.best_params})
921
928
 
922
- def _clean_gpu(self):
923
- """Clean up GPU memory using shared GPUMemoryManager."""
924
- GPUMemoryManager.clean()
929
+ def _clean_gpu(
930
+ self,
931
+ *,
932
+ synchronize: bool = True,
933
+ empty_cache: bool = True,
934
+ ) -> None:
935
+ """Clean up GPU memory using shared GPUMemoryManager."""
936
+ GPUMemoryManager.clean(synchronize=synchronize, empty_cache=empty_cache)
925
937
 
926
938
  def _standardize_fold(self,
927
939
  X_train: pd.DataFrame,
@@ -11,17 +11,33 @@ from sklearn.model_selection import GroupKFold, TimeSeriesSplit
11
11
  from ins_pricing.modelling.bayesopt.trainers.trainer_base import TrainerBase
12
12
  from ins_pricing.modelling.bayesopt.models import FTTransformerSklearn
13
13
  from ins_pricing.utils.losses import regression_loss
14
-
15
-
16
- class FTTrainer(TrainerBase):
17
- def __init__(self, context: "BayesOptModel") -> None:
18
- if context.task_type == 'classification':
19
- super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
20
- else:
21
- super().__init__(context, 'FTTransformer', 'FTTransformer')
22
- self.model: Optional[FTTransformerSklearn] = None
23
- self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
24
- self._cv_geo_warned = False
14
+ from ins_pricing.utils import get_logger, log_print
15
+
16
+ _logger = get_logger("ins_pricing.trainer.ft")
17
+
18
+
19
+ def _log(*args, **kwargs) -> None:
20
+ log_print(_logger, *args, **kwargs)
21
+
22
+
23
+ class FTTrainer(TrainerBase):
24
+ def __init__(self, context: "BayesOptModel") -> None:
25
+ if context.task_type == 'classification':
26
+ super().__init__(context, 'FTTransformerClassifier', 'FTTransformer')
27
+ else:
28
+ super().__init__(context, 'FTTransformer', 'FTTransformer')
29
+ self.model: Optional[FTTransformerSklearn] = None
30
+ self.enable_distributed_optuna = bool(context.config.use_ft_ddp)
31
+ self._cv_geo_warned = False
32
+
33
+ def _maybe_cleanup_gpu(self, model: Optional[FTTransformerSklearn]) -> None:
34
+ if not bool(getattr(self.ctx.config, "ft_cleanup_per_fold", False)):
35
+ return
36
+ if model is not None:
37
+ getattr(getattr(model, "ft", None), "to",
38
+ lambda *_args, **_kwargs: None)("cpu")
39
+ synchronize = bool(getattr(self.ctx.config, "ft_cleanup_synchronize", False))
40
+ self._clean_gpu(synchronize=synchronize)
25
41
 
26
42
  def _resolve_numeric_tokens(self) -> int:
27
43
  requested = getattr(self.ctx.config, "ft_num_numeric_tokens", None)
@@ -121,7 +137,7 @@ class FTTrainer(TrainerBase):
121
137
  if built is not None:
122
138
  geo_train, geo_val, _, _ = built
123
139
  elif not self._cv_geo_warned:
124
- print(
140
+ _log(
125
141
  "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
126
142
  flush=True,
127
143
  )
@@ -168,22 +184,20 @@ class FTTrainer(TrainerBase):
168
184
  )
169
185
  model = self._apply_dataloader_overrides(model)
170
186
  model.set_params(model_params)
171
- try:
172
- return float(model.fit_unsupervised(
173
- X_train,
174
- X_val=X_val,
175
- trial=trial,
176
- geo_train=geo_train,
177
- geo_val=geo_val,
178
- mask_prob_num=mask_prob_num,
179
- mask_prob_cat=mask_prob_cat,
180
- num_loss_weight=num_loss_weight,
181
- cat_loss_weight=cat_loss_weight
182
- ))
183
- finally:
184
- getattr(getattr(model, "ft", None), "to",
185
- lambda *_args, **_kwargs: None)("cpu")
186
- self._clean_gpu()
187
+ try:
188
+ return float(model.fit_unsupervised(
189
+ X_train,
190
+ X_val=X_val,
191
+ trial=trial,
192
+ geo_train=geo_train,
193
+ geo_val=geo_val,
194
+ mask_prob_num=mask_prob_num,
195
+ mask_prob_cat=mask_prob_cat,
196
+ num_loss_weight=num_loss_weight,
197
+ cat_loss_weight=cat_loss_weight
198
+ ))
199
+ finally:
200
+ self._maybe_cleanup_gpu(model)
187
201
 
188
202
  def cross_val(self, trial: optuna.trial.Trial) -> float:
189
203
  # FT-Transformer CV also focuses on memory control:
@@ -229,7 +243,7 @@ class FTTrainer(TrainerBase):
229
243
  token_count += 1
230
244
  approx_units = d_model * n_layers * max(1, token_count)
231
245
  if approx_units > 12_000_000:
232
- print(
246
+ _log(
233
247
  f"[FTTrainer] Trial pruned early: d_model={d_model}, n_layers={n_layers} -> approx_units={approx_units}")
234
248
  raise optuna.TrialPruned(
235
249
  "config exceeds safe memory budget; prune before training")
@@ -285,7 +299,7 @@ class FTTrainer(TrainerBase):
285
299
  if built is not None:
286
300
  geo_train, geo_val, _, _ = built
287
301
  elif not self._cv_geo_warned:
288
- print(
302
+ _log(
289
303
  "[FTTrainer] Geo tokens unavailable for CV split; continue without geo tokens.",
290
304
  flush=True,
291
305
  )
@@ -338,7 +352,7 @@ class FTTrainer(TrainerBase):
338
352
  requested_heads=resolved_params.get("n_heads")
339
353
  )
340
354
  if heads_adjusted:
341
- print(f"[FTTrainer] Auto-adjusted n_heads from "
355
+ _log(f"[FTTrainer] Auto-adjusted n_heads from "
342
356
  f"{resolved_params.get('n_heads')} to {adaptive_heads} "
343
357
  f"(d_model={d_model_value}).")
344
358
  resolved_params["n_heads"] = adaptive_heads
@@ -378,13 +392,11 @@ class FTTrainer(TrainerBase):
378
392
  geo_train=geo_train,
379
393
  geo_val=geo_val,
380
394
  )
381
- refit_epochs = self._resolve_best_epoch(
382
- getattr(tmp_model, "training_history", None),
383
- default_epochs=int(self.ctx.epochs),
384
- )
385
- getattr(getattr(tmp_model, "ft", None), "to",
386
- lambda *_args, **_kwargs: None)("cpu")
387
- self._clean_gpu()
395
+ refit_epochs = self._resolve_best_epoch(
396
+ getattr(tmp_model, "training_history", None),
397
+ default_epochs=int(self.ctx.epochs),
398
+ )
399
+ self._maybe_cleanup_gpu(tmp_model)
388
400
 
389
401
  self.model = FTTransformerSklearn(
390
402
  model_nme=self.ctx.model_nme,
@@ -451,7 +463,7 @@ class FTTrainer(TrainerBase):
451
463
 
452
464
  split_iter, _ = self._resolve_ensemble_splits(X_all, k=k)
453
465
  if split_iter is None:
454
- print(
466
+ _log(
455
467
  f"[FT Ensemble] unable to build CV split (n_samples={n_samples}); skip ensemble.",
456
468
  flush=True,
457
469
  )
@@ -494,15 +506,13 @@ class FTTrainer(TrainerBase):
494
506
 
495
507
  pred_train = model.predict(X_all, geo_tokens=geo_train_full)
496
508
  pred_test = model.predict(X_test, geo_tokens=geo_test_full)
497
- preds_train_sum += np.asarray(pred_train, dtype=np.float64)
498
- preds_test_sum += np.asarray(pred_test, dtype=np.float64)
499
- getattr(getattr(model, "ft", None), "to",
500
- lambda *_args, **_kwargs: None)("cpu")
501
- self._clean_gpu()
502
- split_count += 1
509
+ preds_train_sum += np.asarray(pred_train, dtype=np.float64)
510
+ preds_test_sum += np.asarray(pred_test, dtype=np.float64)
511
+ self._maybe_cleanup_gpu(model)
512
+ split_count += 1
503
513
 
504
514
  if split_count < 1:
505
- print(
515
+ _log(
506
516
  f"[FT Ensemble] no CV splits generated; skip ensemble.",
507
517
  flush=True,
508
518
  )
@@ -591,7 +601,7 @@ class FTTrainer(TrainerBase):
591
601
  requested_heads=resolved_params.get("n_heads"),
592
602
  )
593
603
  if heads_adjusted:
594
- print(
604
+ _log(
595
605
  f"[FTTrainer] Auto-adjusted n_heads from "
596
606
  f"{resolved_params.get('n_heads')} to {adaptive_heads} "
597
607
  f"(d_model={resolved_params.get('d_model', model.d_model)})."
@@ -652,11 +662,9 @@ class FTTrainer(TrainerBase):
652
662
  if preds_train is None:
653
663
  preds_train = np.empty(
654
664
  (len(X_all),) + fold_pred.shape[1:], dtype=fold_pred.dtype)
655
- preds_train[val_idx] = fold_pred
656
-
657
- getattr(getattr(model, "ft", None), "to",
658
- lambda *_a, **_k: None)("cpu")
659
- self._clean_gpu()
665
+ preds_train[val_idx] = fold_pred
666
+
667
+ self._maybe_cleanup_gpu(model)
660
668
 
661
669
  if preds_train is None:
662
670
  return None
@@ -773,7 +781,7 @@ class FTTrainer(TrainerBase):
773
781
  requested_heads=resolved_params.get("n_heads")
774
782
  )
775
783
  if heads_adjusted:
776
- print(f"[FTTrainer] Auto-adjusted n_heads from "
784
+ _log(f"[FTTrainer] Auto-adjusted n_heads from "
777
785
  f"{resolved_params.get('n_heads')} to {adaptive_heads} "
778
786
  f"(d_model={resolved_params.get('d_model', self.model.d_model)}).")
779
787
  resolved_params["n_heads"] = adaptive_heads