ins-pricing 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. ins_pricing/README.md +74 -56
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +832 -898
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +1440 -1438
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/README.md +573 -419
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/config_builder.py +1 -0
  17. ins_pricing/frontend/example_workflows.py +1 -1
  18. ins_pricing/governance/__init__.py +20 -20
  19. ins_pricing/governance/release.py +159 -159
  20. ins_pricing/modelling/README.md +67 -0
  21. ins_pricing/modelling/__init__.py +147 -92
  22. ins_pricing/modelling/bayesopt/README.md +59 -0
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -550
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -962
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  32. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  37. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  39. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  40. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  41. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  42. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  43. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  44. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  45. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  46. ins_pricing/modelling/explain/__init__.py +55 -55
  47. ins_pricing/modelling/explain/metrics.py +27 -174
  48. ins_pricing/modelling/explain/permutation.py +237 -237
  49. ins_pricing/modelling/plotting/__init__.py +40 -36
  50. ins_pricing/modelling/plotting/compat.py +228 -0
  51. ins_pricing/modelling/plotting/curves.py +572 -572
  52. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  53. ins_pricing/modelling/plotting/geo.py +362 -362
  54. ins_pricing/modelling/plotting/importance.py +121 -121
  55. ins_pricing/pricing/__init__.py +27 -27
  56. ins_pricing/production/__init__.py +35 -25
  57. ins_pricing/production/{predict.py → inference.py} +140 -57
  58. ins_pricing/production/monitoring.py +8 -21
  59. ins_pricing/reporting/__init__.py +11 -11
  60. ins_pricing/setup.py +1 -1
  61. ins_pricing/tests/production/test_inference.py +90 -0
  62. ins_pricing/utils/__init__.py +116 -83
  63. ins_pricing/utils/device.py +255 -255
  64. ins_pricing/utils/features.py +53 -0
  65. ins_pricing/utils/io.py +72 -0
  66. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  67. ins_pricing/utils/metrics.py +158 -24
  68. ins_pricing/utils/numerics.py +76 -0
  69. ins_pricing/utils/paths.py +9 -1
  70. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +55 -35
  71. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  72. ins_pricing/CHANGELOG.md +0 -272
  73. ins_pricing/RELEASE_NOTES_0.2.8.md +0 -344
  74. ins_pricing/docs/LOSS_FUNCTIONS.md +0 -78
  75. ins_pricing/docs/modelling/BayesOpt_USAGE.md +0 -945
  76. ins_pricing/docs/modelling/README.md +0 -34
  77. ins_pricing/frontend/QUICKSTART.md +0 -152
  78. ins_pricing/modelling/core/BayesOpt.py +0 -146
  79. ins_pricing/modelling/core/__init__.py +0 -1
  80. ins_pricing/modelling/core/bayesopt/PHASE2_REFACTORING_SUMMARY.md +0 -449
  81. ins_pricing/modelling/core/bayesopt/PHASE3_REFACTORING_SUMMARY.md +0 -406
  82. ins_pricing/modelling/core/bayesopt/REFACTORING_SUMMARY.md +0 -247
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.4.dist-info/RECORD +0 -137
  92. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  93. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  94. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  95. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  96. {ins_pricing-0.4.4.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,550 +1,562 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from dataclasses import dataclass, asdict
6
- from datetime import datetime
7
- from pathlib import Path
8
- from typing import Any, Dict, List, Optional
9
-
10
- import numpy as np
11
- import pandas as pd
12
- from sklearn.preprocessing import StandardScaler
13
-
14
- from .utils import IOUtils
15
- from .utils.losses import normalize_loss_name
16
- from ....exceptions import ConfigurationError, DataValidationError
17
-
18
- # NOTE: Some CSV exports may contain invisible BOM characters or leading/trailing
19
- # spaces in column names. Pandas requires exact matches, so we normalize a few
20
- # "required" column names (response/weight/binary response) before validating.
21
-
22
-
23
- def _clean_column_name(name: Any) -> Any:
24
- if not isinstance(name, str):
25
- return name
26
- return name.replace("\ufeff", "").strip()
27
-
28
-
29
- def _normalize_required_columns(
30
- df: pd.DataFrame, required: List[Optional[str]], *, df_label: str
31
- ) -> None:
32
- required_names = [r for r in required if isinstance(r, str) and r.strip()]
33
- if not required_names:
34
- return
35
-
36
- mapping: Dict[Any, Any] = {}
37
- existing = set(df.columns)
38
- for col in df.columns:
39
- cleaned = _clean_column_name(col)
40
- if cleaned != col and cleaned not in existing:
41
- mapping[col] = cleaned
42
- if mapping:
43
- df.rename(columns=mapping, inplace=True)
44
-
45
- existing = set(df.columns)
46
- for req in required_names:
47
- if req in existing:
48
- continue
49
- candidates = [
50
- col
51
- for col in df.columns
52
- if isinstance(col, str) and _clean_column_name(col).lower() == req.lower()
53
- ]
54
- if len(candidates) == 1 and req not in existing:
55
- df.rename(columns={candidates[0]: req}, inplace=True)
56
- existing = set(df.columns)
57
- elif len(candidates) > 1:
58
- raise KeyError(
59
- f"{df_label} has multiple columns matching required {req!r} "
60
- f"(case/space-insensitive): {candidates}"
61
- )
62
-
63
-
64
- # ===== Core components and training wrappers =================================
65
-
66
- # =============================================================================
67
- # Config, preprocessing, and trainer base types
68
- # =============================================================================
69
- @dataclass
70
- class BayesOptConfig:
71
- """Configuration for Bayesian optimization-based model training.
72
-
73
- This dataclass holds all configuration parameters for the BayesOpt training
74
- pipeline, including model settings, distributed training options, and
75
- cross-validation strategies.
76
-
77
- Attributes:
78
- model_nme: Unique identifier for the model
79
- resp_nme: Column name for the response/target variable
80
- weight_nme: Column name for sample weights
81
- factor_nmes: List of feature column names
82
- task_type: Either 'regression' or 'classification'
83
- binary_resp_nme: Column name for binary response (optional)
84
- cate_list: List of categorical feature column names
85
- loss_name: Regression loss ('auto', 'tweedie', 'poisson', 'gamma', 'mse', 'mae')
86
- prop_test: Proportion of data for validation (0.0-1.0)
87
- rand_seed: Random seed for reproducibility
88
- epochs: Number of training epochs
89
- use_gpu: Whether to use GPU acceleration
90
- xgb_max_depth_max: Maximum tree depth for XGBoost tuning
91
- xgb_n_estimators_max: Maximum estimators for XGBoost tuning
92
- use_resn_data_parallel: Use DataParallel for ResNet
93
- use_ft_data_parallel: Use DataParallel for FT-Transformer
94
- use_resn_ddp: Use DDP for ResNet
95
- use_ft_ddp: Use DDP for FT-Transformer
96
- use_gnn_data_parallel: Use DataParallel for GNN
97
- use_gnn_ddp: Use DDP for GNN
98
- ft_role: FT-Transformer role ('model', 'embedding', 'unsupervised_embedding')
99
- cv_strategy: CV strategy ('random', 'group', 'time', 'stratified')
100
-
101
- Example:
102
- >>> config = BayesOptConfig(
103
- ... model_nme="pricing_model",
104
- ... resp_nme="claim_amount",
105
- ... weight_nme="exposure",
106
- ... factor_nmes=["age", "gender", "region"],
107
- ... task_type="regression",
108
- ... use_ft_ddp=True,
109
- ... )
110
- """
111
-
112
- # Required fields
113
- model_nme: str
114
- resp_nme: str
115
- weight_nme: str
116
- factor_nmes: List[str]
117
-
118
- # Task configuration
119
- task_type: str = 'regression'
120
- binary_resp_nme: Optional[str] = None
121
- cate_list: Optional[List[str]] = None
122
- loss_name: str = "auto"
123
-
124
- # Training configuration
125
- prop_test: float = 0.25
126
- rand_seed: Optional[int] = None
127
- epochs: int = 100
128
- use_gpu: bool = True
129
-
130
- # XGBoost settings
131
- xgb_max_depth_max: int = 25
132
- xgb_n_estimators_max: int = 500
133
-
134
- # Distributed training settings
135
- use_resn_data_parallel: bool = False
136
- use_ft_data_parallel: bool = False
137
- use_resn_ddp: bool = False
138
- use_ft_ddp: bool = False
139
- use_gnn_data_parallel: bool = False
140
- use_gnn_ddp: bool = False
141
-
142
- # GNN settings
143
- gnn_use_approx_knn: bool = True
144
- gnn_approx_knn_threshold: int = 50000
145
- gnn_graph_cache: Optional[str] = None
146
- gnn_max_gpu_knn_nodes: Optional[int] = 200000
147
- gnn_knn_gpu_mem_ratio: float = 0.9
148
- gnn_knn_gpu_mem_overhead: float = 2.0
149
-
150
- # Region/Geo settings
151
- region_province_col: Optional[str] = None
152
- region_city_col: Optional[str] = None
153
- region_effect_alpha: float = 50.0
154
- geo_feature_nmes: Optional[List[str]] = None
155
- geo_token_hidden_dim: int = 32
156
- geo_token_layers: int = 2
157
- geo_token_dropout: float = 0.1
158
- geo_token_k_neighbors: int = 10
159
- geo_token_learning_rate: float = 1e-3
160
- geo_token_epochs: int = 50
161
-
162
- # Output settings
163
- output_dir: Optional[str] = None
164
- optuna_storage: Optional[str] = None
165
- optuna_study_prefix: Optional[str] = None
166
- best_params_files: Optional[Dict[str, str]] = None
167
-
168
- # FT-Transformer settings
169
- ft_role: str = "model"
170
- ft_feature_prefix: str = "ft_emb"
171
- ft_num_numeric_tokens: Optional[int] = None
172
-
173
- # Training workflow settings
174
- reuse_best_params: bool = False
175
- resn_weight_decay: float = 1e-4
176
- final_ensemble: bool = False
177
- final_ensemble_k: int = 3
178
- final_refit: bool = True
179
-
180
- # Cross-validation settings
181
- cv_strategy: str = "random"
182
- cv_splits: Optional[int] = None
183
- cv_group_col: Optional[str] = None
184
- cv_time_col: Optional[str] = None
185
- cv_time_ascending: bool = True
186
- ft_oof_folds: Optional[int] = None
187
- ft_oof_strategy: Optional[str] = None
188
- ft_oof_shuffle: bool = True
189
-
190
- # Caching and output settings
191
- save_preprocess: bool = False
192
- preprocess_artifact_path: Optional[str] = None
193
- plot_path_style: str = "nested"
194
- bo_sample_limit: Optional[int] = None
195
- cache_predictions: bool = False
196
- prediction_cache_dir: Optional[str] = None
197
- prediction_cache_format: str = "parquet"
198
- dataloader_workers: Optional[int] = None
199
-
200
- def __post_init__(self) -> None:
201
- """Validate configuration after initialization."""
202
- self._validate()
203
-
204
- def _validate(self) -> None:
205
- """Validate configuration values and raise errors for invalid combinations."""
206
- errors: List[str] = []
207
-
208
- # Validate task_type
209
- valid_task_types = {"regression", "classification"}
210
- if self.task_type not in valid_task_types:
211
- errors.append(
212
- f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
213
- )
214
- if self.dataloader_workers is not None:
215
- try:
216
- if int(self.dataloader_workers) < 0:
217
- errors.append("dataloader_workers must be >= 0 when provided.")
218
- except (TypeError, ValueError):
219
- errors.append("dataloader_workers must be an integer when provided.")
220
- # Validate loss_name
221
- try:
222
- normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
223
- if self.task_type == "classification" and normalized_loss not in {"auto", "logloss", "bce"}:
224
- errors.append(
225
- "loss_name must be 'auto', 'logloss', or 'bce' for classification tasks."
226
- )
227
- except ValueError as exc:
228
- errors.append(str(exc))
229
-
230
- # Validate prop_test
231
- if not 0.0 < self.prop_test < 1.0:
232
- errors.append(
233
- f"prop_test must be between 0 and 1, got {self.prop_test}"
234
- )
235
-
236
- # Validate epochs
237
- if self.epochs < 1:
238
- errors.append(f"epochs must be >= 1, got {self.epochs}")
239
-
240
- # Validate XGBoost settings
241
- if self.xgb_max_depth_max < 1:
242
- errors.append(
243
- f"xgb_max_depth_max must be >= 1, got {self.xgb_max_depth_max}"
244
- )
245
- if self.xgb_n_estimators_max < 1:
246
- errors.append(
247
- f"xgb_n_estimators_max must be >= 1, got {self.xgb_n_estimators_max}"
248
- )
249
-
250
- # Validate distributed training: can't use both DataParallel and DDP
251
- if self.use_resn_data_parallel and self.use_resn_ddp:
252
- errors.append(
253
- "Cannot use both use_resn_data_parallel and use_resn_ddp"
254
- )
255
- if self.use_ft_data_parallel and self.use_ft_ddp:
256
- errors.append(
257
- "Cannot use both use_ft_data_parallel and use_ft_ddp"
258
- )
259
- if self.use_gnn_data_parallel and self.use_gnn_ddp:
260
- errors.append(
261
- "Cannot use both use_gnn_data_parallel and use_gnn_ddp"
262
- )
263
-
264
- # Validate ft_role
265
- valid_ft_roles = {"model", "embedding", "unsupervised_embedding"}
266
- if self.ft_role not in valid_ft_roles:
267
- errors.append(
268
- f"ft_role must be one of {valid_ft_roles}, got '{self.ft_role}'"
269
- )
270
-
271
- # Validate cv_strategy
272
- valid_cv_strategies = {"random", "group", "grouped", "time", "timeseries", "temporal", "stratified"}
273
- if self.cv_strategy not in valid_cv_strategies:
274
- errors.append(
275
- f"cv_strategy must be one of {valid_cv_strategies}, got '{self.cv_strategy}'"
276
- )
277
-
278
- # Validate group CV requires group_col
279
- if self.cv_strategy in {"group", "grouped"} and not self.cv_group_col:
280
- errors.append(
281
- f"cv_group_col is required when cv_strategy is '{self.cv_strategy}'"
282
- )
283
-
284
- # Validate time CV requires time_col
285
- if self.cv_strategy in {"time", "timeseries", "temporal"} and not self.cv_time_col:
286
- errors.append(
287
- f"cv_time_col is required when cv_strategy is '{self.cv_strategy}'"
288
- )
289
-
290
- # Validate prediction_cache_format
291
- valid_cache_formats = {"parquet", "csv"}
292
- if self.prediction_cache_format not in valid_cache_formats:
293
- errors.append(
294
- f"prediction_cache_format must be one of {valid_cache_formats}, "
295
- f"got '{self.prediction_cache_format}'"
296
- )
297
-
298
- # Validate GNN memory settings
299
- if self.gnn_knn_gpu_mem_ratio <= 0 or self.gnn_knn_gpu_mem_ratio > 1.0:
300
- errors.append(
301
- f"gnn_knn_gpu_mem_ratio must be in (0, 1], got {self.gnn_knn_gpu_mem_ratio}"
302
- )
303
-
304
- if errors:
305
- raise ConfigurationError(
306
- "BayesOptConfig validation failed:\n - " + "\n - ".join(errors)
307
- )
308
-
309
-
310
- @dataclass
311
- class PreprocessArtifacts:
312
- factor_nmes: List[str]
313
- cate_list: List[str]
314
- num_features: List[str]
315
- var_nmes: List[str]
316
- cat_categories: Dict[str, List[Any]]
317
- dummy_columns: List[str]
318
- numeric_scalers: Dict[str, Dict[str, float]]
319
- weight_nme: str
320
- resp_nme: str
321
- binary_resp_nme: Optional[str] = None
322
- drop_first: bool = True
323
-
324
-
325
- class OutputManager:
326
- # Centralize output paths for plots, results, and models.
327
-
328
- def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
329
- self.root = Path(root or os.getcwd())
330
- self.model_name = model_name
331
- self.plot_dir = self.root / 'plot'
332
- self.result_dir = self.root / 'Results'
333
- self.model_dir = self.root / 'model'
334
-
335
- def _prepare(self, path: Path) -> str:
336
- IOUtils.ensure_parent_dir(str(path))
337
- return str(path)
338
-
339
- def plot_path(self, filename: str) -> str:
340
- return self._prepare(self.plot_dir / filename)
341
-
342
- def result_path(self, filename: str) -> str:
343
- return self._prepare(self.result_dir / filename)
344
-
345
- def model_path(self, filename: str) -> str:
346
- return self._prepare(self.model_dir / filename)
347
-
348
-
349
- class VersionManager:
350
- """Lightweight versioning: save config and best-params snapshots for traceability."""
351
-
352
- def __init__(self, output: OutputManager) -> None:
353
- self.output = output
354
- self.version_dir = Path(self.output.result_dir) / "versions"
355
- IOUtils.ensure_parent_dir(str(self.version_dir))
356
-
357
- def save(self, tag: str, payload: Dict[str, Any]) -> str:
358
- safe_tag = tag.replace(" ", "_")
359
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
360
- path = self.version_dir / f"{ts}_{safe_tag}.json"
361
- IOUtils.ensure_parent_dir(str(path))
362
- with open(path, "w", encoding="utf-8") as f:
363
- json.dump(payload, f, ensure_ascii=False, indent=2, default=str)
364
- print(f"[Version] Saved snapshot: {path}")
365
- return str(path)
366
-
367
- def load_latest(self, tag: str) -> Optional[Dict[str, Any]]:
368
- """Load the latest snapshot for a tag (sorted by timestamp prefix)."""
369
- safe_tag = tag.replace(" ", "_")
370
- pattern = f"*_{safe_tag}.json"
371
- candidates = sorted(self.version_dir.glob(pattern))
372
- if not candidates:
373
- return None
374
- path = candidates[-1]
375
- try:
376
- return json.loads(path.read_text(encoding="utf-8"))
377
- except Exception as exc:
378
- print(f"[Version] Failed to load snapshot {path}: {exc}")
379
- return None
380
-
381
-
382
- class DatasetPreprocessor:
383
- # Prepare shared train/test views for trainers.
384
-
385
- def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
386
- config: BayesOptConfig) -> None:
387
- self.config = config
388
- # Copy inputs to avoid mutating caller-provided DataFrames.
389
- self.train_data = train_df.copy()
390
- self.test_data = test_df.copy()
391
- self.num_features: List[str] = []
392
- self.train_oht_data: Optional[pd.DataFrame] = None
393
- self.test_oht_data: Optional[pd.DataFrame] = None
394
- self.train_oht_scl_data: Optional[pd.DataFrame] = None
395
- self.test_oht_scl_data: Optional[pd.DataFrame] = None
396
- self.var_nmes: List[str] = []
397
- self.cat_categories_for_shap: Dict[str, List[Any]] = {}
398
- self.numeric_scalers: Dict[str, Dict[str, float]] = {}
399
-
400
- def run(self) -> "DatasetPreprocessor":
401
- """Run preprocessing: categorical encoding, target clipping, numeric scaling."""
402
- cfg = self.config
403
- _normalize_required_columns(
404
- self.train_data,
405
- [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
406
- df_label="Train data",
407
- )
408
- _normalize_required_columns(
409
- self.test_data,
410
- [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
411
- df_label="Test data",
412
- )
413
- missing_train = [
414
- col for col in (cfg.resp_nme, cfg.weight_nme)
415
- if col not in self.train_data.columns
416
- ]
417
- if missing_train:
418
- raise DataValidationError(
419
- f"Train data missing required columns: {missing_train}. "
420
- f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
421
- )
422
- if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.train_data.columns:
423
- raise DataValidationError(
424
- f"Train data missing binary response column: {cfg.binary_resp_nme}. "
425
- f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
426
- )
427
-
428
- test_has_resp = cfg.resp_nme in self.test_data.columns
429
- test_has_weight = cfg.weight_nme in self.test_data.columns
430
- test_has_binary = bool(
431
- cfg.binary_resp_nme and cfg.binary_resp_nme in self.test_data.columns
432
- )
433
- if not test_has_weight:
434
- self.test_data[cfg.weight_nme] = 1.0
435
- if not test_has_resp:
436
- self.test_data[cfg.resp_nme] = np.nan
437
- if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.test_data.columns:
438
- self.test_data[cfg.binary_resp_nme] = np.nan
439
-
440
- # Precompute weighted actuals for plots and validation checks.
441
- # Direct assignment is more efficient than .loc[:, col]
442
- self.train_data['w_act'] = self.train_data[cfg.resp_nme] * \
443
- self.train_data[cfg.weight_nme]
444
- if test_has_resp:
445
- self.test_data['w_act'] = self.test_data[cfg.resp_nme] * \
446
- self.test_data[cfg.weight_nme]
447
- if cfg.binary_resp_nme:
448
- self.train_data['w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
449
- self.train_data[cfg.weight_nme]
450
- if test_has_binary:
451
- self.test_data['w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
452
- self.test_data[cfg.weight_nme]
453
- # High-quantile clipping absorbs outliers; removing it lets extremes dominate loss.
454
- q99 = self.train_data[cfg.resp_nme].quantile(0.999)
455
- self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
456
- upper=q99)
457
- cate_list = list(cfg.cate_list or [])
458
- if cate_list:
459
- for cate in cate_list:
460
- self.train_data[cate] = self.train_data[cate].astype(
461
- 'category')
462
- self.test_data[cate] = self.test_data[cate].astype('category')
463
- cats = self.train_data[cate].cat.categories
464
- self.cat_categories_for_shap[cate] = list(cats)
465
- self.num_features = [
466
- nme for nme in cfg.factor_nmes if nme not in cate_list]
467
-
468
- # Memory optimization: Single copy + in-place operations
469
- train_oht = self.train_data[cfg.factor_nmes +
470
- [cfg.weight_nme] + [cfg.resp_nme]].copy()
471
- test_oht = self.test_data[cfg.factor_nmes +
472
- [cfg.weight_nme] + [cfg.resp_nme]].copy()
473
- train_oht = pd.get_dummies(
474
- train_oht,
475
- columns=cate_list,
476
- drop_first=True,
477
- dtype=np.int8
478
- )
479
- test_oht = pd.get_dummies(
480
- test_oht,
481
- columns=cate_list,
482
- drop_first=True,
483
- dtype=np.int8
484
- )
485
-
486
- # Fill missing dummy columns when reindexing to align train/test columns.
487
- test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
488
-
489
- # Keep unscaled one-hot data for fold-specific scaling to avoid leakage.
490
- # Store direct references - these won't be mutated
491
- self.train_oht_data = train_oht
492
- self.test_oht_data = test_oht
493
-
494
- # Only copy if we need to scale numeric features (memory optimization)
495
- if self.num_features:
496
- train_oht_scaled = train_oht.copy()
497
- test_oht_scaled = test_oht.copy()
498
- else:
499
- # No scaling needed, reuse original
500
- train_oht_scaled = train_oht
501
- test_oht_scaled = test_oht
502
- for num_chr in self.num_features:
503
- # Scale per column so features are on comparable ranges for NN stability.
504
- scaler = StandardScaler()
505
- train_oht_scaled[num_chr] = scaler.fit_transform(
506
- train_oht_scaled[num_chr].values.reshape(-1, 1))
507
- test_oht_scaled[num_chr] = scaler.transform(
508
- test_oht_scaled[num_chr].values.reshape(-1, 1))
509
- scale_val = float(getattr(scaler, "scale_", [1.0])[0])
510
- if scale_val == 0.0:
511
- scale_val = 1.0
512
- self.numeric_scalers[num_chr] = {
513
- "mean": float(getattr(scaler, "mean_", [0.0])[0]),
514
- "scale": scale_val,
515
- }
516
- # Fill missing dummy columns when reindexing to align train/test columns.
517
- test_oht_scaled = test_oht_scaled.reindex(
518
- columns=train_oht_scaled.columns, fill_value=0)
519
- self.train_oht_scl_data = train_oht_scaled
520
- self.test_oht_scl_data = test_oht_scaled
521
- excluded = {cfg.weight_nme, cfg.resp_nme}
522
- self.var_nmes = [
523
- col for col in train_oht_scaled.columns if col not in excluded
524
- ]
525
- return self
526
-
527
- def export_artifacts(self) -> PreprocessArtifacts:
528
- dummy_columns: List[str] = []
529
- if self.train_oht_data is not None:
530
- dummy_columns = list(self.train_oht_data.columns)
531
- return PreprocessArtifacts(
532
- factor_nmes=list(self.config.factor_nmes),
533
- cate_list=list(self.config.cate_list or []),
534
- num_features=list(self.num_features),
535
- var_nmes=list(self.var_nmes),
536
- cat_categories=dict(self.cat_categories_for_shap),
537
- dummy_columns=dummy_columns,
538
- numeric_scalers=dict(self.numeric_scalers),
539
- weight_nme=str(self.config.weight_nme),
540
- resp_nme=str(self.config.resp_nme),
541
- binary_resp_nme=self.config.binary_resp_nme,
542
- drop_first=True,
543
- )
544
-
545
- def save_artifacts(self, path: str | Path) -> str:
546
- payload = self.export_artifacts()
547
- target = Path(path)
548
- target.parent.mkdir(parents=True, exist_ok=True)
549
- target.write_text(json.dumps(asdict(payload), ensure_ascii=True, indent=2), encoding="utf-8")
550
- return str(target)
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass, asdict
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import StandardScaler
13
+
14
+ from ins_pricing.utils.io import IOUtils
15
+ from ins_pricing.utils.losses import normalize_loss_name
16
+ from ins_pricing.exceptions import ConfigurationError, DataValidationError
17
+
18
+ # NOTE: Some CSV exports may contain invisible BOM characters or leading/trailing
19
+ # spaces in column names. Pandas requires exact matches, so we normalize a few
20
+ # "required" column names (response/weight/binary response) before validating.
21
+
22
+
23
+ def _clean_column_name(name: Any) -> Any:
24
+ if not isinstance(name, str):
25
+ return name
26
+ return name.replace("\ufeff", "").strip()
27
+
28
+
29
+ def _normalize_required_columns(
30
+ df: pd.DataFrame, required: List[Optional[str]], *, df_label: str
31
+ ) -> None:
32
+ required_names = [r for r in required if isinstance(r, str) and r.strip()]
33
+ if not required_names:
34
+ return
35
+
36
+ mapping: Dict[Any, Any] = {}
37
+ existing = set(df.columns)
38
+ for col in df.columns:
39
+ cleaned = _clean_column_name(col)
40
+ if cleaned != col and cleaned not in existing:
41
+ mapping[col] = cleaned
42
+ if mapping:
43
+ df.rename(columns=mapping, inplace=True)
44
+
45
+ existing = set(df.columns)
46
+ for req in required_names:
47
+ if req in existing:
48
+ continue
49
+ candidates = [
50
+ col
51
+ for col in df.columns
52
+ if isinstance(col, str) and _clean_column_name(col).lower() == req.lower()
53
+ ]
54
+ if len(candidates) == 1 and req not in existing:
55
+ df.rename(columns={candidates[0]: req}, inplace=True)
56
+ existing = set(df.columns)
57
+ elif len(candidates) > 1:
58
+ raise KeyError(
59
+ f"{df_label} has multiple columns matching required {req!r} "
60
+ f"(case/space-insensitive): {candidates}"
61
+ )
62
+
63
+
64
+ # ===== Core components and training wrappers =================================
65
+
66
+ # =============================================================================
67
+ # Config, preprocessing, and trainer base types
68
+ # =============================================================================
69
+ @dataclass
70
+ class BayesOptConfig:
71
+ """Configuration for Bayesian optimization-based model training.
72
+
73
+ This dataclass holds all configuration parameters for the BayesOpt training
74
+ pipeline, including model settings, distributed training options, and
75
+ cross-validation strategies.
76
+
77
+ Attributes:
78
+ model_nme: Unique identifier for the model
79
+ resp_nme: Column name for the response/target variable
80
+ weight_nme: Column name for sample weights
81
+ factor_nmes: List of feature column names
82
+ task_type: Either 'regression' or 'classification'
83
+ binary_resp_nme: Column name for binary response (optional)
84
+ cate_list: List of categorical feature column names
85
+ loss_name: Regression loss ('auto', 'tweedie', 'poisson', 'gamma', 'mse', 'mae')
86
+ prop_test: Proportion of data for validation (0.0-1.0)
87
+ rand_seed: Random seed for reproducibility
88
+ epochs: Number of training epochs
89
+ use_gpu: Whether to use GPU acceleration
90
+ xgb_max_depth_max: Maximum tree depth for XGBoost tuning
91
+ xgb_n_estimators_max: Maximum estimators for XGBoost tuning
92
+ use_resn_data_parallel: Use DataParallel for ResNet
93
+ use_ft_data_parallel: Use DataParallel for FT-Transformer
94
+ use_resn_ddp: Use DDP for ResNet
95
+ use_ft_ddp: Use DDP for FT-Transformer
96
+ use_gnn_data_parallel: Use DataParallel for GNN
97
+ use_gnn_ddp: Use DDP for GNN
98
+ ft_role: FT-Transformer role ('model', 'embedding', 'unsupervised_embedding')
99
+ cv_strategy: CV strategy ('random', 'group', 'time', 'stratified')
100
+ build_oht: Whether to build one-hot encoded features (default True)
101
+
102
+ Example:
103
+ >>> config = BayesOptConfig(
104
+ ... model_nme="pricing_model",
105
+ ... resp_nme="claim_amount",
106
+ ... weight_nme="exposure",
107
+ ... factor_nmes=["age", "gender", "region"],
108
+ ... task_type="regression",
109
+ ... use_ft_ddp=True,
110
+ ... )
111
+ """
112
+
113
+ # Required fields
114
+ model_nme: str
115
+ resp_nme: str
116
+ weight_nme: str
117
+ factor_nmes: List[str]
118
+
119
+ # Task configuration
120
+ task_type: str = 'regression'
121
+ binary_resp_nme: Optional[str] = None
122
+ cate_list: Optional[List[str]] = None
123
+ loss_name: str = "auto"
124
+
125
+ # Training configuration
126
+ prop_test: float = 0.25
127
+ rand_seed: Optional[int] = None
128
+ epochs: int = 100
129
+ use_gpu: bool = True
130
+
131
+ # XGBoost settings
132
+ xgb_max_depth_max: int = 25
133
+ xgb_n_estimators_max: int = 500
134
+
135
+ # Distributed training settings
136
+ use_resn_data_parallel: bool = False
137
+ use_ft_data_parallel: bool = False
138
+ use_resn_ddp: bool = False
139
+ use_ft_ddp: bool = False
140
+ use_gnn_data_parallel: bool = False
141
+ use_gnn_ddp: bool = False
142
+
143
+ # GNN settings
144
+ gnn_use_approx_knn: bool = True
145
+ gnn_approx_knn_threshold: int = 50000
146
+ gnn_graph_cache: Optional[str] = None
147
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000
148
+ gnn_knn_gpu_mem_ratio: float = 0.9
149
+ gnn_knn_gpu_mem_overhead: float = 2.0
150
+
151
+ # Region/Geo settings
152
+ region_province_col: Optional[str] = None
153
+ region_city_col: Optional[str] = None
154
+ region_effect_alpha: float = 50.0
155
+ geo_feature_nmes: Optional[List[str]] = None
156
+ geo_token_hidden_dim: int = 32
157
+ geo_token_layers: int = 2
158
+ geo_token_dropout: float = 0.1
159
+ geo_token_k_neighbors: int = 10
160
+ geo_token_learning_rate: float = 1e-3
161
+ geo_token_epochs: int = 50
162
+
163
+ # Output settings
164
+ output_dir: Optional[str] = None
165
+ optuna_storage: Optional[str] = None
166
+ optuna_study_prefix: Optional[str] = None
167
+ best_params_files: Optional[Dict[str, str]] = None
168
+
169
+ # FT-Transformer settings
170
+ ft_role: str = "model"
171
+ ft_feature_prefix: str = "ft_emb"
172
+ ft_num_numeric_tokens: Optional[int] = None
173
+
174
+ # Training workflow settings
175
+ reuse_best_params: bool = False
176
+ resn_weight_decay: float = 1e-4
177
+ final_ensemble: bool = False
178
+ final_ensemble_k: int = 3
179
+ final_refit: bool = True
180
+
181
+ # Cross-validation settings
182
+ cv_strategy: str = "random"
183
+ cv_splits: Optional[int] = None
184
+ cv_group_col: Optional[str] = None
185
+ cv_time_col: Optional[str] = None
186
+ cv_time_ascending: bool = True
187
+ ft_oof_folds: Optional[int] = None
188
+ ft_oof_strategy: Optional[str] = None
189
+ ft_oof_shuffle: bool = True
190
+
191
+ # Caching and output settings
192
+ save_preprocess: bool = False
193
+ preprocess_artifact_path: Optional[str] = None
194
+ plot_path_style: str = "nested"
195
+ bo_sample_limit: Optional[int] = None
196
+ build_oht: bool = True
197
+ cache_predictions: bool = False
198
+ prediction_cache_dir: Optional[str] = None
199
+ prediction_cache_format: str = "parquet"
200
+ dataloader_workers: Optional[int] = None
201
+
202
+ def __post_init__(self) -> None:
203
+ """Validate configuration after initialization."""
204
+ self._validate()
205
+
206
+ def _validate(self) -> None:
207
+ """Validate configuration values and raise errors for invalid combinations."""
208
+ errors: List[str] = []
209
+
210
+ # Validate task_type
211
+ valid_task_types = {"regression", "classification"}
212
+ if self.task_type not in valid_task_types:
213
+ errors.append(
214
+ f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
215
+ )
216
+ if self.dataloader_workers is not None:
217
+ try:
218
+ if int(self.dataloader_workers) < 0:
219
+ errors.append("dataloader_workers must be >= 0 when provided.")
220
+ except (TypeError, ValueError):
221
+ errors.append("dataloader_workers must be an integer when provided.")
222
+ # Validate loss_name
223
+ try:
224
+ normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
225
+ if self.task_type == "classification" and normalized_loss not in {"auto", "logloss", "bce"}:
226
+ errors.append(
227
+ "loss_name must be 'auto', 'logloss', or 'bce' for classification tasks."
228
+ )
229
+ except ValueError as exc:
230
+ errors.append(str(exc))
231
+
232
+ # Validate prop_test
233
+ if not 0.0 < self.prop_test < 1.0:
234
+ errors.append(
235
+ f"prop_test must be between 0 and 1, got {self.prop_test}"
236
+ )
237
+
238
+ # Validate epochs
239
+ if self.epochs < 1:
240
+ errors.append(f"epochs must be >= 1, got {self.epochs}")
241
+
242
+ # Validate XGBoost settings
243
+ if self.xgb_max_depth_max < 1:
244
+ errors.append(
245
+ f"xgb_max_depth_max must be >= 1, got {self.xgb_max_depth_max}"
246
+ )
247
+ if self.xgb_n_estimators_max < 1:
248
+ errors.append(
249
+ f"xgb_n_estimators_max must be >= 1, got {self.xgb_n_estimators_max}"
250
+ )
251
+
252
+ # Validate distributed training: can't use both DataParallel and DDP
253
+ if self.use_resn_data_parallel and self.use_resn_ddp:
254
+ errors.append(
255
+ "Cannot use both use_resn_data_parallel and use_resn_ddp"
256
+ )
257
+ if self.use_ft_data_parallel and self.use_ft_ddp:
258
+ errors.append(
259
+ "Cannot use both use_ft_data_parallel and use_ft_ddp"
260
+ )
261
+ if self.use_gnn_data_parallel and self.use_gnn_ddp:
262
+ errors.append(
263
+ "Cannot use both use_gnn_data_parallel and use_gnn_ddp"
264
+ )
265
+
266
+ # Validate ft_role
267
+ valid_ft_roles = {"model", "embedding", "unsupervised_embedding"}
268
+ if self.ft_role not in valid_ft_roles:
269
+ errors.append(
270
+ f"ft_role must be one of {valid_ft_roles}, got '{self.ft_role}'"
271
+ )
272
+
273
+ # Validate cv_strategy
274
+ valid_cv_strategies = {"random", "group", "grouped", "time", "timeseries", "temporal", "stratified"}
275
+ if self.cv_strategy not in valid_cv_strategies:
276
+ errors.append(
277
+ f"cv_strategy must be one of {valid_cv_strategies}, got '{self.cv_strategy}'"
278
+ )
279
+
280
+ # Validate group CV requires group_col
281
+ if self.cv_strategy in {"group", "grouped"} and not self.cv_group_col:
282
+ errors.append(
283
+ f"cv_group_col is required when cv_strategy is '{self.cv_strategy}'"
284
+ )
285
+
286
+ # Validate time CV requires time_col
287
+ if self.cv_strategy in {"time", "timeseries", "temporal"} and not self.cv_time_col:
288
+ errors.append(
289
+ f"cv_time_col is required when cv_strategy is '{self.cv_strategy}'"
290
+ )
291
+
292
+ # Validate prediction_cache_format
293
+ valid_cache_formats = {"parquet", "csv"}
294
+ if self.prediction_cache_format not in valid_cache_formats:
295
+ errors.append(
296
+ f"prediction_cache_format must be one of {valid_cache_formats}, "
297
+ f"got '{self.prediction_cache_format}'"
298
+ )
299
+
300
+ # Validate GNN memory settings
301
+ if self.gnn_knn_gpu_mem_ratio <= 0 or self.gnn_knn_gpu_mem_ratio > 1.0:
302
+ errors.append(
303
+ f"gnn_knn_gpu_mem_ratio must be in (0, 1], got {self.gnn_knn_gpu_mem_ratio}"
304
+ )
305
+
306
+ if errors:
307
+ raise ConfigurationError(
308
+ "BayesOptConfig validation failed:\n - " + "\n - ".join(errors)
309
+ )
310
+
311
+
312
+ @dataclass
313
+ class PreprocessArtifacts:
314
+ factor_nmes: List[str]
315
+ cate_list: List[str]
316
+ num_features: List[str]
317
+ var_nmes: List[str]
318
+ cat_categories: Dict[str, List[Any]]
319
+ dummy_columns: List[str]
320
+ numeric_scalers: Dict[str, Dict[str, float]]
321
+ weight_nme: str
322
+ resp_nme: str
323
+ binary_resp_nme: Optional[str] = None
324
+ drop_first: bool = True
325
+
326
+
327
+ class OutputManager:
328
+ # Centralize output paths for plots, results, and models.
329
+
330
+ def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
331
+ self.root = Path(root or os.getcwd())
332
+ self.model_name = model_name
333
+ self.plot_dir = self.root / 'plot'
334
+ self.result_dir = self.root / 'Results'
335
+ self.model_dir = self.root / 'model'
336
+
337
+ def _prepare(self, path: Path) -> str:
338
+ IOUtils.ensure_parent_dir(str(path))
339
+ return str(path)
340
+
341
+ def plot_path(self, filename: str) -> str:
342
+ return self._prepare(self.plot_dir / filename)
343
+
344
+ def result_path(self, filename: str) -> str:
345
+ return self._prepare(self.result_dir / filename)
346
+
347
+ def model_path(self, filename: str) -> str:
348
+ return self._prepare(self.model_dir / filename)
349
+
350
+
351
+ class VersionManager:
352
+ """Lightweight versioning: save config and best-params snapshots for traceability."""
353
+
354
+ def __init__(self, output: OutputManager) -> None:
355
+ self.output = output
356
+ self.version_dir = Path(self.output.result_dir) / "versions"
357
+ IOUtils.ensure_parent_dir(str(self.version_dir))
358
+
359
+ def save(self, tag: str, payload: Dict[str, Any]) -> str:
360
+ safe_tag = tag.replace(" ", "_")
361
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
362
+ path = self.version_dir / f"{ts}_{safe_tag}.json"
363
+ IOUtils.ensure_parent_dir(str(path))
364
+ with open(path, "w", encoding="utf-8") as f:
365
+ json.dump(payload, f, ensure_ascii=False, indent=2, default=str)
366
+ print(f"[Version] Saved snapshot: {path}")
367
+ return str(path)
368
+
369
+ def load_latest(self, tag: str) -> Optional[Dict[str, Any]]:
370
+ """Load the latest snapshot for a tag (sorted by timestamp prefix)."""
371
+ safe_tag = tag.replace(" ", "_")
372
+ pattern = f"*_{safe_tag}.json"
373
+ candidates = sorted(self.version_dir.glob(pattern))
374
+ if not candidates:
375
+ return None
376
+ path = candidates[-1]
377
+ try:
378
+ return json.loads(path.read_text(encoding="utf-8"))
379
+ except Exception as exc:
380
+ print(f"[Version] Failed to load snapshot {path}: {exc}")
381
+ return None
382
+
383
+
384
+ class DatasetPreprocessor:
385
+ # Prepare shared train/test views for trainers.
386
+
387
+ def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
388
+ config: BayesOptConfig) -> None:
389
+ self.config = config
390
+ # Copy inputs to avoid mutating caller-provided DataFrames.
391
+ self.train_data = train_df.copy()
392
+ self.test_data = test_df.copy()
393
+ self.num_features: List[str] = []
394
+ self.train_oht_data: Optional[pd.DataFrame] = None
395
+ self.test_oht_data: Optional[pd.DataFrame] = None
396
+ self.train_oht_scl_data: Optional[pd.DataFrame] = None
397
+ self.test_oht_scl_data: Optional[pd.DataFrame] = None
398
+ self.var_nmes: List[str] = []
399
+ self.cat_categories_for_shap: Dict[str, List[Any]] = {}
400
+ self.numeric_scalers: Dict[str, Dict[str, float]] = {}
401
+
402
+ def run(self) -> "DatasetPreprocessor":
403
+ """Run preprocessing: categorical encoding, target clipping, numeric scaling."""
404
+ cfg = self.config
405
+ _normalize_required_columns(
406
+ self.train_data,
407
+ [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
408
+ df_label="Train data",
409
+ )
410
+ _normalize_required_columns(
411
+ self.test_data,
412
+ [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
413
+ df_label="Test data",
414
+ )
415
+ missing_train = [
416
+ col for col in (cfg.resp_nme, cfg.weight_nme)
417
+ if col not in self.train_data.columns
418
+ ]
419
+ if missing_train:
420
+ raise DataValidationError(
421
+ f"Train data missing required columns: {missing_train}. "
422
+ f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
423
+ )
424
+ if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.train_data.columns:
425
+ raise DataValidationError(
426
+ f"Train data missing binary response column: {cfg.binary_resp_nme}. "
427
+ f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
428
+ )
429
+
430
+ test_has_resp = cfg.resp_nme in self.test_data.columns
431
+ test_has_weight = cfg.weight_nme in self.test_data.columns
432
+ test_has_binary = bool(
433
+ cfg.binary_resp_nme and cfg.binary_resp_nme in self.test_data.columns
434
+ )
435
+ if not test_has_weight:
436
+ self.test_data[cfg.weight_nme] = 1.0
437
+ if not test_has_resp:
438
+ self.test_data[cfg.resp_nme] = np.nan
439
+ if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.test_data.columns:
440
+ self.test_data[cfg.binary_resp_nme] = np.nan
441
+
442
+ # Precompute weighted actuals for plots and validation checks.
443
+ # Direct assignment is more efficient than .loc[:, col]
444
+ self.train_data['w_act'] = self.train_data[cfg.resp_nme] * \
445
+ self.train_data[cfg.weight_nme]
446
+ if test_has_resp:
447
+ self.test_data['w_act'] = self.test_data[cfg.resp_nme] * \
448
+ self.test_data[cfg.weight_nme]
449
+ if cfg.binary_resp_nme:
450
+ self.train_data['w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
451
+ self.train_data[cfg.weight_nme]
452
+ if test_has_binary:
453
+ self.test_data['w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
454
+ self.test_data[cfg.weight_nme]
455
+ # High-quantile clipping absorbs outliers; removing it lets extremes dominate loss.
456
+ q99 = self.train_data[cfg.resp_nme].quantile(0.999)
457
+ self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
458
+ upper=q99)
459
+ cate_list = list(cfg.cate_list or [])
460
+ if cate_list:
461
+ for cate in cate_list:
462
+ self.train_data[cate] = self.train_data[cate].astype(
463
+ 'category')
464
+ self.test_data[cate] = self.test_data[cate].astype('category')
465
+ cats = self.train_data[cate].cat.categories
466
+ self.cat_categories_for_shap[cate] = list(cats)
467
+ self.num_features = [
468
+ nme for nme in cfg.factor_nmes if nme not in cate_list]
469
+
470
+ build_oht = bool(getattr(cfg, "build_oht", True))
471
+ if not build_oht:
472
+ print("[Preprocess] build_oht=False; skip one-hot features.", flush=True)
473
+ self.train_oht_data = None
474
+ self.test_oht_data = None
475
+ self.train_oht_scl_data = None
476
+ self.test_oht_scl_data = None
477
+ self.var_nmes = list(cfg.factor_nmes)
478
+ return self
479
+
480
+ # Memory optimization: Single copy + in-place operations
481
+ train_oht = self.train_data[cfg.factor_nmes +
482
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
483
+ test_oht = self.test_data[cfg.factor_nmes +
484
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
485
+ train_oht = pd.get_dummies(
486
+ train_oht,
487
+ columns=cate_list,
488
+ drop_first=True,
489
+ dtype=np.int8
490
+ )
491
+ test_oht = pd.get_dummies(
492
+ test_oht,
493
+ columns=cate_list,
494
+ drop_first=True,
495
+ dtype=np.int8
496
+ )
497
+
498
+ # Fill missing dummy columns when reindexing to align train/test columns.
499
+ test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
500
+
501
+ # Keep unscaled one-hot data for fold-specific scaling to avoid leakage.
502
+ # Store direct references - these won't be mutated
503
+ self.train_oht_data = train_oht
504
+ self.test_oht_data = test_oht
505
+
506
+ # Only copy if we need to scale numeric features (memory optimization)
507
+ if self.num_features:
508
+ train_oht_scaled = train_oht.copy()
509
+ test_oht_scaled = test_oht.copy()
510
+ else:
511
+ # No scaling needed, reuse original
512
+ train_oht_scaled = train_oht
513
+ test_oht_scaled = test_oht
514
+ for num_chr in self.num_features:
515
+ # Scale per column so features are on comparable ranges for NN stability.
516
+ scaler = StandardScaler()
517
+ train_oht_scaled[num_chr] = scaler.fit_transform(
518
+ train_oht_scaled[num_chr].values.reshape(-1, 1))
519
+ test_oht_scaled[num_chr] = scaler.transform(
520
+ test_oht_scaled[num_chr].values.reshape(-1, 1))
521
+ scale_val = float(getattr(scaler, "scale_", [1.0])[0])
522
+ if scale_val == 0.0:
523
+ scale_val = 1.0
524
+ self.numeric_scalers[num_chr] = {
525
+ "mean": float(getattr(scaler, "mean_", [0.0])[0]),
526
+ "scale": scale_val,
527
+ }
528
+ # Fill missing dummy columns when reindexing to align train/test columns.
529
+ test_oht_scaled = test_oht_scaled.reindex(
530
+ columns=train_oht_scaled.columns, fill_value=0)
531
+ self.train_oht_scl_data = train_oht_scaled
532
+ self.test_oht_scl_data = test_oht_scaled
533
+ excluded = {cfg.weight_nme, cfg.resp_nme}
534
+ self.var_nmes = [
535
+ col for col in train_oht_scaled.columns if col not in excluded
536
+ ]
537
+ return self
538
+
539
+ def export_artifacts(self) -> PreprocessArtifacts:
540
+ dummy_columns: List[str] = []
541
+ if self.train_oht_data is not None:
542
+ dummy_columns = list(self.train_oht_data.columns)
543
+ return PreprocessArtifacts(
544
+ factor_nmes=list(self.config.factor_nmes),
545
+ cate_list=list(self.config.cate_list or []),
546
+ num_features=list(self.num_features),
547
+ var_nmes=list(self.var_nmes),
548
+ cat_categories=dict(self.cat_categories_for_shap),
549
+ dummy_columns=dummy_columns,
550
+ numeric_scalers=dict(self.numeric_scalers),
551
+ weight_nme=str(self.config.weight_nme),
552
+ resp_nme=str(self.config.resp_nme),
553
+ binary_resp_nme=self.config.binary_resp_nme,
554
+ drop_first=True,
555
+ )
556
+
557
+ def save_artifacts(self, path: str | Path) -> str:
558
+ payload = self.export_artifacts()
559
+ target = Path(path)
560
+ target.parent.mkdir(parents=True, exist_ok=True)
561
+ target.write_text(json.dumps(asdict(payload), ensure_ascii=True, indent=2), encoding="utf-8")
562
+ return str(target)