ins-pricing 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/bayesopt_entry_runner.py +626 -499
  4. ins_pricing/cli/utils/evaluation_context.py +320 -0
  5. ins_pricing/cli/utils/import_resolver.py +350 -0
  6. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  8. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  9. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  10. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  11. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  12. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +118 -31
  13. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +294 -139
  14. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  15. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  16. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  17. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  18. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  19. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +587 -0
  20. ins_pricing/modelling/core/bayesopt/utils.py +98 -1495
  21. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  22. ins_pricing/setup.py +1 -1
  23. ins_pricing-0.3.0.dist-info/METADATA +162 -0
  24. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/RECORD +26 -13
  25. ins_pricing-0.2.8.dist-info/METADATA +0 -51
  26. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.8.dist-info → ins_pricing-0.3.0.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,118 @@ from ..utils import DistributedUtils, EPS, TorchTrainerMixin
19
19
  from .model_ft_components import FTTransformerCore, MaskedTabularDataset, TabularDataset
20
20
 
21
21
 
22
+ # --- Helper functions for reconstruction loss computation ---
23
+
24
+
25
+ def _compute_numeric_reconstruction_loss(
26
+ num_pred: Optional[torch.Tensor],
27
+ num_true: Optional[torch.Tensor],
28
+ num_mask: Optional[torch.Tensor],
29
+ loss_weight: float,
30
+ device: torch.device,
31
+ ) -> torch.Tensor:
32
+ """Compute MSE loss for numeric feature reconstruction.
33
+
34
+ Args:
35
+ num_pred: Predicted numeric values (N, num_features)
36
+ num_true: Ground truth numeric values (N, num_features)
37
+ num_mask: Boolean mask indicating which values were masked (N, num_features)
38
+ loss_weight: Weight to apply to the loss
39
+ device: Target device for computation
40
+
41
+ Returns:
42
+ Weighted MSE loss for masked numeric features
43
+ """
44
+ if num_pred is None or num_true is None or num_mask is None:
45
+ return torch.zeros((), device=device, dtype=torch.float32)
46
+
47
+ num_mask = num_mask.to(dtype=torch.bool)
48
+ if not num_mask.any():
49
+ return torch.zeros((), device=device, dtype=torch.float32)
50
+
51
+ diff = num_pred - num_true
52
+ mse = diff * diff
53
+ return float(loss_weight) * mse[num_mask].mean()
54
+
55
+
56
+ def _compute_categorical_reconstruction_loss(
57
+ cat_logits: Optional[List[torch.Tensor]],
58
+ cat_true: Optional[torch.Tensor],
59
+ cat_mask: Optional[torch.Tensor],
60
+ loss_weight: float,
61
+ device: torch.device,
62
+ ) -> torch.Tensor:
63
+ """Compute cross-entropy loss for categorical feature reconstruction.
64
+
65
+ Args:
66
+ cat_logits: List of logits for each categorical feature
67
+ cat_true: Ground truth categorical indices (N, num_cat_features)
68
+ cat_mask: Boolean mask indicating which values were masked (N, num_cat_features)
69
+ loss_weight: Weight to apply to the loss
70
+ device: Target device for computation
71
+
72
+ Returns:
73
+ Weighted cross-entropy loss for masked categorical features
74
+ """
75
+ if not cat_logits or cat_true is None or cat_mask is None:
76
+ return torch.zeros((), device=device, dtype=torch.float32)
77
+
78
+ cat_mask = cat_mask.to(dtype=torch.bool)
79
+ cat_losses: List[torch.Tensor] = []
80
+
81
+ for j, logits in enumerate(cat_logits):
82
+ mask_j = cat_mask[:, j]
83
+ if not mask_j.any():
84
+ continue
85
+ targets = cat_true[:, j]
86
+ cat_losses.append(
87
+ F.cross_entropy(logits, targets, reduction='none')[mask_j].mean()
88
+ )
89
+
90
+ if not cat_losses:
91
+ return torch.zeros((), device=device, dtype=torch.float32)
92
+
93
+ return float(loss_weight) * torch.stack(cat_losses).mean()
94
+
95
+
96
+ def _compute_reconstruction_loss(
97
+ num_pred: Optional[torch.Tensor],
98
+ cat_logits: Optional[List[torch.Tensor]],
99
+ num_true: Optional[torch.Tensor],
100
+ num_mask: Optional[torch.Tensor],
101
+ cat_true: Optional[torch.Tensor],
102
+ cat_mask: Optional[torch.Tensor],
103
+ num_loss_weight: float,
104
+ cat_loss_weight: float,
105
+ device: torch.device,
106
+ ) -> torch.Tensor:
107
+ """Compute combined reconstruction loss for masked tabular data.
108
+
109
+ This combines numeric (MSE) and categorical (cross-entropy) reconstruction losses.
110
+
111
+ Args:
112
+ num_pred: Predicted numeric values
113
+ cat_logits: List of logits for categorical features
114
+ num_true: Ground truth numeric values
115
+ num_mask: Mask for numeric features
116
+ cat_true: Ground truth categorical indices
117
+ cat_mask: Mask for categorical features
118
+ num_loss_weight: Weight for numeric loss
119
+ cat_loss_weight: Weight for categorical loss
120
+ device: Target device for computation
121
+
122
+ Returns:
123
+ Combined weighted reconstruction loss
124
+ """
125
+ num_loss = _compute_numeric_reconstruction_loss(
126
+ num_pred, num_true, num_mask, num_loss_weight, device
127
+ )
128
+ cat_loss = _compute_categorical_reconstruction_loss(
129
+ cat_logits, cat_true, cat_mask, cat_loss_weight, device
130
+ )
131
+ return num_loss + cat_loss
132
+
133
+
22
134
  # Scikit-Learn style wrapper for FTTransformer.
23
135
 
24
136
 
@@ -508,34 +620,6 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
508
620
  )
509
621
  scaler = GradScaler(enabled=(device_type == 'cuda'))
510
622
 
511
- def _batch_recon_loss(num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device):
512
- loss = torch.zeros((), device=device, dtype=torch.float32)
513
-
514
- if num_pred is not None and num_true_b is not None and num_mask_b is not None:
515
- num_mask_b = num_mask_b.to(dtype=torch.bool)
516
- if num_mask_b.any():
517
- diff = num_pred - num_true_b
518
- mse = diff * diff
519
- loss = loss + float(num_loss_weight) * \
520
- mse[num_mask_b].mean()
521
-
522
- if cat_logits and cat_true_b is not None and cat_mask_b is not None:
523
- cat_mask_b = cat_mask_b.to(dtype=torch.bool)
524
- cat_losses: List[torch.Tensor] = []
525
- for j, logits in enumerate(cat_logits):
526
- mask_j = cat_mask_b[:, j]
527
- if not mask_j.any():
528
- continue
529
- targets = cat_true_b[:, j]
530
- cat_losses.append(
531
- F.cross_entropy(logits, targets, reduction='none')[
532
- mask_j].mean()
533
- )
534
- if cat_losses:
535
- loss = loss + float(cat_loss_weight) * \
536
- torch.stack(cat_losses).mean()
537
- return loss
538
-
539
623
  train_history: List[float] = []
540
624
  val_history: List[float] = []
541
625
  best_loss = float("inf")
@@ -579,8 +663,10 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
579
663
 
580
664
  num_pred, cat_logits = self.ft(
581
665
  X_num_b, X_cat_b, X_geo_b, return_reconstruction=True)
582
- batch_loss = _batch_recon_loss(
583
- num_pred, cat_logits, num_true_b, num_mask_b, cat_true_b, cat_mask_b, device=X_num_b.device)
666
+ batch_loss = _compute_reconstruction_loss(
667
+ num_pred, cat_logits, num_true_b, num_mask_b,
668
+ cat_true_b, cat_mask_b, num_loss_weight, cat_loss_weight,
669
+ device=X_num_b.device)
584
670
  local_bad = 0 if bool(torch.isfinite(batch_loss)) else 1
585
671
  global_bad = local_bad
586
672
  if dist.is_initialized():
@@ -672,10 +758,11 @@ class FTTransformerSklearn(TorchTrainerMixin, nn.Module):
672
758
  self.device, non_blocking=True)
673
759
  num_pred_v, cat_logits_v = self.ft(
674
760
  X_num_v, X_cat_v, X_geo_v, return_reconstruction=True)
675
- loss_v = _batch_recon_loss(
761
+ loss_v = _compute_reconstruction_loss(
676
762
  num_pred_v, cat_logits_v,
677
763
  X_num_v_true if X_num_v_true.numel() else None, val_num_mask,
678
764
  X_cat_v_true if X_cat_v_true.numel() else None, val_cat_mask,
765
+ num_loss_weight, cat_loss_weight,
679
766
  device=X_num_v.device
680
767
  )
681
768
  if not torch.isfinite(loss_v):
@@ -42,6 +42,252 @@ class _OrderSplitter:
42
42
  for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
43
43
  yield order[tr_idx], order[val_idx]
44
44
 
45
+
46
+ # =============================================================================
47
+ # CV Strategy Resolution Helper
48
+ # =============================================================================
49
+
50
+
51
+ class CVStrategyResolver:
52
+ """Helper class to resolve cross-validation splitting strategies.
53
+
54
+ This encapsulates the logic for determining how to split data based on the
55
+ configured strategy (random, time, group). It provides methods to:
56
+ - Get time-ordered indices for a dataset
57
+ - Get group values for a dataset
58
+ - Create appropriate sklearn splitters
59
+ """
60
+
61
+ TIME_STRATEGIES = {"time", "timeseries", "temporal"}
62
+ GROUP_STRATEGIES = {"group", "grouped"}
63
+
64
+ def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
65
+ """Initialize the resolver.
66
+
67
+ Args:
68
+ config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
69
+ train_data: The training DataFrame (needed for column access)
70
+ rand_seed: Random seed for reproducible splits
71
+ """
72
+ self.config = config
73
+ self.train_data = train_data
74
+ self.rand_seed = rand_seed
75
+ self._strategy = self._normalize_strategy()
76
+
77
+ def _normalize_strategy(self) -> str:
78
+ """Normalize the strategy string to lowercase."""
79
+ raw = str(getattr(self.config, "cv_strategy", "random") or "random")
80
+ return raw.strip().lower()
81
+
82
+ @property
83
+ def strategy(self) -> str:
84
+ """Return the normalized CV strategy."""
85
+ return self._strategy
86
+
87
+ def is_time_strategy(self) -> bool:
88
+ """Check if using a time-based splitting strategy."""
89
+ return self._strategy in self.TIME_STRATEGIES
90
+
91
+ def is_group_strategy(self) -> bool:
92
+ """Check if using a group-based splitting strategy."""
93
+ return self._strategy in self.GROUP_STRATEGIES
94
+
95
+ def get_time_col(self) -> str:
96
+ """Get and validate the time column.
97
+
98
+ Raises:
99
+ ValueError: If time column is not configured
100
+ KeyError: If time column not found in train_data
101
+ """
102
+ time_col = getattr(self.config, "cv_time_col", None)
103
+ if not time_col:
104
+ raise ValueError("cv_time_col is required for time cv_strategy.")
105
+ if time_col not in self.train_data.columns:
106
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
107
+ return time_col
108
+
109
+ def get_time_ascending(self) -> bool:
110
+ """Get the time ordering preference."""
111
+ return bool(getattr(self.config, "cv_time_ascending", True))
112
+
113
+ def get_group_col(self) -> str:
114
+ """Get and validate the group column.
115
+
116
+ Raises:
117
+ ValueError: If group column is not configured
118
+ KeyError: If group column not found in train_data
119
+ """
120
+ group_col = getattr(self.config, "cv_group_col", None)
121
+ if not group_col:
122
+ raise ValueError("cv_group_col is required for group cv_strategy.")
123
+ if group_col not in self.train_data.columns:
124
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
125
+ return group_col
126
+
127
+ def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
128
+ """Get indices ordered by time for the given dataset.
129
+
130
+ Args:
131
+ X_all: DataFrame to get indices for (must have index compatible with train_data)
132
+
133
+ Returns:
134
+ Array of positional indices into X_all, ordered by time
135
+ """
136
+ time_col = self.get_time_col()
137
+ ascending = self.get_time_ascending()
138
+ order_index = self.train_data[time_col].sort_values(ascending=ascending).index
139
+ index_set = set(X_all.index)
140
+ order_index = [idx for idx in order_index if idx in index_set]
141
+ order = X_all.index.get_indexer(order_index)
142
+ return order[order >= 0]
143
+
144
+ def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
145
+ """Get group labels for the given dataset.
146
+
147
+ Args:
148
+ X_all: DataFrame to get groups for
149
+
150
+ Returns:
151
+ Series of group labels aligned with X_all
152
+ """
153
+ group_col = self.get_group_col()
154
+ return self.train_data.reindex(X_all.index)[group_col]
155
+
156
+ def create_train_val_splitter(
157
+ self,
158
+ X_all: pd.DataFrame,
159
+ val_ratio: float,
160
+ ) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
161
+ """Create a single train/val split based on strategy.
162
+
163
+ Args:
164
+ X_all: DataFrame to split
165
+ val_ratio: Fraction of data for validation
166
+
167
+ Returns:
168
+ Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
169
+ """
170
+ if self.is_time_strategy():
171
+ order = self.get_time_ordered_indices(X_all)
172
+ cutoff = int(len(order) * (1.0 - val_ratio))
173
+ if cutoff <= 0 or cutoff >= len(order):
174
+ raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
175
+ return (order[:cutoff], order[cutoff:]), None
176
+
177
+ if self.is_group_strategy():
178
+ groups = self.get_groups(X_all)
179
+ splitter = GroupShuffleSplit(
180
+ n_splits=1, test_size=val_ratio, random_state=self.rand_seed
181
+ )
182
+ train_idx, val_idx = next(splitter.split(X_all, groups=groups))
183
+ return (train_idx, val_idx), groups
184
+
185
+ # Random strategy
186
+ splitter = ShuffleSplit(
187
+ n_splits=1, test_size=val_ratio, random_state=self.rand_seed
188
+ )
189
+ train_idx, val_idx = next(splitter.split(X_all))
190
+ return (train_idx, val_idx), None
191
+
192
+ def create_cv_splitter(
193
+ self,
194
+ X_all: pd.DataFrame,
195
+ y_all: Optional[pd.Series],
196
+ n_splits: int,
197
+ val_ratio: float,
198
+ ) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
199
+ """Create a cross-validation splitter based on strategy.
200
+
201
+ Args:
202
+ X_all: DataFrame to split
203
+ y_all: Target series (used by some splitters)
204
+ n_splits: Number of CV folds
205
+ val_ratio: Validation ratio (for ShuffleSplit)
206
+
207
+ Returns:
208
+ Tuple of (split_iterator, actual_n_splits)
209
+ """
210
+ n_splits = max(2, int(n_splits))
211
+
212
+ if self.is_group_strategy():
213
+ groups = self.get_groups(X_all)
214
+ n_groups = int(groups.nunique(dropna=False))
215
+ if n_groups < 2:
216
+ return iter([]), 0
217
+ n_splits = min(n_splits, n_groups)
218
+ if n_splits < 2:
219
+ return iter([]), 0
220
+ splitter = GroupKFold(n_splits=n_splits)
221
+ return splitter.split(X_all, y_all, groups=groups), n_splits
222
+
223
+ if self.is_time_strategy():
224
+ order = self.get_time_ordered_indices(X_all)
225
+ if len(order) < 2:
226
+ return iter([]), 0
227
+ n_splits = min(n_splits, max(2, len(order) - 1))
228
+ if n_splits < 2:
229
+ return iter([]), 0
230
+ splitter = TimeSeriesSplit(n_splits=n_splits)
231
+ return _OrderSplitter(splitter, order).split(X_all), n_splits
232
+
233
+ # Random strategy
234
+ if len(X_all) < n_splits:
235
+ n_splits = len(X_all)
236
+ if n_splits < 2:
237
+ return iter([]), 0
238
+ splitter = ShuffleSplit(
239
+ n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
240
+ )
241
+ return splitter.split(X_all), n_splits
242
+
243
+ def create_kfold_splitter(
244
+ self,
245
+ X_all: pd.DataFrame,
246
+ k: int,
247
+ ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
248
+ """Create a K-fold splitter for ensemble training.
249
+
250
+ Args:
251
+ X_all: DataFrame to split
252
+ k: Number of folds
253
+
254
+ Returns:
255
+ Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
256
+ """
257
+ k = max(2, int(k))
258
+ n_samples = len(X_all)
259
+ if n_samples < 2:
260
+ return None, 0
261
+
262
+ if self.is_group_strategy():
263
+ groups = self.get_groups(X_all)
264
+ n_groups = int(groups.nunique(dropna=False))
265
+ if n_groups < 2:
266
+ return None, 0
267
+ k = min(k, n_groups)
268
+ if k < 2:
269
+ return None, 0
270
+ splitter = GroupKFold(n_splits=k)
271
+ return splitter.split(X_all, y=None, groups=groups), k
272
+
273
+ if self.is_time_strategy():
274
+ order = self.get_time_ordered_indices(X_all)
275
+ if len(order) < 2:
276
+ return None, 0
277
+ k = min(k, max(2, len(order) - 1))
278
+ if k < 2:
279
+ return None, 0
280
+ splitter = TimeSeriesSplit(n_splits=k)
281
+ return _OrderSplitter(splitter, order).split(X_all), k
282
+
283
+ # Random strategy with KFold
284
+ k = min(k, n_samples)
285
+ if k < 2:
286
+ return None, 0
287
+ splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
288
+ return splitter.split(X_all), k
289
+
290
+
45
291
  # =============================================================================
46
292
  # Trainer system
47
293
  # =============================================================================
@@ -692,6 +938,15 @@ class TrainerBase:
692
938
  *,
693
939
  allow_default: bool = False,
694
940
  ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
941
+ """Resolve train/validation split indices based on configured CV strategy.
942
+
943
+ Args:
944
+ X_all: DataFrame to split
945
+ allow_default: If True, use default val_ratio when config is invalid
946
+
947
+ Returns:
948
+ Tuple of (train_indices, val_indices) or None if not enough data
949
+ """
695
950
  val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
696
951
  if not (0.0 < val_ratio < 1.0):
697
952
  if not allow_default:
@@ -700,46 +955,8 @@ class TrainerBase:
700
955
  if len(X_all) < 10:
701
956
  return None
702
957
 
703
- strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
704
- if strategy in {"time", "timeseries", "temporal"}:
705
- time_col = getattr(self.ctx.config, "cv_time_col", None)
706
- if not time_col:
707
- raise ValueError("cv_time_col is required for time cv_strategy.")
708
- if time_col not in self.ctx.train_data.columns:
709
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
710
- ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
711
- order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
712
- index_set = set(X_all.index)
713
- order_index = [idx for idx in order_index if idx in index_set]
714
- order = X_all.index.get_indexer(order_index)
715
- order = order[order >= 0]
716
- cutoff = int(len(order) * (1.0 - val_ratio))
717
- if cutoff <= 0 or cutoff >= len(order):
718
- raise ValueError(
719
- f"prop_test={val_ratio} leaves no data for train/val split.")
720
- return order[:cutoff], order[cutoff:]
721
-
722
- if strategy in {"group", "grouped"}:
723
- group_col = getattr(self.ctx.config, "cv_group_col", None)
724
- if not group_col:
725
- raise ValueError("cv_group_col is required for group cv_strategy.")
726
- if group_col not in self.ctx.train_data.columns:
727
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
728
- groups = self.ctx.train_data.reindex(X_all.index)[group_col]
729
- splitter = GroupShuffleSplit(
730
- n_splits=1,
731
- test_size=val_ratio,
732
- random_state=self.ctx.rand_seed,
733
- )
734
- train_idx, val_idx = next(splitter.split(X_all, groups=groups))
735
- return train_idx, val_idx
736
-
737
- splitter = ShuffleSplit(
738
- n_splits=1,
739
- test_size=val_ratio,
740
- random_state=self.ctx.rand_seed,
741
- )
742
- train_idx, val_idx = next(splitter.split(X_all))
958
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
959
+ (train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
743
960
  return train_idx, val_idx
744
961
 
745
962
  def _resolve_time_sample_indices(
@@ -747,25 +964,34 @@ class TrainerBase:
747
964
  X_all: pd.DataFrame,
748
965
  sample_limit: int,
749
966
  ) -> Optional[pd.Index]:
967
+ """Get the most recent indices for time-based sampling.
968
+
969
+ For time-based CV strategies, returns the last `sample_limit` indices
970
+ ordered by time. For other strategies, returns None.
971
+
972
+ Args:
973
+ X_all: DataFrame to sample from
974
+ sample_limit: Maximum number of samples to return
975
+
976
+ Returns:
977
+ Index of sampled rows, or None if not using time-based strategy
978
+ """
750
979
  if sample_limit <= 0:
751
980
  return None
752
- strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
753
- if strategy not in {"time", "timeseries", "temporal"}:
981
+
982
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
983
+ if not resolver.is_time_strategy():
754
984
  return None
755
- time_col = getattr(self.ctx.config, "cv_time_col", None)
756
- if not time_col:
757
- raise ValueError("cv_time_col is required for time cv_strategy.")
758
- if time_col not in self.ctx.train_data.columns:
759
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
760
- ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
761
- order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
762
- index_set = set(X_all.index)
763
- order_index = [idx for idx in order_index if idx in index_set]
764
- if not order_index:
985
+
986
+ order = resolver.get_time_ordered_indices(X_all)
987
+ if len(order) == 0:
765
988
  return None
766
- if len(order_index) > sample_limit:
767
- order_index = order_index[-sample_limit:]
768
- return pd.Index(order_index)
989
+
990
+ # Get the last sample_limit indices (most recent in time)
991
+ if len(order) > sample_limit:
992
+ order = order[-sample_limit:]
993
+
994
+ return X_all.index[order]
769
995
 
770
996
  def _resolve_ensemble_splits(
771
997
  self,
@@ -773,60 +999,17 @@ class TrainerBase:
773
999
  *,
774
1000
  k: int,
775
1001
  ) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
776
- k = max(2, int(k))
777
- n_samples = len(X_all)
778
- if n_samples < 2:
779
- return None, 0
780
-
781
- strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
782
- if strategy in {"group", "grouped"}:
783
- group_col = getattr(self.ctx.config, "cv_group_col", None)
784
- if not group_col:
785
- raise ValueError("cv_group_col is required for group cv_strategy.")
786
- if group_col not in self.ctx.train_data.columns:
787
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
788
- groups = self.ctx.train_data.reindex(X_all.index)[group_col]
789
- n_groups = int(groups.nunique(dropna=False))
790
- if n_groups < 2:
791
- return None, 0
792
- if k > n_groups:
793
- k = n_groups
794
- if k < 2:
795
- return None, 0
796
- splitter = GroupKFold(n_splits=k)
797
- return splitter.split(X_all, y=None, groups=groups), k
1002
+ """Resolve K-fold splits for ensemble training based on configured CV strategy.
798
1003
 
799
- if strategy in {"time", "timeseries", "temporal"}:
800
- time_col = getattr(self.ctx.config, "cv_time_col", None)
801
- if not time_col:
802
- raise ValueError("cv_time_col is required for time cv_strategy.")
803
- if time_col not in self.ctx.train_data.columns:
804
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
805
- ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
806
- order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
807
- index_set = set(X_all.index)
808
- order_index = [idx for idx in order_index if idx in index_set]
809
- order = X_all.index.get_indexer(order_index)
810
- order = order[order >= 0]
811
- if len(order) < 2:
812
- return None, 0
813
- if len(order) <= k:
814
- k = max(2, len(order) - 1)
815
- if k < 2:
816
- return None, 0
817
- splitter = TimeSeriesSplit(n_splits=k)
818
- return _OrderSplitter(splitter, order).split(X_all), k
1004
+ Args:
1005
+ X_all: DataFrame to split
1006
+ k: Number of folds requested
819
1007
 
820
- if n_samples < k:
821
- k = n_samples
822
- if k < 2:
823
- return None, 0
824
- splitter = KFold(
825
- n_splits=k,
826
- shuffle=True,
827
- random_state=self.ctx.rand_seed,
828
- )
829
- return splitter.split(X_all), k
1008
+ Returns:
1009
+ Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
1010
+ """
1011
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1012
+ return resolver.create_kfold_splitter(X_all, k)
830
1013
 
831
1014
  def cross_val_generic(
832
1015
  self,
@@ -892,7 +1075,6 @@ class TrainerBase:
892
1075
  w_all = w_all.loc[sampled_idx] if w_all is not None else None
893
1076
 
894
1077
  if splitter is None:
895
- strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
896
1078
  val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
897
1079
  if not (0.0 < val_ratio < 1.0):
898
1080
  val_ratio = 0.25
@@ -901,37 +1083,10 @@ class TrainerBase:
901
1083
  cv_splits = max(2, int(round(1 / val_ratio)))
902
1084
  cv_splits = max(2, int(cv_splits))
903
1085
 
904
- if strategy in {"group", "grouped"}:
905
- group_col = getattr(self.ctx.config, "cv_group_col", None)
906
- if not group_col:
907
- raise ValueError("cv_group_col is required for group cv_strategy.")
908
- if group_col not in self.ctx.train_data.columns:
909
- raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
910
- groups = self.ctx.train_data.reindex(X_all.index)[group_col]
911
- split_iter = GroupKFold(n_splits=cv_splits).split(X_all, y_all, groups=groups)
912
- elif strategy in {"time", "timeseries", "temporal"}:
913
- time_col = getattr(self.ctx.config, "cv_time_col", None)
914
- if not time_col:
915
- raise ValueError("cv_time_col is required for time cv_strategy.")
916
- if time_col not in self.ctx.train_data.columns:
917
- raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
918
- ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
919
- order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
920
- index_set = set(X_all.index)
921
- order_index = [idx for idx in order_index if idx in index_set]
922
- order = X_all.index.get_indexer(order_index)
923
- order = order[order >= 0]
924
- if len(order) <= cv_splits:
925
- cv_splits = max(2, len(order) - 1)
926
- if cv_splits < 2:
927
- raise ValueError("Not enough samples for time-series CV.")
928
- split_iter = _OrderSplitter(TimeSeriesSplit(n_splits=cv_splits), order).split(X_all)
929
- else:
930
- split_iter = ShuffleSplit(
931
- n_splits=cv_splits,
932
- test_size=val_ratio,
933
- random_state=self.ctx.rand_seed
934
- ).split(X_all)
1086
+ resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
1087
+ split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
1088
+ if actual_splits < 2:
1089
+ raise ValueError("Not enough samples for cross-validation.")
935
1090
  else:
936
1091
  if hasattr(splitter, "split"):
937
1092
  split_iter = splitter.split(X_all, y_all, groups=None)