ins-pricing 0.2.9__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. ins_pricing/CHANGELOG.md +93 -0
  2. ins_pricing/README.md +11 -0
  3. ins_pricing/cli/Explain_entry.py +50 -48
  4. ins_pricing/cli/bayesopt_entry_runner.py +699 -569
  5. ins_pricing/cli/utils/evaluation_context.py +320 -0
  6. ins_pricing/cli/utils/import_resolver.py +350 -0
  7. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +449 -0
  8. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +406 -0
  9. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +247 -0
  10. ins_pricing/modelling/core/bayesopt/config_components.py +351 -0
  11. ins_pricing/modelling/core/bayesopt/config_preprocess.py +3 -4
  12. ins_pricing/modelling/core/bayesopt/core.py +153 -94
  13. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +122 -34
  14. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +298 -142
  15. ins_pricing/modelling/core/bayesopt/utils/__init__.py +86 -0
  16. ins_pricing/modelling/core/bayesopt/utils/constants.py +183 -0
  17. ins_pricing/modelling/core/bayesopt/utils/distributed_utils.py +186 -0
  18. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +126 -0
  19. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +540 -0
  20. ins_pricing/modelling/core/bayesopt/utils/torch_trainer_mixin.py +591 -0
  21. ins_pricing/modelling/core/bayesopt/utils.py +98 -1496
  22. ins_pricing/modelling/core/bayesopt/utils_backup.py +1503 -0
  23. ins_pricing/setup.py +1 -1
  24. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/METADATA +14 -1
  25. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/RECORD +27 -14
  26. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/WHEEL +0 -0
  27. {ins_pricing-0.2.9.dist-info → ins_pricing-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,351 @@
1
+ """Nested configuration components for BayesOptConfig.
2
+
3
+ This module provides focused configuration dataclasses that group related settings
4
+ together, improving maintainability and reducing the cognitive load of the main
5
+ BayesOptConfig class.
6
+
7
+ Usage:
8
+ config = BayesOptConfig(
9
+ model_nme="pricing_model",
10
+ resp_nme="claim",
11
+ weight_nme="exposure",
12
+ factor_nmes=["age", "gender"],
13
+ distributed=DistributedConfig(use_ft_ddp=True),
14
+ gnn=GNNConfig(use_approx_knn=False),
15
+ )
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from dataclasses import dataclass, field
21
+ from typing import Any, Dict, List, Optional
22
+
23
+
24
+ @dataclass
25
+ class DistributedConfig:
26
+ """Configuration for distributed training (DDP/DataParallel).
27
+
28
+ Attributes:
29
+ use_resn_data_parallel: Use DataParallel for ResNet
30
+ use_ft_data_parallel: Use DataParallel for FT-Transformer
31
+ use_gnn_data_parallel: Use DataParallel for GNN
32
+ use_resn_ddp: Use DistributedDataParallel for ResNet
33
+ use_ft_ddp: Use DistributedDataParallel for FT-Transformer
34
+ use_gnn_ddp: Use DistributedDataParallel for GNN
35
+ """
36
+
37
+ use_resn_data_parallel: bool = False
38
+ use_ft_data_parallel: bool = False
39
+ use_gnn_data_parallel: bool = False
40
+ use_resn_ddp: bool = False
41
+ use_ft_ddp: bool = False
42
+ use_gnn_ddp: bool = False
43
+
44
+ @classmethod
45
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "DistributedConfig":
46
+ """Create from a flat dictionary with prefixed keys."""
47
+ return cls(
48
+ use_resn_data_parallel=bool(d.get("use_resn_data_parallel", False)),
49
+ use_ft_data_parallel=bool(d.get("use_ft_data_parallel", False)),
50
+ use_gnn_data_parallel=bool(d.get("use_gnn_data_parallel", False)),
51
+ use_resn_ddp=bool(d.get("use_resn_ddp", False)),
52
+ use_ft_ddp=bool(d.get("use_ft_ddp", False)),
53
+ use_gnn_ddp=bool(d.get("use_gnn_ddp", False)),
54
+ )
55
+
56
+
57
+ @dataclass
58
+ class GNNConfig:
59
+ """Configuration for Graph Neural Network training.
60
+
61
+ Attributes:
62
+ use_approx_knn: Use approximate k-NN for graph construction
63
+ approx_knn_threshold: Row count threshold for approximate k-NN
64
+ graph_cache: Path to cache/load adjacency matrix
65
+ max_gpu_knn_nodes: Max nodes for GPU k-NN construction
66
+ knn_gpu_mem_ratio: Fraction of GPU memory for k-NN
67
+ knn_gpu_mem_overhead: Temporary memory overhead multiplier
68
+ """
69
+
70
+ use_approx_knn: bool = True
71
+ approx_knn_threshold: int = 50000
72
+ graph_cache: Optional[str] = None
73
+ max_gpu_knn_nodes: int = 200000
74
+ knn_gpu_mem_ratio: float = 0.9
75
+ knn_gpu_mem_overhead: float = 2.0
76
+
77
+ @classmethod
78
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "GNNConfig":
79
+ """Create from a flat dictionary with prefixed keys."""
80
+ return cls(
81
+ use_approx_knn=bool(d.get("gnn_use_approx_knn", True)),
82
+ approx_knn_threshold=int(d.get("gnn_approx_knn_threshold", 50000)),
83
+ graph_cache=d.get("gnn_graph_cache"),
84
+ max_gpu_knn_nodes=int(d.get("gnn_max_gpu_knn_nodes", 200000)),
85
+ knn_gpu_mem_ratio=float(d.get("gnn_knn_gpu_mem_ratio", 0.9)),
86
+ knn_gpu_mem_overhead=float(d.get("gnn_knn_gpu_mem_overhead", 2.0)),
87
+ )
88
+
89
+
90
+ @dataclass
91
+ class GeoTokenConfig:
92
+ """Configuration for geographic token embeddings.
93
+
94
+ Attributes:
95
+ feature_nmes: Feature column names for geo tokens
96
+ hidden_dim: Hidden dimension for geo token network
97
+ layers: Number of layers in geo token network
98
+ dropout: Dropout rate
99
+ k_neighbors: Number of neighbors for geo tokens
100
+ learning_rate: Learning rate for geo token training
101
+ epochs: Training epochs for geo tokens
102
+ """
103
+
104
+ feature_nmes: Optional[List[str]] = None
105
+ hidden_dim: int = 32
106
+ layers: int = 2
107
+ dropout: float = 0.1
108
+ k_neighbors: int = 10
109
+ learning_rate: float = 1e-3
110
+ epochs: int = 50
111
+
112
+ @classmethod
113
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "GeoTokenConfig":
114
+ """Create from a flat dictionary with prefixed keys."""
115
+ return cls(
116
+ feature_nmes=d.get("geo_feature_nmes"),
117
+ hidden_dim=int(d.get("geo_token_hidden_dim", 32)),
118
+ layers=int(d.get("geo_token_layers", 2)),
119
+ dropout=float(d.get("geo_token_dropout", 0.1)),
120
+ k_neighbors=int(d.get("geo_token_k_neighbors", 10)),
121
+ learning_rate=float(d.get("geo_token_learning_rate", 1e-3)),
122
+ epochs=int(d.get("geo_token_epochs", 50)),
123
+ )
124
+
125
+
126
+ @dataclass
127
+ class RegionConfig:
128
+ """Configuration for region/geographic effects.
129
+
130
+ Attributes:
131
+ province_col: Column name for province/state
132
+ city_col: Column name for city
133
+ effect_alpha: Regularization alpha for region effects
134
+ """
135
+
136
+ province_col: Optional[str] = None
137
+ city_col: Optional[str] = None
138
+ effect_alpha: float = 50.0
139
+
140
+ @classmethod
141
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "RegionConfig":
142
+ """Create from a flat dictionary with prefixed keys."""
143
+ return cls(
144
+ province_col=d.get("region_province_col"),
145
+ city_col=d.get("region_city_col"),
146
+ effect_alpha=float(d.get("region_effect_alpha", 50.0)),
147
+ )
148
+
149
+
150
+ @dataclass
151
+ class FTTransformerConfig:
152
+ """Configuration for FT-Transformer model.
153
+
154
+ Attributes:
155
+ role: Model role ('model', 'embedding', 'unsupervised_embedding')
156
+ feature_prefix: Prefix for generated embedding features
157
+ num_numeric_tokens: Number of numeric tokens
158
+ """
159
+
160
+ role: str = "model"
161
+ feature_prefix: str = "ft_emb"
162
+ num_numeric_tokens: Optional[int] = None
163
+
164
+ @classmethod
165
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "FTTransformerConfig":
166
+ """Create from a flat dictionary with prefixed keys."""
167
+ return cls(
168
+ role=str(d.get("ft_role", "model")),
169
+ feature_prefix=str(d.get("ft_feature_prefix", "ft_emb")),
170
+ num_numeric_tokens=d.get("ft_num_numeric_tokens"),
171
+ )
172
+
173
+
174
+ @dataclass
175
+ class XGBoostConfig:
176
+ """Configuration for XGBoost model.
177
+
178
+ Attributes:
179
+ max_depth_max: Maximum tree depth for hyperparameter tuning
180
+ n_estimators_max: Maximum number of estimators for tuning
181
+ """
182
+
183
+ max_depth_max: int = 25
184
+ n_estimators_max: int = 500
185
+
186
+ @classmethod
187
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "XGBoostConfig":
188
+ """Create from a flat dictionary with prefixed keys."""
189
+ return cls(
190
+ max_depth_max=int(d.get("xgb_max_depth_max", 25)),
191
+ n_estimators_max=int(d.get("xgb_n_estimators_max", 500)),
192
+ )
193
+
194
+
195
+ @dataclass
196
+ class CVConfig:
197
+ """Configuration for cross-validation.
198
+
199
+ Attributes:
200
+ strategy: CV strategy ('random', 'group', 'time', 'stratified')
201
+ splits: Number of CV splits
202
+ group_col: Column for group-based CV
203
+ time_col: Column for time-based CV
204
+ time_ascending: Whether to sort time ascending
205
+ """
206
+
207
+ strategy: str = "random"
208
+ splits: Optional[int] = None
209
+ group_col: Optional[str] = None
210
+ time_col: Optional[str] = None
211
+ time_ascending: bool = True
212
+
213
+ @classmethod
214
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "CVConfig":
215
+ """Create from a flat dictionary with prefixed keys."""
216
+ return cls(
217
+ strategy=str(d.get("cv_strategy", "random")),
218
+ splits=d.get("cv_splits"),
219
+ group_col=d.get("cv_group_col"),
220
+ time_col=d.get("cv_time_col"),
221
+ time_ascending=bool(d.get("cv_time_ascending", True)),
222
+ )
223
+
224
+
225
+ @dataclass
226
+ class FTOOFConfig:
227
+ """Configuration for FT-Transformer out-of-fold predictions.
228
+
229
+ Attributes:
230
+ folds: Number of OOF folds
231
+ strategy: OOF strategy
232
+ shuffle: Whether to shuffle data
233
+ """
234
+
235
+ folds: Optional[int] = None
236
+ strategy: Optional[str] = None
237
+ shuffle: bool = True
238
+
239
+ @classmethod
240
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "FTOOFConfig":
241
+ """Create from a flat dictionary with prefixed keys."""
242
+ return cls(
243
+ folds=d.get("ft_oof_folds"),
244
+ strategy=d.get("ft_oof_strategy"),
245
+ shuffle=bool(d.get("ft_oof_shuffle", True)),
246
+ )
247
+
248
+
249
+ @dataclass
250
+ class OutputConfig:
251
+ """Configuration for output and caching.
252
+
253
+ Attributes:
254
+ output_dir: Base output directory
255
+ optuna_storage: Optuna study storage path
256
+ optuna_study_prefix: Prefix for Optuna study names
257
+ best_params_files: Mapping of trainer keys to param files
258
+ save_preprocess: Whether to save preprocessing artifacts
259
+ preprocess_artifact_path: Path for preprocessing artifacts
260
+ plot_path_style: Plot path style ('nested' or 'flat')
261
+ cache_predictions: Whether to cache predictions
262
+ prediction_cache_dir: Directory for prediction cache
263
+ prediction_cache_format: Format for prediction cache ('parquet' or 'csv')
264
+ """
265
+
266
+ output_dir: Optional[str] = None
267
+ optuna_storage: Optional[str] = None
268
+ optuna_study_prefix: Optional[str] = None
269
+ best_params_files: Optional[Dict[str, str]] = None
270
+ save_preprocess: bool = False
271
+ preprocess_artifact_path: Optional[str] = None
272
+ plot_path_style: str = "nested"
273
+ cache_predictions: bool = False
274
+ prediction_cache_dir: Optional[str] = None
275
+ prediction_cache_format: str = "parquet"
276
+
277
+ @classmethod
278
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "OutputConfig":
279
+ """Create from a flat dictionary with prefixed keys."""
280
+ return cls(
281
+ output_dir=d.get("output_dir"),
282
+ optuna_storage=d.get("optuna_storage"),
283
+ optuna_study_prefix=d.get("optuna_study_prefix"),
284
+ best_params_files=d.get("best_params_files"),
285
+ save_preprocess=bool(d.get("save_preprocess", False)),
286
+ preprocess_artifact_path=d.get("preprocess_artifact_path"),
287
+ plot_path_style=str(d.get("plot_path_style", "nested")),
288
+ cache_predictions=bool(d.get("cache_predictions", False)),
289
+ prediction_cache_dir=d.get("prediction_cache_dir"),
290
+ prediction_cache_format=str(d.get("prediction_cache_format", "parquet")),
291
+ )
292
+
293
+
294
+ @dataclass
295
+ class EnsembleConfig:
296
+ """Configuration for ensemble training.
297
+
298
+ Attributes:
299
+ final_ensemble: Whether to use final ensemble
300
+ final_ensemble_k: Number of models in ensemble
301
+ final_refit: Whether to refit after ensemble
302
+ """
303
+
304
+ final_ensemble: bool = False
305
+ final_ensemble_k: int = 3
306
+ final_refit: bool = True
307
+
308
+ @classmethod
309
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "EnsembleConfig":
310
+ """Create from a flat dictionary with prefixed keys."""
311
+ return cls(
312
+ final_ensemble=bool(d.get("final_ensemble", False)),
313
+ final_ensemble_k=int(d.get("final_ensemble_k", 3)),
314
+ final_refit=bool(d.get("final_refit", True)),
315
+ )
316
+
317
+
318
+ @dataclass
319
+ class TrainingConfig:
320
+ """Core training configuration.
321
+
322
+ Attributes:
323
+ prop_test: Proportion of data for validation
324
+ rand_seed: Random seed for reproducibility
325
+ epochs: Number of training epochs
326
+ use_gpu: Whether to use GPU
327
+ reuse_best_params: Whether to reuse best params
328
+ resn_weight_decay: Weight decay for ResNet
329
+ bo_sample_limit: Sample limit for Bayesian optimization
330
+ """
331
+
332
+ prop_test: float = 0.25
333
+ rand_seed: Optional[int] = None
334
+ epochs: int = 100
335
+ use_gpu: bool = True
336
+ reuse_best_params: bool = False
337
+ resn_weight_decay: float = 1e-4
338
+ bo_sample_limit: Optional[int] = None
339
+
340
+ @classmethod
341
+ def from_flat_dict(cls, d: Dict[str, Any]) -> "TrainingConfig":
342
+ """Create from a flat dictionary with prefixed keys."""
343
+ return cls(
344
+ prop_test=float(d.get("prop_test", 0.25)),
345
+ rand_seed=d.get("rand_seed"),
346
+ epochs=int(d.get("epochs", 100)),
347
+ use_gpu=bool(d.get("use_gpu", True)),
348
+ reuse_best_params=bool(d.get("reuse_best_params", False)),
349
+ resn_weight_decay=float(d.get("resn_weight_decay", 1e-4)),
350
+ bo_sample_limit=d.get("bo_sample_limit"),
351
+ )
@@ -366,10 +366,9 @@ class DatasetPreprocessor:
366
366
  def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
367
367
  config: BayesOptConfig) -> None:
368
368
  self.config = config
369
- # Use shallow copy to avoid unnecessary memory overhead
370
- # Deep copies only made when actually modifying data
371
- self.train_data = train_df.copy(deep=False)
372
- self.test_data = test_df.copy(deep=False)
369
+ # Copy inputs to avoid mutating caller-provided DataFrames.
370
+ self.train_data = train_df.copy()
371
+ self.test_data = test_df.copy()
373
372
  self.num_features: List[str] = []
374
373
  self.train_oht_data: Optional[pd.DataFrame] = None
375
374
  self.test_oht_data: Optional[pd.DataFrame] = None
@@ -48,7 +48,10 @@ class _CVSplitter:
48
48
  # =============================================================================
49
49
  class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
50
50
  def __init__(self, train_data, test_data,
51
- model_nme, resp_nme, weight_nme, factor_nmes: Optional[List[str]] = None, task_type='regression',
51
+ config: Optional[BayesOptConfig] = None,
52
+ # Backward compatibility: individual parameters (DEPRECATED)
53
+ model_nme=None, resp_nme=None, weight_nme=None,
54
+ factor_nmes: Optional[List[str]] = None, task_type='regression',
52
55
  binary_resp_nme=None,
53
56
  cate_list=None, prop_test=0.25, rand_seed=None,
54
57
  epochs=100, use_gpu=True,
@@ -108,6 +111,10 @@ class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
108
111
  Args:
109
112
  train_data: Training DataFrame.
110
113
  test_data: Test DataFrame.
114
+ config: BayesOptConfig instance with all configuration (RECOMMENDED).
115
+ If provided, all other parameters are ignored.
116
+
117
+ # DEPRECATED: Individual parameters (use config instead)
111
118
  model_nme: Model name prefix used in outputs.
112
119
  resp_nme: Target column name.
113
120
  weight_nme: Sample weight column name.
@@ -136,101 +143,153 @@ class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
136
143
  final_ensemble: Enable k-fold model averaging at the final stage.
137
144
  final_ensemble_k: Number of folds for averaging.
138
145
  final_refit: Refit on full data using best stopping point.
146
+
147
+ Examples:
148
+ # New style (recommended):
149
+ config = BayesOptConfig(
150
+ model_nme="my_model",
151
+ resp_nme="target",
152
+ weight_nme="weight",
153
+ factor_nmes=["feat1", "feat2"]
154
+ )
155
+ model = BayesOptModel(train_df, test_df, config=config)
156
+
157
+ # Old style (deprecated, for backward compatibility):
158
+ model = BayesOptModel(
159
+ train_df, test_df,
160
+ model_nme="my_model",
161
+ resp_nme="target",
162
+ weight_nme="weight",
163
+ factor_nmes=["feat1", "feat2"]
164
+ )
139
165
  """
140
- inferred_factors, inferred_cats = infer_factor_and_cate_list(
141
- train_df=train_data,
142
- test_df=test_data,
143
- resp_nme=resp_nme,
144
- weight_nme=weight_nme,
145
- binary_resp_nme=binary_resp_nme,
146
- factor_nmes=factor_nmes,
147
- cate_list=cate_list,
148
- infer_categorical_max_unique=int(infer_categorical_max_unique),
149
- infer_categorical_max_ratio=float(infer_categorical_max_ratio),
150
- )
166
+ # Detect which API is being used
167
+ if config is not None:
168
+ # New API: config object provided
169
+ if isinstance(config, BayesOptConfig):
170
+ cfg = config
171
+ else:
172
+ raise TypeError(
173
+ f"config must be a BayesOptConfig instance, got {type(config).__name__}"
174
+ )
175
+ else:
176
+ # Old API: individual parameters (backward compatibility)
177
+ # Show deprecation warning
178
+ import warnings
179
+ warnings.warn(
180
+ "Passing individual parameters to BayesOptModel.__init__ is deprecated. "
181
+ "Use the 'config' parameter with a BayesOptConfig instance instead:\n"
182
+ " config = BayesOptConfig(model_nme=..., resp_nme=..., ...)\n"
183
+ " model = BayesOptModel(train_data, test_data, config=config)\n"
184
+ "Individual parameters will be removed in v0.4.0.",
185
+ DeprecationWarning,
186
+ stacklevel=2
187
+ )
151
188
 
152
- cfg = BayesOptConfig(
153
- model_nme=model_nme,
154
- task_type=task_type,
155
- resp_nme=resp_nme,
156
- weight_nme=weight_nme,
157
- factor_nmes=list(inferred_factors),
158
- binary_resp_nme=binary_resp_nme,
159
- cate_list=list(inferred_cats) if inferred_cats else None,
160
- prop_test=prop_test,
161
- rand_seed=rand_seed,
162
- epochs=epochs,
163
- use_gpu=use_gpu,
164
- xgb_max_depth_max=int(xgb_max_depth_max),
165
- xgb_n_estimators_max=int(xgb_n_estimators_max),
166
- use_resn_data_parallel=use_resn_data_parallel,
167
- use_ft_data_parallel=use_ft_data_parallel,
168
- use_resn_ddp=use_resn_ddp,
169
- use_gnn_data_parallel=use_gnn_data_parallel,
170
- use_ft_ddp=use_ft_ddp,
171
- use_gnn_ddp=use_gnn_ddp,
172
- gnn_use_approx_knn=gnn_use_approx_knn,
173
- gnn_approx_knn_threshold=gnn_approx_knn_threshold,
174
- gnn_graph_cache=gnn_graph_cache,
175
- gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
176
- gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
177
- gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
178
- output_dir=output_dir,
179
- optuna_storage=optuna_storage,
180
- optuna_study_prefix=optuna_study_prefix,
181
- best_params_files=best_params_files,
182
- ft_role=str(ft_role or "model"),
183
- ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
184
- ft_num_numeric_tokens=ft_num_numeric_tokens,
185
- reuse_best_params=bool(reuse_best_params),
186
- resn_weight_decay=float(resn_weight_decay)
187
- if resn_weight_decay is not None
188
- else 1e-4,
189
- final_ensemble=bool(final_ensemble),
190
- final_ensemble_k=int(final_ensemble_k),
191
- final_refit=bool(final_refit),
192
- cv_strategy=str(cv_strategy or "random"),
193
- cv_splits=cv_splits,
194
- cv_group_col=cv_group_col,
195
- cv_time_col=cv_time_col,
196
- cv_time_ascending=bool(cv_time_ascending),
197
- ft_oof_folds=ft_oof_folds,
198
- ft_oof_strategy=ft_oof_strategy,
199
- ft_oof_shuffle=bool(ft_oof_shuffle),
200
- save_preprocess=bool(save_preprocess),
201
- preprocess_artifact_path=preprocess_artifact_path,
202
- plot_path_style=str(plot_path_style or "nested"),
203
- bo_sample_limit=bo_sample_limit,
204
- cache_predictions=bool(cache_predictions),
205
- prediction_cache_dir=prediction_cache_dir,
206
- prediction_cache_format=str(prediction_cache_format or "parquet"),
207
- region_province_col=region_province_col,
208
- region_city_col=region_city_col,
209
- region_effect_alpha=float(region_effect_alpha)
210
- if region_effect_alpha is not None
211
- else 50.0,
212
- geo_feature_nmes=list(geo_feature_nmes)
213
- if geo_feature_nmes is not None
214
- else None,
215
- geo_token_hidden_dim=int(geo_token_hidden_dim)
216
- if geo_token_hidden_dim is not None
217
- else 32,
218
- geo_token_layers=int(geo_token_layers)
219
- if geo_token_layers is not None
220
- else 2,
221
- geo_token_dropout=float(geo_token_dropout)
222
- if geo_token_dropout is not None
223
- else 0.1,
224
- geo_token_k_neighbors=int(geo_token_k_neighbors)
225
- if geo_token_k_neighbors is not None
226
- else 10,
227
- geo_token_learning_rate=float(geo_token_learning_rate)
228
- if geo_token_learning_rate is not None
229
- else 1e-3,
230
- geo_token_epochs=int(geo_token_epochs)
231
- if geo_token_epochs is not None
232
- else 50,
233
- )
189
+ # Validate required parameters
190
+ if model_nme is None:
191
+ raise ValueError("model_nme is required when not using config parameter")
192
+ if resp_nme is None:
193
+ raise ValueError("resp_nme is required when not using config parameter")
194
+ if weight_nme is None:
195
+ raise ValueError("weight_nme is required when not using config parameter")
196
+
197
+ # Infer categorical features if needed
198
+ inferred_factors, inferred_cats = infer_factor_and_cate_list(
199
+ train_df=train_data,
200
+ test_df=test_data,
201
+ resp_nme=resp_nme,
202
+ weight_nme=weight_nme,
203
+ binary_resp_nme=binary_resp_nme,
204
+ factor_nmes=factor_nmes,
205
+ cate_list=cate_list,
206
+ infer_categorical_max_unique=int(infer_categorical_max_unique),
207
+ infer_categorical_max_ratio=float(infer_categorical_max_ratio),
208
+ )
209
+
210
+ # Construct config from individual parameters
211
+ cfg = BayesOptConfig(
212
+ model_nme=model_nme,
213
+ task_type=task_type,
214
+ resp_nme=resp_nme,
215
+ weight_nme=weight_nme,
216
+ factor_nmes=list(inferred_factors),
217
+ binary_resp_nme=binary_resp_nme,
218
+ cate_list=list(inferred_cats) if inferred_cats else None,
219
+ prop_test=prop_test,
220
+ rand_seed=rand_seed,
221
+ epochs=epochs,
222
+ use_gpu=use_gpu,
223
+ xgb_max_depth_max=int(xgb_max_depth_max),
224
+ xgb_n_estimators_max=int(xgb_n_estimators_max),
225
+ use_resn_data_parallel=use_resn_data_parallel,
226
+ use_ft_data_parallel=use_ft_data_parallel,
227
+ use_resn_ddp=use_resn_ddp,
228
+ use_gnn_data_parallel=use_gnn_data_parallel,
229
+ use_ft_ddp=use_ft_ddp,
230
+ use_gnn_ddp=use_gnn_ddp,
231
+ gnn_use_approx_knn=gnn_use_approx_knn,
232
+ gnn_approx_knn_threshold=gnn_approx_knn_threshold,
233
+ gnn_graph_cache=gnn_graph_cache,
234
+ gnn_max_gpu_knn_nodes=gnn_max_gpu_knn_nodes,
235
+ gnn_knn_gpu_mem_ratio=gnn_knn_gpu_mem_ratio,
236
+ gnn_knn_gpu_mem_overhead=gnn_knn_gpu_mem_overhead,
237
+ output_dir=output_dir,
238
+ optuna_storage=optuna_storage,
239
+ optuna_study_prefix=optuna_study_prefix,
240
+ best_params_files=best_params_files,
241
+ ft_role=str(ft_role or "model"),
242
+ ft_feature_prefix=str(ft_feature_prefix or "ft_emb"),
243
+ ft_num_numeric_tokens=ft_num_numeric_tokens,
244
+ reuse_best_params=bool(reuse_best_params),
245
+ resn_weight_decay=float(resn_weight_decay)
246
+ if resn_weight_decay is not None
247
+ else 1e-4,
248
+ final_ensemble=bool(final_ensemble),
249
+ final_ensemble_k=int(final_ensemble_k),
250
+ final_refit=bool(final_refit),
251
+ cv_strategy=str(cv_strategy or "random"),
252
+ cv_splits=cv_splits,
253
+ cv_group_col=cv_group_col,
254
+ cv_time_col=cv_time_col,
255
+ cv_time_ascending=bool(cv_time_ascending),
256
+ ft_oof_folds=ft_oof_folds,
257
+ ft_oof_strategy=ft_oof_strategy,
258
+ ft_oof_shuffle=bool(ft_oof_shuffle),
259
+ save_preprocess=bool(save_preprocess),
260
+ preprocess_artifact_path=preprocess_artifact_path,
261
+ plot_path_style=str(plot_path_style or "nested"),
262
+ bo_sample_limit=bo_sample_limit,
263
+ cache_predictions=bool(cache_predictions),
264
+ prediction_cache_dir=prediction_cache_dir,
265
+ prediction_cache_format=str(prediction_cache_format or "parquet"),
266
+ region_province_col=region_province_col,
267
+ region_city_col=region_city_col,
268
+ region_effect_alpha=float(region_effect_alpha)
269
+ if region_effect_alpha is not None
270
+ else 50.0,
271
+ geo_feature_nmes=list(geo_feature_nmes)
272
+ if geo_feature_nmes is not None
273
+ else None,
274
+ geo_token_hidden_dim=int(geo_token_hidden_dim)
275
+ if geo_token_hidden_dim is not None
276
+ else 32,
277
+ geo_token_layers=int(geo_token_layers)
278
+ if geo_token_layers is not None
279
+ else 2,
280
+ geo_token_dropout=float(geo_token_dropout)
281
+ if geo_token_dropout is not None
282
+ else 0.1,
283
+ geo_token_k_neighbors=int(geo_token_k_neighbors)
284
+ if geo_token_k_neighbors is not None
285
+ else 10,
286
+ geo_token_learning_rate=float(geo_token_learning_rate)
287
+ if geo_token_learning_rate is not None
288
+ else 1e-3,
289
+ geo_token_epochs=int(geo_token_epochs)
290
+ if geo_token_epochs is not None
291
+ else 50,
292
+ )
234
293
  self.config = cfg
235
294
  self.model_nme = cfg.model_nme
236
295
  self.task_type = cfg.task_type