ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/CHANGELOG.md +93 -0
- ins_pricing/README.md +11 -0
- ins_pricing/cli/Explain_entry.py +50 -48
- ins_pricing/cli/bayesopt_entry_runner.py +699 -569
- ins_pricing/cli/utils/evaluation_context.py +320 -0
- ins_pricing/cli/utils/import_resolver.py +350 -0
- ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
- ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
- ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
- ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
- ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
- ins_pricing/modelling/core/bayesopt/core.py +153 -94
- ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
- ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
- ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
- ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
- ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
- ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
- ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
- ins_pricing/setup.py +1 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -42,6 +42,252 @@ class _OrderSplitter:
|
|
|
42
42
|
for tr_idx, val_idx in self._splitter.split(X_ord, y=y, groups=groups):
|
|
43
43
|
yield order[tr_idx], order[val_idx]
|
|
44
44
|
|
|
45
|
+
|
|
46
|
+
# =============================================================================
|
|
47
|
+
# CV Strategy Resolution Helper
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CVStrategyResolver:
|
|
52
|
+
"""Helper class to resolve cross-validation splitting strategies.
|
|
53
|
+
|
|
54
|
+
This encapsulates the logic for determining how to split data based on the
|
|
55
|
+
configured strategy (random, time, group). It provides methods to:
|
|
56
|
+
- Get time-ordered indices for a dataset
|
|
57
|
+
- Get group values for a dataset
|
|
58
|
+
- Create appropriate sklearn splitters
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
TIME_STRATEGIES = {"time", "timeseries", "temporal"}
|
|
62
|
+
GROUP_STRATEGIES = {"group", "grouped"}
|
|
63
|
+
|
|
64
|
+
def __init__(self, config, train_data: pd.DataFrame, rand_seed: Optional[int] = None):
|
|
65
|
+
"""Initialize the resolver.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
config: BayesOptConfig with cv_strategy, cv_time_col, cv_group_col, etc.
|
|
69
|
+
train_data: The training DataFrame (needed for column access)
|
|
70
|
+
rand_seed: Random seed for reproducible splits
|
|
71
|
+
"""
|
|
72
|
+
self.config = config
|
|
73
|
+
self.train_data = train_data
|
|
74
|
+
self.rand_seed = rand_seed
|
|
75
|
+
self._strategy = self._normalize_strategy()
|
|
76
|
+
|
|
77
|
+
def _normalize_strategy(self) -> str:
|
|
78
|
+
"""Normalize the strategy string to lowercase."""
|
|
79
|
+
raw = str(getattr(self.config, "cv_strategy", "random") or "random")
|
|
80
|
+
return raw.strip().lower()
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def strategy(self) -> str:
|
|
84
|
+
"""Return the normalized CV strategy."""
|
|
85
|
+
return self._strategy
|
|
86
|
+
|
|
87
|
+
def is_time_strategy(self) -> bool:
|
|
88
|
+
"""Check if using a time-based splitting strategy."""
|
|
89
|
+
return self._strategy in self.TIME_STRATEGIES
|
|
90
|
+
|
|
91
|
+
def is_group_strategy(self) -> bool:
|
|
92
|
+
"""Check if using a group-based splitting strategy."""
|
|
93
|
+
return self._strategy in self.GROUP_STRATEGIES
|
|
94
|
+
|
|
95
|
+
def get_time_col(self) -> str:
|
|
96
|
+
"""Get and validate the time column.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If time column is not configured
|
|
100
|
+
KeyError: If time column not found in train_data
|
|
101
|
+
"""
|
|
102
|
+
time_col = getattr(self.config, "cv_time_col", None)
|
|
103
|
+
if not time_col:
|
|
104
|
+
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
105
|
+
if time_col not in self.train_data.columns:
|
|
106
|
+
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
107
|
+
return time_col
|
|
108
|
+
|
|
109
|
+
def get_time_ascending(self) -> bool:
|
|
110
|
+
"""Get the time ordering preference."""
|
|
111
|
+
return bool(getattr(self.config, "cv_time_ascending", True))
|
|
112
|
+
|
|
113
|
+
def get_group_col(self) -> str:
|
|
114
|
+
"""Get and validate the group column.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
ValueError: If group column is not configured
|
|
118
|
+
KeyError: If group column not found in train_data
|
|
119
|
+
"""
|
|
120
|
+
group_col = getattr(self.config, "cv_group_col", None)
|
|
121
|
+
if not group_col:
|
|
122
|
+
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
123
|
+
if group_col not in self.train_data.columns:
|
|
124
|
+
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
125
|
+
return group_col
|
|
126
|
+
|
|
127
|
+
def get_time_ordered_indices(self, X_all: pd.DataFrame) -> np.ndarray:
|
|
128
|
+
"""Get indices ordered by time for the given dataset.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
X_all: DataFrame to get indices for (must have index compatible with train_data)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Array of positional indices into X_all, ordered by time
|
|
135
|
+
"""
|
|
136
|
+
time_col = self.get_time_col()
|
|
137
|
+
ascending = self.get_time_ascending()
|
|
138
|
+
order_index = self.train_data[time_col].sort_values(ascending=ascending).index
|
|
139
|
+
index_set = set(X_all.index)
|
|
140
|
+
order_index = [idx for idx in order_index if idx in index_set]
|
|
141
|
+
order = X_all.index.get_indexer(order_index)
|
|
142
|
+
return order[order >= 0]
|
|
143
|
+
|
|
144
|
+
def get_groups(self, X_all: pd.DataFrame) -> pd.Series:
|
|
145
|
+
"""Get group labels for the given dataset.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
X_all: DataFrame to get groups for
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Series of group labels aligned with X_all
|
|
152
|
+
"""
|
|
153
|
+
group_col = self.get_group_col()
|
|
154
|
+
return self.train_data.reindex(X_all.index)[group_col]
|
|
155
|
+
|
|
156
|
+
def create_train_val_splitter(
|
|
157
|
+
self,
|
|
158
|
+
X_all: pd.DataFrame,
|
|
159
|
+
val_ratio: float,
|
|
160
|
+
) -> Tuple[Optional[Tuple[np.ndarray, np.ndarray]], Optional[pd.Series]]:
|
|
161
|
+
"""Create a single train/val split based on strategy.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
X_all: DataFrame to split
|
|
165
|
+
val_ratio: Fraction of data for validation
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Tuple of ((train_idx, val_idx), groups) where groups is None for non-group strategies
|
|
169
|
+
"""
|
|
170
|
+
if self.is_time_strategy():
|
|
171
|
+
order = self.get_time_ordered_indices(X_all)
|
|
172
|
+
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
173
|
+
if cutoff <= 0 or cutoff >= len(order):
|
|
174
|
+
raise ValueError(f"val_ratio={val_ratio} leaves no data for train/val split.")
|
|
175
|
+
return (order[:cutoff], order[cutoff:]), None
|
|
176
|
+
|
|
177
|
+
if self.is_group_strategy():
|
|
178
|
+
groups = self.get_groups(X_all)
|
|
179
|
+
splitter = GroupShuffleSplit(
|
|
180
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
181
|
+
)
|
|
182
|
+
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
183
|
+
return (train_idx, val_idx), groups
|
|
184
|
+
|
|
185
|
+
# Random strategy
|
|
186
|
+
splitter = ShuffleSplit(
|
|
187
|
+
n_splits=1, test_size=val_ratio, random_state=self.rand_seed
|
|
188
|
+
)
|
|
189
|
+
train_idx, val_idx = next(splitter.split(X_all))
|
|
190
|
+
return (train_idx, val_idx), None
|
|
191
|
+
|
|
192
|
+
def create_cv_splitter(
|
|
193
|
+
self,
|
|
194
|
+
X_all: pd.DataFrame,
|
|
195
|
+
y_all: Optional[pd.Series],
|
|
196
|
+
n_splits: int,
|
|
197
|
+
val_ratio: float,
|
|
198
|
+
) -> Tuple[Iterable[Tuple[np.ndarray, np.ndarray]], int]:
|
|
199
|
+
"""Create a cross-validation splitter based on strategy.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
X_all: DataFrame to split
|
|
203
|
+
y_all: Target series (used by some splitters)
|
|
204
|
+
n_splits: Number of CV folds
|
|
205
|
+
val_ratio: Validation ratio (for ShuffleSplit)
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple of (split_iterator, actual_n_splits)
|
|
209
|
+
"""
|
|
210
|
+
n_splits = max(2, int(n_splits))
|
|
211
|
+
|
|
212
|
+
if self.is_group_strategy():
|
|
213
|
+
groups = self.get_groups(X_all)
|
|
214
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
215
|
+
if n_groups < 2:
|
|
216
|
+
return iter([]), 0
|
|
217
|
+
n_splits = min(n_splits, n_groups)
|
|
218
|
+
if n_splits < 2:
|
|
219
|
+
return iter([]), 0
|
|
220
|
+
splitter = GroupKFold(n_splits=n_splits)
|
|
221
|
+
return splitter.split(X_all, y_all, groups=groups), n_splits
|
|
222
|
+
|
|
223
|
+
if self.is_time_strategy():
|
|
224
|
+
order = self.get_time_ordered_indices(X_all)
|
|
225
|
+
if len(order) < 2:
|
|
226
|
+
return iter([]), 0
|
|
227
|
+
n_splits = min(n_splits, max(2, len(order) - 1))
|
|
228
|
+
if n_splits < 2:
|
|
229
|
+
return iter([]), 0
|
|
230
|
+
splitter = TimeSeriesSplit(n_splits=n_splits)
|
|
231
|
+
return _OrderSplitter(splitter, order).split(X_all), n_splits
|
|
232
|
+
|
|
233
|
+
# Random strategy
|
|
234
|
+
if len(X_all) < n_splits:
|
|
235
|
+
n_splits = len(X_all)
|
|
236
|
+
if n_splits < 2:
|
|
237
|
+
return iter([]), 0
|
|
238
|
+
splitter = ShuffleSplit(
|
|
239
|
+
n_splits=n_splits, test_size=val_ratio, random_state=self.rand_seed
|
|
240
|
+
)
|
|
241
|
+
return splitter.split(X_all), n_splits
|
|
242
|
+
|
|
243
|
+
def create_kfold_splitter(
|
|
244
|
+
self,
|
|
245
|
+
X_all: pd.DataFrame,
|
|
246
|
+
k: int,
|
|
247
|
+
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
248
|
+
"""Create a K-fold splitter for ensemble training.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
X_all: DataFrame to split
|
|
252
|
+
k: Number of folds
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
256
|
+
"""
|
|
257
|
+
k = max(2, int(k))
|
|
258
|
+
n_samples = len(X_all)
|
|
259
|
+
if n_samples < 2:
|
|
260
|
+
return None, 0
|
|
261
|
+
|
|
262
|
+
if self.is_group_strategy():
|
|
263
|
+
groups = self.get_groups(X_all)
|
|
264
|
+
n_groups = int(groups.nunique(dropna=False))
|
|
265
|
+
if n_groups < 2:
|
|
266
|
+
return None, 0
|
|
267
|
+
k = min(k, n_groups)
|
|
268
|
+
if k < 2:
|
|
269
|
+
return None, 0
|
|
270
|
+
splitter = GroupKFold(n_splits=k)
|
|
271
|
+
return splitter.split(X_all, y=None, groups=groups), k
|
|
272
|
+
|
|
273
|
+
if self.is_time_strategy():
|
|
274
|
+
order = self.get_time_ordered_indices(X_all)
|
|
275
|
+
if len(order) < 2:
|
|
276
|
+
return None, 0
|
|
277
|
+
k = min(k, max(2, len(order) - 1))
|
|
278
|
+
if k < 2:
|
|
279
|
+
return None, 0
|
|
280
|
+
splitter = TimeSeriesSplit(n_splits=k)
|
|
281
|
+
return _OrderSplitter(splitter, order).split(X_all), k
|
|
282
|
+
|
|
283
|
+
# Random strategy with KFold
|
|
284
|
+
k = min(k, n_samples)
|
|
285
|
+
if k < 2:
|
|
286
|
+
return None, 0
|
|
287
|
+
splitter = KFold(n_splits=k, shuffle=True, random_state=self.rand_seed)
|
|
288
|
+
return splitter.split(X_all), k
|
|
289
|
+
|
|
290
|
+
|
|
45
291
|
# =============================================================================
|
|
46
292
|
# Trainer system
|
|
47
293
|
# =============================================================================
|
|
@@ -692,6 +938,15 @@ class TrainerBase:
|
|
|
692
938
|
*,
|
|
693
939
|
allow_default: bool = False,
|
|
694
940
|
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
|
941
|
+
"""Resolve train/validation split indices based on configured CV strategy.
|
|
942
|
+
|
|
943
|
+
Args:
|
|
944
|
+
X_all: DataFrame to split
|
|
945
|
+
allow_default: If True, use default val_ratio when config is invalid
|
|
946
|
+
|
|
947
|
+
Returns:
|
|
948
|
+
Tuple of (train_indices, val_indices) or None if not enough data
|
|
949
|
+
"""
|
|
695
950
|
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
696
951
|
if not (0.0 < val_ratio < 1.0):
|
|
697
952
|
if not allow_default:
|
|
@@ -700,46 +955,8 @@ class TrainerBase:
|
|
|
700
955
|
if len(X_all) < 10:
|
|
701
956
|
return None
|
|
702
957
|
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
706
|
-
if not time_col:
|
|
707
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
708
|
-
if time_col not in self.ctx.train_data.columns:
|
|
709
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
710
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
711
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
712
|
-
index_set = set(X_all.index)
|
|
713
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
714
|
-
order = X_all.index.get_indexer(order_index)
|
|
715
|
-
order = order[order >= 0]
|
|
716
|
-
cutoff = int(len(order) * (1.0 - val_ratio))
|
|
717
|
-
if cutoff <= 0 or cutoff >= len(order):
|
|
718
|
-
raise ValueError(
|
|
719
|
-
f"prop_test={val_ratio} leaves no data for train/val split.")
|
|
720
|
-
return order[:cutoff], order[cutoff:]
|
|
721
|
-
|
|
722
|
-
if strategy in {"group", "grouped"}:
|
|
723
|
-
group_col = getattr(self.ctx.config, "cv_group_col", None)
|
|
724
|
-
if not group_col:
|
|
725
|
-
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
726
|
-
if group_col not in self.ctx.train_data.columns:
|
|
727
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
728
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
729
|
-
splitter = GroupShuffleSplit(
|
|
730
|
-
n_splits=1,
|
|
731
|
-
test_size=val_ratio,
|
|
732
|
-
random_state=self.ctx.rand_seed,
|
|
733
|
-
)
|
|
734
|
-
train_idx, val_idx = next(splitter.split(X_all, groups=groups))
|
|
735
|
-
return train_idx, val_idx
|
|
736
|
-
|
|
737
|
-
splitter = ShuffleSplit(
|
|
738
|
-
n_splits=1,
|
|
739
|
-
test_size=val_ratio,
|
|
740
|
-
random_state=self.ctx.rand_seed,
|
|
741
|
-
)
|
|
742
|
-
train_idx, val_idx = next(splitter.split(X_all))
|
|
958
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
959
|
+
(train_idx, val_idx), _ = resolver.create_train_val_splitter(X_all, val_ratio)
|
|
743
960
|
return train_idx, val_idx
|
|
744
961
|
|
|
745
962
|
def _resolve_time_sample_indices(
|
|
@@ -747,25 +964,34 @@ class TrainerBase:
|
|
|
747
964
|
X_all: pd.DataFrame,
|
|
748
965
|
sample_limit: int,
|
|
749
966
|
) -> Optional[pd.Index]:
|
|
967
|
+
"""Get the most recent indices for time-based sampling.
|
|
968
|
+
|
|
969
|
+
For time-based CV strategies, returns the last `sample_limit` indices
|
|
970
|
+
ordered by time. For other strategies, returns None.
|
|
971
|
+
|
|
972
|
+
Args:
|
|
973
|
+
X_all: DataFrame to sample from
|
|
974
|
+
sample_limit: Maximum number of samples to return
|
|
975
|
+
|
|
976
|
+
Returns:
|
|
977
|
+
Index of sampled rows, or None if not using time-based strategy
|
|
978
|
+
"""
|
|
750
979
|
if sample_limit <= 0:
|
|
751
980
|
return None
|
|
752
|
-
|
|
753
|
-
|
|
981
|
+
|
|
982
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
983
|
+
if not resolver.is_time_strategy():
|
|
754
984
|
return None
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
if time_col not in self.ctx.train_data.columns:
|
|
759
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
760
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
761
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
762
|
-
index_set = set(X_all.index)
|
|
763
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
764
|
-
if not order_index:
|
|
985
|
+
|
|
986
|
+
order = resolver.get_time_ordered_indices(X_all)
|
|
987
|
+
if len(order) == 0:
|
|
765
988
|
return None
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
989
|
+
|
|
990
|
+
# Get the last sample_limit indices (most recent in time)
|
|
991
|
+
if len(order) > sample_limit:
|
|
992
|
+
order = order[-sample_limit:]
|
|
993
|
+
|
|
994
|
+
return X_all.index[order]
|
|
769
995
|
|
|
770
996
|
def _resolve_ensemble_splits(
|
|
771
997
|
self,
|
|
@@ -773,60 +999,17 @@ class TrainerBase:
|
|
|
773
999
|
*,
|
|
774
1000
|
k: int,
|
|
775
1001
|
) -> Tuple[Optional[Iterable[Tuple[np.ndarray, np.ndarray]]], int]:
|
|
776
|
-
|
|
777
|
-
n_samples = len(X_all)
|
|
778
|
-
if n_samples < 2:
|
|
779
|
-
return None, 0
|
|
1002
|
+
"""Resolve K-fold splits for ensemble training based on configured CV strategy.
|
|
780
1003
|
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
if not group_col:
|
|
785
|
-
raise ValueError("cv_group_col is required for group cv_strategy.")
|
|
786
|
-
if group_col not in self.ctx.train_data.columns:
|
|
787
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
788
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
789
|
-
n_groups = int(groups.nunique(dropna=False))
|
|
790
|
-
if n_groups < 2:
|
|
791
|
-
return None, 0
|
|
792
|
-
if k > n_groups:
|
|
793
|
-
k = n_groups
|
|
794
|
-
if k < 2:
|
|
795
|
-
return None, 0
|
|
796
|
-
splitter = GroupKFold(n_splits=k)
|
|
797
|
-
return splitter.split(X_all, y=None, groups=groups), k
|
|
798
|
-
|
|
799
|
-
if strategy in {"time", "timeseries", "temporal"}:
|
|
800
|
-
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
801
|
-
if not time_col:
|
|
802
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
803
|
-
if time_col not in self.ctx.train_data.columns:
|
|
804
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
805
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
806
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
807
|
-
index_set = set(X_all.index)
|
|
808
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
809
|
-
order = X_all.index.get_indexer(order_index)
|
|
810
|
-
order = order[order >= 0]
|
|
811
|
-
if len(order) < 2:
|
|
812
|
-
return None, 0
|
|
813
|
-
if len(order) <= k:
|
|
814
|
-
k = max(2, len(order) - 1)
|
|
815
|
-
if k < 2:
|
|
816
|
-
return None, 0
|
|
817
|
-
splitter = TimeSeriesSplit(n_splits=k)
|
|
818
|
-
return _OrderSplitter(splitter, order).split(X_all), k
|
|
1004
|
+
Args:
|
|
1005
|
+
X_all: DataFrame to split
|
|
1006
|
+
k: Number of folds requested
|
|
819
1007
|
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
n_splits=k,
|
|
826
|
-
shuffle=True,
|
|
827
|
-
random_state=self.ctx.rand_seed,
|
|
828
|
-
)
|
|
829
|
-
return splitter.split(X_all), k
|
|
1008
|
+
Returns:
|
|
1009
|
+
Tuple of (split_iterator, actual_k) or (None, 0) if not enough data
|
|
1010
|
+
"""
|
|
1011
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1012
|
+
return resolver.create_kfold_splitter(X_all, k)
|
|
830
1013
|
|
|
831
1014
|
def cross_val_generic(
|
|
832
1015
|
self,
|
|
@@ -892,7 +1075,6 @@ class TrainerBase:
|
|
|
892
1075
|
w_all = w_all.loc[sampled_idx] if w_all is not None else None
|
|
893
1076
|
|
|
894
1077
|
if splitter is None:
|
|
895
|
-
strategy = str(getattr(self.ctx.config, "cv_strategy", "random") or "random").strip().lower()
|
|
896
1078
|
val_ratio = float(self.ctx.prop_test) if self.ctx.prop_test is not None else 0.25
|
|
897
1079
|
if not (0.0 < val_ratio < 1.0):
|
|
898
1080
|
val_ratio = 0.25
|
|
@@ -901,37 +1083,10 @@ class TrainerBase:
|
|
|
901
1083
|
cv_splits = max(2, int(round(1 / val_ratio)))
|
|
902
1084
|
cv_splits = max(2, int(cv_splits))
|
|
903
1085
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
if group_col not in self.ctx.train_data.columns:
|
|
909
|
-
raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
|
|
910
|
-
groups = self.ctx.train_data.reindex(X_all.index)[group_col]
|
|
911
|
-
split_iter = GroupKFold(n_splits=cv_splits).split(X_all, y_all, groups=groups)
|
|
912
|
-
elif strategy in {"time", "timeseries", "temporal"}:
|
|
913
|
-
time_col = getattr(self.ctx.config, "cv_time_col", None)
|
|
914
|
-
if not time_col:
|
|
915
|
-
raise ValueError("cv_time_col is required for time cv_strategy.")
|
|
916
|
-
if time_col not in self.ctx.train_data.columns:
|
|
917
|
-
raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
|
|
918
|
-
ascending = bool(getattr(self.ctx.config, "cv_time_ascending", True))
|
|
919
|
-
order_index = self.ctx.train_data[time_col].sort_values(ascending=ascending).index
|
|
920
|
-
index_set = set(X_all.index)
|
|
921
|
-
order_index = [idx for idx in order_index if idx in index_set]
|
|
922
|
-
order = X_all.index.get_indexer(order_index)
|
|
923
|
-
order = order[order >= 0]
|
|
924
|
-
if len(order) <= cv_splits:
|
|
925
|
-
cv_splits = max(2, len(order) - 1)
|
|
926
|
-
if cv_splits < 2:
|
|
927
|
-
raise ValueError("Not enough samples for time-series CV.")
|
|
928
|
-
split_iter = _OrderSplitter(TimeSeriesSplit(n_splits=cv_splits), order).split(X_all)
|
|
929
|
-
else:
|
|
930
|
-
split_iter = ShuffleSplit(
|
|
931
|
-
n_splits=cv_splits,
|
|
932
|
-
test_size=val_ratio,
|
|
933
|
-
random_state=self.ctx.rand_seed
|
|
934
|
-
).split(X_all)
|
|
1086
|
+
resolver = CVStrategyResolver(self.ctx.config, self.ctx.train_data, self.ctx.rand_seed)
|
|
1087
|
+
split_iter, actual_splits = resolver.create_cv_splitter(X_all, y_all, cv_splits, val_ratio)
|
|
1088
|
+
if actual_splits < 2:
|
|
1089
|
+
raise ValueError("Not enough samples for cross-validation.")
|
|
935
1090
|
else:
|
|
936
1091
|
if hasattr(splitter, "split"):
|
|
937
1092
|
split_iter = splitter.split(X_all, y_all, groups=None)
|
|
@@ -939,7 +1094,7 @@ class TrainerBase:
|
|
|
939
1094
|
split_iter = splitter
|
|
940
1095
|
|
|
941
1096
|
losses: List[float] = []
|
|
942
|
-
for train_idx, val_idx in split_iter:
|
|
1097
|
+
for fold_idx, (train_idx, val_idx) in enumerate(split_iter):
|
|
943
1098
|
X_train = X_all.iloc[train_idx]
|
|
944
1099
|
y_train = y_all.iloc[train_idx]
|
|
945
1100
|
X_val = X_all.iloc[val_idx]
|
|
@@ -953,9 +1108,11 @@ class TrainerBase:
|
|
|
953
1108
|
model = model_builder(params)
|
|
954
1109
|
try:
|
|
955
1110
|
if fit_predict_fn:
|
|
1111
|
+
# Avoid duplicate Optuna step reports across folds.
|
|
1112
|
+
trial_for_fold = trial if fold_idx == 0 else None
|
|
956
1113
|
y_pred = fit_predict_fn(
|
|
957
1114
|
model, X_train, y_train, w_train,
|
|
958
|
-
X_val, y_val, w_val,
|
|
1115
|
+
X_val, y_val, w_val, trial_for_fold
|
|
959
1116
|
)
|
|
960
1117
|
else:
|
|
961
1118
|
fit_kwargs = {}
|
|
@@ -1133,4 +1290,3 @@ class TrainerBase:
|
|
|
1133
1290
|
predict_kwargs_train=predict_kwargs_train,
|
|
1134
1291
|
predict_kwargs_test=predict_kwargs_test,
|
|
1135
1292
|
predict_fn=predict_fn)
|
|
1136
|
-
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Backward compatibility re-exports from refactored utils modules.
|
|
2
|
+
|
|
3
|
+
This module ensures all existing imports continue to work:
|
|
4
|
+
from ins_pricing.modelling.core.bayesopt.utils import EPS, IOUtils, ...
|
|
5
|
+
|
|
6
|
+
The utils.py file has been split into focused modules for better maintainability:
|
|
7
|
+
- constants.py: EPS, set_global_seed, etc.
|
|
8
|
+
- io_utils.py: IOUtils for file I/O
|
|
9
|
+
- distributed_utils.py: DistributedUtils, TrainingUtils for DDP
|
|
10
|
+
- torch_trainer_mixin.py: TorchTrainerMixin for PyTorch training
|
|
11
|
+
- metrics_and_devices.py: Metrics, GPU/device management, CV strategies, plotting
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
# Constants and simple utilities
|
|
17
|
+
from .constants import (
|
|
18
|
+
EPS,
|
|
19
|
+
set_global_seed,
|
|
20
|
+
ensure_parent_dir,
|
|
21
|
+
compute_batch_size,
|
|
22
|
+
tweedie_loss,
|
|
23
|
+
infer_factor_and_cate_list,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# I/O utilities
|
|
27
|
+
from .io_utils import (
|
|
28
|
+
IOUtils,
|
|
29
|
+
csv_to_dict,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Distributed training
|
|
33
|
+
from .distributed_utils import (
|
|
34
|
+
DistributedUtils,
|
|
35
|
+
TrainingUtils,
|
|
36
|
+
free_cuda,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# PyTorch training mixin
|
|
40
|
+
from .torch_trainer_mixin import (
|
|
41
|
+
TorchTrainerMixin,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Metrics, devices, CV, and plotting
|
|
45
|
+
from .metrics_and_devices import (
|
|
46
|
+
get_logger,
|
|
47
|
+
MetricFactory,
|
|
48
|
+
GPUMemoryManager,
|
|
49
|
+
DeviceManager,
|
|
50
|
+
CVStrategyResolver,
|
|
51
|
+
PlotUtils,
|
|
52
|
+
split_data,
|
|
53
|
+
plot_lift_list,
|
|
54
|
+
plot_dlift_list,
|
|
55
|
+
_OrderedSplitter,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
# Constants
|
|
60
|
+
'EPS',
|
|
61
|
+
'set_global_seed',
|
|
62
|
+
'ensure_parent_dir',
|
|
63
|
+
'compute_batch_size',
|
|
64
|
+
'tweedie_loss',
|
|
65
|
+
'infer_factor_and_cate_list',
|
|
66
|
+
# I/O
|
|
67
|
+
'IOUtils',
|
|
68
|
+
'csv_to_dict',
|
|
69
|
+
# Distributed
|
|
70
|
+
'DistributedUtils',
|
|
71
|
+
'TrainingUtils',
|
|
72
|
+
'free_cuda',
|
|
73
|
+
# PyTorch
|
|
74
|
+
'TorchTrainerMixin',
|
|
75
|
+
# Utilities
|
|
76
|
+
'get_logger',
|
|
77
|
+
'MetricFactory',
|
|
78
|
+
'GPUMemoryManager',
|
|
79
|
+
'DeviceManager',
|
|
80
|
+
'CVStrategyResolver',
|
|
81
|
+
'PlotUtils',
|
|
82
|
+
'split_data',
|
|
83
|
+
'plot_lift_list',
|
|
84
|
+
'plot_dlift_list',
|
|
85
|
+
'_OrderedSplitter',
|
|
86
|
+
]
|