ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/Explain_entry.py +50 -48
  4. ins_pricing/cli/bayesopt_entry_runner.py +699 -569
  5. ins_pricing/cli/utils/evaluation_context.py +320 -0
  6. ins_pricing/cli/utils/import_resolver.py +350 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  8. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  9. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  10. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  11. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  12. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  13. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
  14. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
  15. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  16. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  17. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  18. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  19. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  20. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
  21. ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
  22. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  23. ins_pricing/setup.py +1 -1
  24. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
  25. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
  26. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,118 @@ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
19
19
  from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
20
20
 
21
21
 
22
+ # --- Helper functions for reconstruction loss computation ---
23
+
24
+
25
+ def _compute_numeric_reconstruction_loss(
26
+ num_pred: Optional[torch.Tensor],
27
+ num_true: Optional[torch.Tensor],
28
+ num_mask: Optional[torch.Tensor],
29
+ loss_weight: float,
30
+ device: torch.device,
31
+ ) -> torch.Tensor:
32
+ """Compute MSE loss for numeric feature reconstruction.
33
+
34
+ Args:
35
+ num_pred: Predicted numeric values (N, num_features)
36
+ num_true: Ground truth numeric values (N, num_features)
37
+ num_mask: Boolean mask indicating which values were masked (N, num_features)
38
+ loss_weight: Weight to apply to the loss
39
+ device: Target device for computation
40
+
41
+ Returns:
42
+ Weighted MSE loss for masked numeric features
43
+ """
44
+ if num_pred is None or num_true is None or num_mask is None:
45
+ return torch.zeros((), device=device, dtype=torch.float32)
46
+
47
+ num_mask = num_mask.to(dtype=torch.bool)
48
+ if not num_mask.any():
49
+ return torch.zeros((), device=device, dtype=torch.float32)
50
+
51
+ diff = num_pred - num_true
52
+ mse = diff * diff
53
+ return float(loss_weight) * mse[num_mask].mean()
54
+
55
+
56
+ def _compute_categorical_reconstruction_loss(
57
+ cat_logits: Optional[List[torch.Tensor]],
58
+ cat_true: Optional[torch.Tensor],
59
+ cat_mask: Optional[torch.Tensor],
60
+ loss_weight: float,
61
+ device: torch.device,
62
+ ) -> torch.Tensor:
63
+ """Compute cross-entropy loss for categorical feature reconstruction.
64
+
65
+ Args:
66
+ cat_logits: List of logits for each categorical feature
67
+ cat_true: Ground truth categorical indices (N, num_cat_features)
68
+ cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
69
+ loss_weight: Weight to apply to the loss
70
+ device: Target device for computation
71
+
72
+ Returns:
73
+ Weighted cross-entropy loss for masked categorical features
74
+ """
75
+ if not cat_logits or cat_true is None or cat_mask is None:
76
+ return torch.zeros((), device=device, dtype=torch.float32)
77
+
78
+ cat_mask = cat_mask.to(dtype=torch.bool)
79
+ cat_losses: List[torch.Tensor] = []
80
+
81
+ for j, logits in enumerate(cat_logits):
82
+ mask_j = cat_mask[:, j]
83
+ if not mask_j.any():
84
+ continue
85
+ targets = cat_true[:, j]
86
+ cat_losses.append(
87
+ F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
88
+ )
89
+
90
+ if not cat_losses:
91
+ return torch.zeros((), device=device, dtype=torch.float32)
92
+
93
+ return float(loss_weight) * torch.stack(cat_losses).mean()
94
+
95
+
96
+ def _compute_reconstruction_loss(
97
+ num_pred: Optional[torch.Tensor],
98
+ cat_logits: Optional[List[torch.Tensor]],
99
+ num_true: Optional[torch.Tensor],
100
+ num_mask: Optional[torch.Tensor],
101
+ cat_true: Optional[torch.Tensor],
102
+ cat_mask: Optional[torch.Tensor],
103
+ num_loss_weight: float,
104
+ cat_loss_weight: float,
105
+ device: torch.device,
106
+ ) -> torch.Tensor:
107
+ """Compute combined reconstruction loss for masked tabular data.
108
+
109
+ This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
110
+
111
+ Args:
112
+ num_pred: Predicted numeric values
113
+ cat_logits: List of logits for categorical features
114
+ num_true: Ground truth numeric values
115
+ num_mask: Mask for numeric features
116
+ cat_true: Ground truth categorical indices
117
+ cat_mask: Mask for categorical features
118
+ num_loss_weight: Weight for numeric loss
119
+ cat_loss_weight: Weight for categorical loss
120
+ device: Target device for computation
121
+
122
+ Returns:
123
+ Combined weighted reconstruction loss
124
+ """
125
+ num_loss = _compute_numeric_reconstruction_loss(
126
+ num_pred, num_true, num_mask, num_loss_weight, device
127
+ )
128
+ cat_loss = _compute_categorical_reconstruction_loss(
129
+ cat_logits, cat_true, cat_mask, cat_loss_weight, device
130
+ )
131
+ return num_loss + cat_loss
132
+
133
+
22
134
  # Scikit-Learn style wrapper for FTTransformer.
23
135
 
24
136
 
@@ -508,40 +620,13 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
508
620
  )
509
621
  scaler = GradScaler(enabled=(device_type == 'cuda'))
510
622
 
511
- def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
512
- loss = torch.zeros((), device=device, dtype=torch.float32)
513
-
514
- if num_pred is not None and num_true_b is not None and num_mask_b is not None:
515
- num_mask_b = num_mask_b.to(dtype=torch.bool)
516
- if num_mask_b.any():
517
- diff = num_pred - num_true_b
518
- mse = diff * diff
519
- loss = loss + float(num_loss_weight) * \
520
- mse[num_mask_b].mean()
521
-
522
- if cat_logits and cat_true_b is not None and cat_mask_b is not None:
523
- cat_mask_b = cat_mask_b.to(dtype=torch.bool)
524
- cat_losses: List[torch.Tensor] = []
525
- for j, logits in enumerate(cat_logits):
526
- mask_j = cat_mask_b[:, j]
527
- if not mask_j.any():
528
- continue
529
- targets = cat_true_b[:, j]
530
- cat_losses.append(
531
- F.cross_entropy(logits, targets, reduction='none')[
532
- mask_j].mean()
533
- )
534
- if cat_losses:
535
- loss = loss + float(cat_loss_weight) * \
536
- torch.stack(cat_losses).mean()
537
- return loss
538
-
539
623
  train_history: List[float] = []
540
624
  val_history: List[float] = []
541
625
  best_loss = float("inf")
542
626
  best_state = None
543
627
  patience_counter = 0
544
628
  is_ddp_model = isinstance(self.ft, DDP)
629
+ use_collectives = dist.is_initialized() and is_ddp_model
545
630
 
546
631
  clip_fn = None
547
632
  if self.device.type == 'cuda':
@@ -579,11 +664,13 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
579
664
 
580
665
  num_pred, cat_logits = self.ft(
581
666
  X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
582
- batch_loss = _batch_recon_loss(
583
- num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
667
+ batch_loss = _compute_reconstruction_loss(
668
+ num_pred, cat_logits, num_true_b, num_mask_b,
669
+ cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
670
+ device=X_num_b.device)
584
671
  local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
585
672
  global_bad = local_bad
586
- if dist.is_initialized():
673
+ if use_collectives:
587
674
  bad = torch.tensor(
588
675
  [local_bad],
589
676
  device=batch_loss.device,
@@ -672,10 +759,11 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
672
759
  self.device, non_blocking=True)
673
760
  num_pred_v, cat_logits_v = self.ft(
674
761
  X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
675
- loss_v = _batch_recon_loss(
762
+ loss_v = _compute_reconstruction_loss(
676
763
  num_pred_v, cat_logits_v,
677
764
  X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
678
765
  X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
766
+ num_loss_weight, cat_loss_weight,
679
767
  device=X_num_v.device
680
768
  )
681
769
  if not torch.isfinite(loss_v):
@@ -687,7 +775,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
687
775
  total_n += float(end - start)
688
776
  val_loss_tensor[0] = total_val / max(total_n, 1.0)
689
777
 
690
- if dist.is_initialized():
778
+ if use_collectives:
691
779
  dist.broadcast(val_loss_tensor, src=0)
692
780
  val_loss_value = float(val_loss_tensor.item())
693
781
  prune_now = False
@@ -719,7 +807,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
719
807
  if trial.should_prune():
720
808
  prune_now = True
721
809
 
722
- if dist.is_initialized():
810
+ if use_collectives:
723
811
  flag = torch.tensor(
724
812
  [1 if prune_now else 0],
725
813
  device=loss_tensor_device,