ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/CHANGELOG.md +93 -0
- ins_pricing/README.md +11 -0
- ins_pricing/cli/Explain_entry.py +50 -48
- ins_pricing/cli/bayesopt_entry_runner.py +699 -569
- ins_pricing/cli/utils/evaluation_context.py +320 -0
- ins_pricing/cli/utils/import_resolver.py +350 -0
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
- ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
- ins_pricing/modelling/core/bayesopt/core.py +153 -94
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
- ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
- ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
- ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
- ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,118 @@ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
|
|
|
19
19
|
from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
# --- Helper functions for reconstruction loss computation ---
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _compute_numeric_reconstruction_loss(
|
|
26
|
+
num_pred: Optional[torch.Tensor],
|
|
27
|
+
num_true: Optional[torch.Tensor],
|
|
28
|
+
num_mask: Optional[torch.Tensor],
|
|
29
|
+
loss_weight: float,
|
|
30
|
+
device: torch.device,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""Compute MSE loss for numeric feature reconstruction.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
num_pred: Predicted numeric values (N, num_features)
|
|
36
|
+
num_true: Ground truth numeric values (N, num_features)
|
|
37
|
+
num_mask: Boolean mask indicating which values were masked (N, num_features)
|
|
38
|
+
loss_weight: Weight to apply to the loss
|
|
39
|
+
device: Target device for computation
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Weighted MSE loss for masked numeric features
|
|
43
|
+
"""
|
|
44
|
+
if num_pred is None or num_true is None or num_mask is None:
|
|
45
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
46
|
+
|
|
47
|
+
num_mask = num_mask.to(dtype=torch.bool)
|
|
48
|
+
if not num_mask.any():
|
|
49
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
50
|
+
|
|
51
|
+
diff = num_pred - num_true
|
|
52
|
+
mse = diff * diff
|
|
53
|
+
return float(loss_weight) * mse[num_mask].mean()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _compute_categorical_reconstruction_loss(
|
|
57
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
58
|
+
cat_true: Optional[torch.Tensor],
|
|
59
|
+
cat_mask: Optional[torch.Tensor],
|
|
60
|
+
loss_weight: float,
|
|
61
|
+
device: torch.device,
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
"""Compute cross-entropy loss for categorical feature reconstruction.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
cat_logits: List of logits for each categorical feature
|
|
67
|
+
cat_true: Ground truth categorical indices (N, num_cat_features)
|
|
68
|
+
cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
|
|
69
|
+
loss_weight: Weight to apply to the loss
|
|
70
|
+
device: Target device for computation
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Weighted cross-entropy loss for masked categorical features
|
|
74
|
+
"""
|
|
75
|
+
if not cat_logits or cat_true is None or cat_mask is None:
|
|
76
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
77
|
+
|
|
78
|
+
cat_mask = cat_mask.to(dtype=torch.bool)
|
|
79
|
+
cat_losses: List[torch.Tensor] = []
|
|
80
|
+
|
|
81
|
+
for j, logits in enumerate(cat_logits):
|
|
82
|
+
mask_j = cat_mask[:, j]
|
|
83
|
+
if not mask_j.any():
|
|
84
|
+
continue
|
|
85
|
+
targets = cat_true[:, j]
|
|
86
|
+
cat_losses.append(
|
|
87
|
+
F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if not cat_losses:
|
|
91
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
92
|
+
|
|
93
|
+
return float(loss_weight) * torch.stack(cat_losses).mean()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _compute_reconstruction_loss(
|
|
97
|
+
num_pred: Optional[torch.Tensor],
|
|
98
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
99
|
+
num_true: Optional[torch.Tensor],
|
|
100
|
+
num_mask: Optional[torch.Tensor],
|
|
101
|
+
cat_true: Optional[torch.Tensor],
|
|
102
|
+
cat_mask: Optional[torch.Tensor],
|
|
103
|
+
num_loss_weight: float,
|
|
104
|
+
cat_loss_weight: float,
|
|
105
|
+
device: torch.device,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
"""Compute combined reconstruction loss for masked tabular data.
|
|
108
|
+
|
|
109
|
+
This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
num_pred: Predicted numeric values
|
|
113
|
+
cat_logits: List of logits for categorical features
|
|
114
|
+
num_true: Ground truth numeric values
|
|
115
|
+
num_mask: Mask for numeric features
|
|
116
|
+
cat_true: Ground truth categorical indices
|
|
117
|
+
cat_mask: Mask for categorical features
|
|
118
|
+
num_loss_weight: Weight for numeric loss
|
|
119
|
+
cat_loss_weight: Weight for categorical loss
|
|
120
|
+
device: Target device for computation
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Combined weighted reconstruction loss
|
|
124
|
+
"""
|
|
125
|
+
num_loss = _compute_numeric_reconstruction_loss(
|
|
126
|
+
num_pred, num_true, num_mask, num_loss_weight, device
|
|
127
|
+
)
|
|
128
|
+
cat_loss = _compute_categorical_reconstruction_loss(
|
|
129
|
+
cat_logits, cat_true, cat_mask, cat_loss_weight, device
|
|
130
|
+
)
|
|
131
|
+
return num_loss + cat_loss
|
|
132
|
+
|
|
133
|
+
|
|
22
134
|
# Scikit-Learn style wrapper for FTTransformer.
|
|
23
135
|
|
|
24
136
|
|
|
@@ -508,40 +620,13 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
508
620
|
)
|
|
509
621
|
scaler = GradScaler(enabled=(device_type == 'cuda'))
|
|
510
622
|
|
|
511
|
-
def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
|
|
512
|
-
loss = torch.zeros((), device=device, dtype=torch.float32)
|
|
513
|
-
|
|
514
|
-
if num_pred is not None and num_true_b is not None and num_mask_b is not None:
|
|
515
|
-
num_mask_b = num_mask_b.to(dtype=torch.bool)
|
|
516
|
-
if num_mask_b.any():
|
|
517
|
-
diff = num_pred - num_true_b
|
|
518
|
-
mse = diff * diff
|
|
519
|
-
loss = loss + float(num_loss_weight) * \
|
|
520
|
-
mse[num_mask_b].mean()
|
|
521
|
-
|
|
522
|
-
if cat_logits and cat_true_b is not None and cat_mask_b is not None:
|
|
523
|
-
cat_mask_b = cat_mask_b.to(dtype=torch.bool)
|
|
524
|
-
cat_losses: List[torch.Tensor] = []
|
|
525
|
-
for j, logits in enumerate(cat_logits):
|
|
526
|
-
mask_j = cat_mask_b[:, j]
|
|
527
|
-
if not mask_j.any():
|
|
528
|
-
continue
|
|
529
|
-
targets = cat_true_b[:, j]
|
|
530
|
-
cat_losses.append(
|
|
531
|
-
F.cross_entropy(logits, targets, reduction='none')[
|
|
532
|
-
mask_j].mean()
|
|
533
|
-
)
|
|
534
|
-
if cat_losses:
|
|
535
|
-
loss = loss + float(cat_loss_weight) * \
|
|
536
|
-
torch.stack(cat_losses).mean()
|
|
537
|
-
return loss
|
|
538
|
-
|
|
539
623
|
train_history: List[float] = []
|
|
540
624
|
val_history: List[float] = []
|
|
541
625
|
best_loss = float("inf")
|
|
542
626
|
best_state = None
|
|
543
627
|
patience_counter = 0
|
|
544
628
|
is_ddp_model = isinstance(self.ft, DDP)
|
|
629
|
+
use_collectives = dist.is_initialized() and is_ddp_model
|
|
545
630
|
|
|
546
631
|
clip_fn = None
|
|
547
632
|
if self.device.type == 'cuda':
|
|
@@ -579,11 +664,13 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
579
664
|
|
|
580
665
|
num_pred, cat_logits = self.ft(
|
|
581
666
|
X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
|
|
582
|
-
batch_loss =
|
|
583
|
-
num_pred, cat_logits, num_true_b, num_mask_b,
|
|
667
|
+
batch_loss = _compute_reconstruction_loss(
|
|
668
|
+
num_pred, cat_logits, num_true_b, num_mask_b,
|
|
669
|
+
cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
|
|
670
|
+
device=X_num_b.device)
|
|
584
671
|
local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
|
|
585
672
|
global_bad = local_bad
|
|
586
|
-
if
|
|
673
|
+
if use_collectives:
|
|
587
674
|
bad = torch.tensor(
|
|
588
675
|
[local_bad],
|
|
589
676
|
device=batch_loss.device,
|
|
@@ -672,10 +759,11 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
672
759
|
self.device, non_blocking=True)
|
|
673
760
|
num_pred_v, cat_logits_v = self.ft(
|
|
674
761
|
X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
|
|
675
|
-
loss_v =
|
|
762
|
+
loss_v = _compute_reconstruction_loss(
|
|
676
763
|
num_pred_v, cat_logits_v,
|
|
677
764
|
X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
|
|
678
765
|
X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
|
|
766
|
+
num_loss_weight, cat_loss_weight,
|
|
679
767
|
device=X_num_v.device
|
|
680
768
|
)
|
|
681
769
|
if not torch.isfinite(loss_v):
|
|
@@ -687,7 +775,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
687
775
|
total_n += float(end - start)
|
|
688
776
|
val_loss_tensor[0] = total_val / max(total_n, 1.0)
|
|
689
777
|
|
|
690
|
-
if
|
|
778
|
+
if use_collectives:
|
|
691
779
|
dist.broadcast(val_loss_tensor, src=0)
|
|
692
780
|
val_loss_value = float(val_loss_tensor.item())
|
|
693
781
|
prune_now = False
|
|
@@ -719,7 +807,7 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
719
807
|
if trial.should_prune():
|
|
720
808
|
prune_now = True
|
|
721
809
|
|
|
722
|
-
if
|
|
810
|
+
if use_collectives:
|
|
723
811
|
flag = torch.tensor(
|
|
724
812
|
[1 if prune_now else 0],
|
|
725
813
|
device=loss_tensor_device,
|