ins-pricing 0.2.9__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/CHANGELOG.md +93 -0
- ins_pricing/README.md +11 -0
- ins_pricing/cli/bayesopt_entry_runner.py +626 -499
- ins_pricing/cli/utils/evaluation_context.py +320 -0
- ins_pricing/cli/utils/import_resolver.py +350 -0
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
- ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
- ins_pricing/modelling/core/bayesopt/core.py +153 -94
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +118 -31
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +294 -139
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
- ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
- ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +587 -0
- ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
- ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/METADATA +162 -149
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/RECORD +26 -13
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/WHEEL +0 -0
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -19,6 +19,118 @@ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
|
|
|
19
19
|
from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
# --- Helper functions for reconstruction loss computation ---
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _compute_numeric_reconstruction_loss(
|
|
26
|
+
num_pred: Optional[torch.Tensor],
|
|
27
|
+
num_true: Optional[torch.Tensor],
|
|
28
|
+
num_mask: Optional[torch.Tensor],
|
|
29
|
+
loss_weight: float,
|
|
30
|
+
device: torch.device,
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""Compute MSE loss for numeric feature reconstruction.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
num_pred: Predicted numeric values (N, num_features)
|
|
36
|
+
num_true: Ground truth numeric values (N, num_features)
|
|
37
|
+
num_mask: Boolean mask indicating which values were masked (N, num_features)
|
|
38
|
+
loss_weight: Weight to apply to the loss
|
|
39
|
+
device: Target device for computation
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Weighted MSE loss for masked numeric features
|
|
43
|
+
"""
|
|
44
|
+
if num_pred is None or num_true is None or num_mask is None:
|
|
45
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
46
|
+
|
|
47
|
+
num_mask = num_mask.to(dtype=torch.bool)
|
|
48
|
+
if not num_mask.any():
|
|
49
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
50
|
+
|
|
51
|
+
diff = num_pred - num_true
|
|
52
|
+
mse = diff * diff
|
|
53
|
+
return float(loss_weight) * mse[num_mask].mean()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _compute_categorical_reconstruction_loss(
|
|
57
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
58
|
+
cat_true: Optional[torch.Tensor],
|
|
59
|
+
cat_mask: Optional[torch.Tensor],
|
|
60
|
+
loss_weight: float,
|
|
61
|
+
device: torch.device,
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
"""Compute cross-entropy loss for categorical feature reconstruction.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
cat_logits: List of logits for each categorical feature
|
|
67
|
+
cat_true: Ground truth categorical indices (N, num_cat_features)
|
|
68
|
+
cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
|
|
69
|
+
loss_weight: Weight to apply to the loss
|
|
70
|
+
device: Target device for computation
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Weighted cross-entropy loss for masked categorical features
|
|
74
|
+
"""
|
|
75
|
+
if not cat_logits or cat_true is None or cat_mask is None:
|
|
76
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
77
|
+
|
|
78
|
+
cat_mask = cat_mask.to(dtype=torch.bool)
|
|
79
|
+
cat_losses: List[torch.Tensor] = []
|
|
80
|
+
|
|
81
|
+
for j, logits in enumerate(cat_logits):
|
|
82
|
+
mask_j = cat_mask[:, j]
|
|
83
|
+
if not mask_j.any():
|
|
84
|
+
continue
|
|
85
|
+
targets = cat_true[:, j]
|
|
86
|
+
cat_losses.append(
|
|
87
|
+
F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if not cat_losses:
|
|
91
|
+
return torch.zeros((), device=device, dtype=torch.float32)
|
|
92
|
+
|
|
93
|
+
return float(loss_weight) * torch.stack(cat_losses).mean()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _compute_reconstruction_loss(
|
|
97
|
+
num_pred: Optional[torch.Tensor],
|
|
98
|
+
cat_logits: Optional[List[torch.Tensor]],
|
|
99
|
+
num_true: Optional[torch.Tensor],
|
|
100
|
+
num_mask: Optional[torch.Tensor],
|
|
101
|
+
cat_true: Optional[torch.Tensor],
|
|
102
|
+
cat_mask: Optional[torch.Tensor],
|
|
103
|
+
num_loss_weight: float,
|
|
104
|
+
cat_loss_weight: float,
|
|
105
|
+
device: torch.device,
|
|
106
|
+
) -> torch.Tensor:
|
|
107
|
+
"""Compute combined reconstruction loss for masked tabular data.
|
|
108
|
+
|
|
109
|
+
This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
num_pred: Predicted numeric values
|
|
113
|
+
cat_logits: List of logits for categorical features
|
|
114
|
+
num_true: Ground truth numeric values
|
|
115
|
+
num_mask: Mask for numeric features
|
|
116
|
+
cat_true: Ground truth categorical indices
|
|
117
|
+
cat_mask: Mask for categorical features
|
|
118
|
+
num_loss_weight: Weight for numeric loss
|
|
119
|
+
cat_loss_weight: Weight for categorical loss
|
|
120
|
+
device: Target device for computation
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Combined weighted reconstruction loss
|
|
124
|
+
"""
|
|
125
|
+
num_loss = _compute_numeric_reconstruction_loss(
|
|
126
|
+
num_pred, num_true, num_mask, num_loss_weight, device
|
|
127
|
+
)
|
|
128
|
+
cat_loss = _compute_categorical_reconstruction_loss(
|
|
129
|
+
cat_logits, cat_true, cat_mask, cat_loss_weight, device
|
|
130
|
+
)
|
|
131
|
+
return num_loss + cat_loss
|
|
132
|
+
|
|
133
|
+
|
|
22
134
|
# Scikit-Learn style wrapper for FTTransformer.
|
|
23
135
|
|
|
24
136
|
|
|
@@ -508,34 +620,6 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
508
620
|
)
|
|
509
621
|
scaler = GradScaler(enabled=(device_type == 'cuda'))
|
|
510
622
|
|
|
511
|
-
def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
|
|
512
|
-
loss = torch.zeros((), device=device, dtype=torch.float32)
|
|
513
|
-
|
|
514
|
-
if num_pred is not None and num_true_b is not None and num_mask_b is not None:
|
|
515
|
-
num_mask_b = num_mask_b.to(dtype=torch.bool)
|
|
516
|
-
if num_mask_b.any():
|
|
517
|
-
diff = num_pred - num_true_b
|
|
518
|
-
mse = diff * diff
|
|
519
|
-
loss = loss + float(num_loss_weight) * \
|
|
520
|
-
mse[num_mask_b].mean()
|
|
521
|
-
|
|
522
|
-
if cat_logits and cat_true_b is not None and cat_mask_b is not None:
|
|
523
|
-
cat_mask_b = cat_mask_b.to(dtype=torch.bool)
|
|
524
|
-
cat_losses: List[torch.Tensor] = []
|
|
525
|
-
for j, logits in enumerate(cat_logits):
|
|
526
|
-
mask_j = cat_mask_b[:, j]
|
|
527
|
-
if not mask_j.any():
|
|
528
|
-
continue
|
|
529
|
-
targets = cat_true_b[:, j]
|
|
530
|
-
cat_losses.append(
|
|
531
|
-
F.cross_entropy(logits, targets, reduction='none')[
|
|
532
|
-
mask_j].mean()
|
|
533
|
-
)
|
|
534
|
-
if cat_losses:
|
|
535
|
-
loss = loss + float(cat_loss_weight) * \
|
|
536
|
-
torch.stack(cat_losses).mean()
|
|
537
|
-
return loss
|
|
538
|
-
|
|
539
623
|
train_history: List[float] = []
|
|
540
624
|
val_history: List[float] = []
|
|
541
625
|
best_loss = float("inf")
|
|
@@ -579,8 +663,10 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
579
663
|
|
|
580
664
|
num_pred, cat_logits = self.ft(
|
|
581
665
|
X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
|
|
582
|
-
batch_loss =
|
|
583
|
-
num_pred, cat_logits, num_true_b, num_mask_b,
|
|
666
|
+
batch_loss = _compute_reconstruction_loss(
|
|
667
|
+
num_pred, cat_logits, num_true_b, num_mask_b,
|
|
668
|
+
cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
|
|
669
|
+
device=X_num_b.device)
|
|
584
670
|
local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
|
|
585
671
|
global_bad = local_bad
|
|
586
672
|
if dist.is_initialized():
|
|
@@ -672,10 +758,11 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
|
|
|
672
758
|
self.device, non_blocking=True)
|
|
673
759
|
num_pred_v, cat_logits_v = self.ft(
|
|
674
760
|
X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
|
|
675
|
-
loss_v =
|
|
761
|
+
loss_v = _compute_reconstruction_loss(
|
|
676
762
|
num_pred_v, cat_logits_v,
|
|
677
763
|
X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
|
|
678
764
|
X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
|
|
765
|
+
num_loss_weight, cat_loss_weight,
|
|
679
766
|
device=X_num_v.device
|
|
680
767
|
)
|
|
681
768
|
if not torch.isfinite(loss_v):
|
|
@@ -42,6 +42,252 @@ class _OrderSplitter:
|
|
|
42
42
|
for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
|
|
43
43
|
yield order[tr_idx], order[val_idx]
|
|
44
44
|
|
|
45
|
+
|
|
46
|
+
# =============================================================================
|
|
47
|
+
# CV Strategy Resolution Helper
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CVStrategyResolver:
|
|
52
|
+
"""Helper class to resolve cross-validation splitting strategies.
|
|
53
|
+
|
|
54
|
+
This encapsulates the logic for determining how to split data based on the
|
|
55
|
+
configured strategy (random, time, group). It provides methods to:
|
|
56
|
+
- Get time-ordered indices for a dataset
|
|
57
|
+
- Get group values for a dataset
|
|
58
|
+
- Create appropriate sklearn splitters
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
TIME_STRATEGIES = {"time", "timeseries", "temporal"}
|
|
62
|
+
GROUP_STRATEGIES = {"group", "grouped"}
|
|
63
|
+
|
|
64
|
+
def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
|
|
65
|
+
"""Initialize the resolver.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
|
|
69
|
+
train_data: The training DataFrame (needed for column access)
|
|
70
|
+
rand_seed: Random seed for reproducible splits
|
|
71
|
+
"""
|
|
72
|
+
self.config = config
|
|
73
|
+
self.train_data = train_data
|
|
74
|
+
self.rand_seed = rand_seed
|
|
75
|
+
self._strategy = self._normalize_strategy()
|
|
76
|
+
|
|
77
|
+
def _normalize_strategy(self) -> str:
|
|
78
|
+
"""Normalize the strategy string to lowercase."""
|
|
79
|
+
raw = str(getattr(self.config, "cv_strategy", "random") or "random")
|
|
80
|
+
return raw.strip().lower()
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def strategy(self) -> str:
|
|
84
|
+
"""Return the normalized CV strategy."""
|
|
85
|
+
return self._strategy
|
|
86
|
+
|
|
87
|
+
def is_time_strategy(self) -> bool:
|
|
88
|
+
"""Check if using a time-based splitting strategy."""
|
|
89
|
+
return self._strategy in self.TIME_STRATEGIES
|
|
90
|
+
|
|
91
|
+
def is_group_strategy(self) -> bool:
|
|
92
|
+
"""Check if using a group-based splitting strategy."""
|
|
93
|
+
return self._strategy in self.GROUP_STRATEGIES
|
|
94
|
+
|
|
95
|
+
def get_time_col(self) -> str:
|
|
96
|
+
"""Get and validate the time column.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If time column is not configured
|
|
100
|
+
KeyError: If time column not found in train_data
|
|
101
|
+
"""
|
|
102
|
+
time_col = getattr(self.config, "cv_time_col", None)
|
|
103
|
+
if not time_col:
|
|
104
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
105
|
+
if time_col not in self.train_data.columns:
|
|
106
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
107
|
+
return time_col
|
|
108
|
+
|
|
109
|
+
def get_time_ascending(self) -> bool:
|
|
110
|
+
"""Get the time ordering preference."""
|
|
111
|
+
return bool(getattr(self.config, "cv_time_ascending", True))
|
|
112
|
+
|
|
113
|
+
def get_group_col(self) -> str:
|
|
114
|
+
"""Get and validate the group column.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
ValueError: If group column is not configured
|
|
118
|
+
KeyError: If group column not found in train_data
|
|
119
|
+
"""
|
|
120
|
+
group_col = getattr(self.config, "cv_group_col", None)
|
|
121
|
+
if not group_col:
|
|
122
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
123
|
+
if group_col not in self.train_data.columns:
|
|
124
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
125
|
+
return group_col
|
|
126
|
+
|
|
127
|
+
def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
|
|
128
|
+
"""Get indices ordered by time for the given dataset.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
X_all: DataFrame to get indices for (must have index compatible with train_data)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Array of positional indices into X_all, ordered by time
|
|
135
|
+
"""
|
|
136
|
+
time_col = self.get_time_col()
|
|
137
|
+
ascending = self.get_time_ascending()
|
|
138
|
+
order_index = self.train_data[time_col].sort_values(ascending=ascending).index
|
|
139
|
+
index_set = set(X_all.index)
|
|
140
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
141
|
+
order = X_all.index.get_indexer(order_index)
|
|
142
|
+
return order[order >= 0]
|
|
143
|
+
|
|
144
|
+
def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
|
|
145
|
+
"""Get group labels for the given dataset.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
X_all: DataFrame to get groups for
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Series of group labels aligned with X_all
|
|
152
|
+
"""
|
|
153
|
+
group_col = self.get_group_col()
|
|
154
|
+
return self.train_data.reindex(X_all.index)[group_col]
|
|
155
|
+
|
|
156
|
+
def create_train_val_splitter(
|
|
157
|
+
self,
|
|
158
|
+
X_all: pd.DataFrame,
|
|
159
|
+
val_ratio: float,
|
|
160
|
+
) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
|
|
161
|
+
"""Create a single train/val split based on strategy.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
X_all: DataFrame to split
|
|
165
|
+
val_ratio: Fraction of data for validation
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
|
|
169
|
+
"""
|
|
170
|
+
if self.is_time_strategy():
|
|
171
|
+
order = self.get_time_ordered_indices(X_all)
|
|
172
|
+
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
173
|
+
if cutoff <= 0 or cutoff >= len(order):
|
|
174
|
+
raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
|
|
175
|
+
return (order[:cutoff], order[cutoff:]), None
|
|
176
|
+
|
|
177
|
+
if self.is_group_strategy():
|
|
178
|
+
groups = self.get_groups(X_all)
|
|
179
|
+
splitter = GroupShuffleSplit(
|
|
180
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
181
|
+
)
|
|
182
|
+
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
183
|
+
return (train_idx, val_idx), groups
|
|
184
|
+
|
|
185
|
+
# Random strategy
|
|
186
|
+
splitter = ShuffleSplit(
|
|
187
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
188
|
+
)
|
|
189
|
+
train_idx, val_idx = next(splitter.split(X_all))
|
|
190
|
+
return (train_idx, val_idx), None
|
|
191
|
+
|
|
192
|
+
def create_cv_splitter(
|
|
193
|
+
self,
|
|
194
|
+
X_all: pd.DataFrame,
|
|
195
|
+
y_all: Optional[pd.Series],
|
|
196
|
+
n_splits: int,
|
|
197
|
+
val_ratio: float,
|
|
198
|
+
) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
|
|
199
|
+
"""Create a cross-validation splitter based on strategy.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
X_all: DataFrame to split
|
|
203
|
+
y_all: Target series (used by some splitters)
|
|
204
|
+
n_splits: Number of CV folds
|
|
205
|
+
val_ratio: Validation ratio (for ShuffleSplit)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of (split_iterator, actual_n_splits)
|
|
209
|
+
"""
|
|
210
|
+
n_splits = max(2, int(n_splits))
|
|
211
|
+
|
|
212
|
+
if self.is_group_strategy():
|
|
213
|
+
groups = self.get_groups(X_all)
|
|
214
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
215
|
+
if n_groups < 2:
|
|
216
|
+
return iter([]), 0
|
|
217
|
+
n_splits = min(n_splits, n_groups)
|
|
218
|
+
if n_splits < 2:
|
|
219
|
+
return iter([]), 0
|
|
220
|
+
splitter = GroupKFold(n_splits=n_splits)
|
|
221
|
+
return splitter.split(X_all, y_all, groups=groups), n_splits
|
|
222
|
+
|
|
223
|
+
if self.is_time_strategy():
|
|
224
|
+
order = self.get_time_ordered_indices(X_all)
|
|
225
|
+
if len(order) < 2:
|
|
226
|
+
return iter([]), 0
|
|
227
|
+
n_splits = min(n_splits, max(2, len(order) - 1))
|
|
228
|
+
if n_splits < 2:
|
|
229
|
+
return iter([]), 0
|
|
230
|
+
splitter = TimeSeriesSplit(n_splits=n_splits)
|
|
231
|
+
return _OrderSplitter(splitter, order).split(X_all), n_splits
|
|
232
|
+
|
|
233
|
+
# Random strategy
|
|
234
|
+
if len(X_all) < n_splits:
|
|
235
|
+
n_splits = len(X_all)
|
|
236
|
+
if n_splits < 2:
|
|
237
|
+
return iter([]), 0
|
|
238
|
+
splitter = ShuffleSplit(
|
|
239
|
+
n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
|
|
240
|
+
)
|
|
241
|
+
return splitter.split(X_all), n_splits
|
|
242
|
+
|
|
243
|
+
def create_kfold_splitter(
|
|
244
|
+
self,
|
|
245
|
+
X_all: pd.DataFrame,
|
|
246
|
+
k: int,
|
|
247
|
+
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
248
|
+
"""Create a K-fold splitter for ensemble training.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
X_all: DataFrame to split
|
|
252
|
+
k: Number of folds
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
256
|
+
"""
|
|
257
|
+
k = max(2, int(k))
|
|
258
|
+
n_samples = len(X_all)
|
|
259
|
+
if n_samples < 2:
|
|
260
|
+
return None, 0
|
|
261
|
+
|
|
262
|
+
if self.is_group_strategy():
|
|
263
|
+
groups = self.get_groups(X_all)
|
|
264
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
265
|
+
if n_groups < 2:
|
|
266
|
+
return None, 0
|
|
267
|
+
k = min(k, n_groups)
|
|
268
|
+
if k < 2:
|
|
269
|
+
return None, 0
|
|
270
|
+
splitter = GroupKFold(n_splits=k)
|
|
271
|
+
return splitter.split(X_all, y=None, groups=groups), k
|
|
272
|
+
|
|
273
|
+
if self.is_time_strategy():
|
|
274
|
+
order = self.get_time_ordered_indices(X_all)
|
|
275
|
+
if len(order) < 2:
|
|
276
|
+
return None, 0
|
|
277
|
+
k = min(k, max(2, len(order) - 1))
|
|
278
|
+
if k < 2:
|
|
279
|
+
return None, 0
|
|
280
|
+
splitter = TimeSeriesSplit(n_splits=k)
|
|
281
|
+
return _OrderSplitter(splitter, order).split(X_all), k
|
|
282
|
+
|
|
283
|
+
# Random strategy with KFold
|
|
284
|
+
k = min(k, n_samples)
|
|
285
|
+
if k < 2:
|
|
286
|
+
return None, 0
|
|
287
|
+
splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
|
|
288
|
+
return splitter.split(X_all), k
|
|
289
|
+
|
|
290
|
+
|
|
45
291
|
# =============================================================================
|
|
46
292
|
# Trainer system
|
|
47
293
|
# =============================================================================
|
|
@@ -692,6 +938,15 @@ class TrainerBase:
|
|
|
692
938
|
*,
|
|
693
939
|
allow_default: bool = False,
|
|
694
940
|
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
941
|
+
"""Resolve train/validation split indices based on configured CV strategy.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
X_all: DataFrame to split
|
|
945
|
+
allow_default: If True, use default val_ratio when config is invalid
|
|
946
|
+
|
|
947
|
+
Returns:
|
|
948
|
+
Tuple of (train_indices, val_indices) or None if not enough data
|
|
949
|
+
"""
|
|
695
950
|
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
696
951
|
if not (0.0 < val_ratio < 1.0):
|
|
697
952
|
if not allow_default:
|
|
@@ -700,46 +955,8 @@ class TrainerBase:
|
|
|
700
955
|
if len(X_all) < 10:
|
|
701
956
|
return None
|
|
702
957
|
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
706
|
-
if not time_col:
|
|
707
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
708
|
-
if time_col not in self.ctx.train_data.columns:
|
|
709
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
710
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
711
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
712
|
-
index_set = set(X_all.index)
|
|
713
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
714
|
-
order = X_all.index.get_indexer(order_index)
|
|
715
|
-
order = order[order >= 0]
|
|
716
|
-
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
717
|
-
if cutoff <= 0 or cutoff >= len(order):
|
|
718
|
-
raise ValueError(
|
|
719
|
-
f"prop_test={val_ratio} leaves no data for train/val split.")
|
|
720
|
-
return order[:cutoff], order[cutoff:]
|
|
721
|
-
|
|
722
|
-
if strategy in {"group", "grouped"}:
|
|
723
|
-
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
724
|
-
if not group_col:
|
|
725
|
-
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
726
|
-
if group_col not in self.ctx.train_data.columns:
|
|
727
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
728
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
729
|
-
splitter = GroupShuffleSplit(
|
|
730
|
-
n_splits=1,
|
|
731
|
-
test_size=val_ratio,
|
|
732
|
-
random_state=self.ctx.rand_seed,
|
|
733
|
-
)
|
|
734
|
-
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
735
|
-
return train_idx, val_idx
|
|
736
|
-
|
|
737
|
-
splitter = ShuffleSplit(
|
|
738
|
-
n_splits=1,
|
|
739
|
-
test_size=val_ratio,
|
|
740
|
-
random_state=self.ctx.rand_seed,
|
|
741
|
-
)
|
|
742
|
-
train_idx, val_idx = next(splitter.split(X_all))
|
|
958
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
959
|
+
(train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
|
|
743
960
|
return train_idx, val_idx
|
|
744
961
|
|
|
745
962
|
def _resolve_time_sample_indices(
|
|
@@ -747,25 +964,34 @@ class TrainerBase:
|
|
|
747
964
|
X_all: pd.DataFrame,
|
|
748
965
|
sample_limit: int,
|
|
749
966
|
) -> Optional[pd.Index]:
|
|
967
|
+
"""Get the most recent indices for time-based sampling.
|
|
968
|
+
|
|
969
|
+
For time-based CV strategies, returns the last `sample_limit` indices
|
|
970
|
+
ordered by time. For other strategies, returns None.
|
|
971
|
+
|
|
972
|
+
Args:
|
|
973
|
+
X_all: DataFrame to sample from
|
|
974
|
+
sample_limit: Maximum number of samples to return
|
|
975
|
+
|
|
976
|
+
Returns:
|
|
977
|
+
Index of sampled rows, or None if not using time-based strategy
|
|
978
|
+
"""
|
|
750
979
|
if sample_limit <= 0:
|
|
751
980
|
return None
|
|
752
|
-
|
|
753
|
-
|
|
981
|
+
|
|
982
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
983
|
+
if not resolver.is_time_strategy():
|
|
754
984
|
return None
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
if time_col not in self.ctx.train_data.columns:
|
|
759
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
760
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
761
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
762
|
-
index_set = set(X_all.index)
|
|
763
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
764
|
-
if not order_index:
|
|
985
|
+
|
|
986
|
+
order = resolver.get_time_ordered_indices(X_all)
|
|
987
|
+
if len(order) == 0:
|
|
765
988
|
return None
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
989
|
+
|
|
990
|
+
# Get the last sample_limit indices (most recent in time)
|
|
991
|
+
if len(order) > sample_limit:
|
|
992
|
+
order = order[-sample_limit:]
|
|
993
|
+
|
|
994
|
+
return X_all.index[order]
|
|
769
995
|
|
|
770
996
|
def _resolve_ensemble_splits(
|
|
771
997
|
self,
|
|
@@ -773,60 +999,17 @@ class TrainerBase:
|
|
|
773
999
|
*,
|
|
774
1000
|
k: int,
|
|
775
1001
|
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
776
|
-
|
|
777
|
-
n_samples = len(X_all)
|
|
778
|
-
if n_samples < 2:
|
|
779
|
-
return None, 0
|
|
780
|
-
|
|
781
|
-
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
782
|
-
if strategy in {"group", "grouped"}:
|
|
783
|
-
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
784
|
-
if not group_col:
|
|
785
|
-
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
786
|
-
if group_col not in self.ctx.train_data.columns:
|
|
787
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
788
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
789
|
-
n_groups = int(groups.nunique(dropna=False))
|
|
790
|
-
if n_groups < 2:
|
|
791
|
-
return None, 0
|
|
792
|
-
if k > n_groups:
|
|
793
|
-
k = n_groups
|
|
794
|
-
if k < 2:
|
|
795
|
-
return None, 0
|
|
796
|
-
splitter = GroupKFold(n_splits=k)
|
|
797
|
-
return splitter.split(X_all, y=None, groups=groups), k
|
|
1002
|
+
"""Resolve K-fold splits for ensemble training based on configured CV strategy.
|
|
798
1003
|
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
803
|
-
if time_col not in self.ctx.train_data.columns:
|
|
804
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
805
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
806
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
807
|
-
index_set = set(X_all.index)
|
|
808
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
809
|
-
order = X_all.index.get_indexer(order_index)
|
|
810
|
-
order = order[order >= 0]
|
|
811
|
-
if len(order) < 2:
|
|
812
|
-
return None, 0
|
|
813
|
-
if len(order) <= k:
|
|
814
|
-
k = max(2, len(order) - 1)
|
|
815
|
-
if k < 2:
|
|
816
|
-
return None, 0
|
|
817
|
-
splitter = TimeSeriesSplit(n_splits=k)
|
|
818
|
-
return _OrderSplitter(splitter, order).split(X_all), k
|
|
1004
|
+
Args:
|
|
1005
|
+
X_all: DataFrame to split
|
|
1006
|
+
k: Number of folds requested
|
|
819
1007
|
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
n_splits=k,
|
|
826
|
-
shuffle=True,
|
|
827
|
-
random_state=self.ctx.rand_seed,
|
|
828
|
-
)
|
|
829
|
-
return splitter.split(X_all), k
|
|
1008
|
+
Returns:
|
|
1009
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
1010
|
+
"""
|
|
1011
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1012
|
+
return resolver.create_kfold_splitter(X_all, k)
|
|
830
1013
|
|
|
831
1014
|
def cross_val_generic(
|
|
832
1015
|
self,
|
|
@@ -892,7 +1075,6 @@ class TrainerBase:
|
|
|
892
1075
|
w_all = w_all.loc[sampled_idx] if w_all is not None else None
|
|
893
1076
|
|
|
894
1077
|
if splitter is None:
|
|
895
|
-
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
896
1078
|
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
897
1079
|
if not (0.0 < val_ratio < 1.0):
|
|
898
1080
|
val_ratio = 0.25
|
|
@@ -901,37 +1083,10 @@ class TrainerBase:
|
|
|
901
1083
|
cv_splits = max(2, int(round(1 / val_ratio)))
|
|
902
1084
|
cv_splits = max(2, int(cv_splits))
|
|
903
1085
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
if group_col not in self.ctx.train_data.columns:
|
|
909
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
910
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
911
|
-
split_iter = GroupKFold(n_splits=cv_splits).split(X_all, y_all, groups=groups)
|
|
912
|
-
elif strategy in {"time", "timeseries", "temporal"}:
|
|
913
|
-
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
914
|
-
if not time_col:
|
|
915
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
916
|
-
if time_col not in self.ctx.train_data.columns:
|
|
917
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
918
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
919
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
920
|
-
index_set = set(X_all.index)
|
|
921
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
922
|
-
order = X_all.index.get_indexer(order_index)
|
|
923
|
-
order = order[order >= 0]
|
|
924
|
-
if len(order) <= cv_splits:
|
|
925
|
-
cv_splits = max(2, len(order) - 1)
|
|
926
|
-
if cv_splits < 2:
|
|
927
|
-
raise ValueError("Not enough samples for time-series CV.")
|
|
928
|
-
split_iter = _OrderSplitter(TimeSeriesSplit(n_splits=cv_splits), order).split(X_all)
|
|
929
|
-
else:
|
|
930
|
-
split_iter = ShuffleSplit(
|
|
931
|
-
n_splits=cv_splits,
|
|
932
|
-
test_size=val_ratio,
|
|
933
|
-
random_state=self.ctx.rand_seed
|
|
934
|
-
).split(X_all)
|
|
1086
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1087
|
+
split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
|
|
1088
|
+
if actual_splits < 2:
|
|
1089
|
+
raise ValueError("Not enough samples for cross-validation.")
|
|
935
1090
|
else:
|
|
936
1091
|
if hasattr(splitter, "split"):
|
|
937
1092
|
split_iter = splitter.split(X_all, y_all, groups=None)
|