ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,98 @@
1
+ """BayesOpt subpackage (split from monolithic BayesOpt.py)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from .config_preprocess import (
8
+ BayesOptConfig,
9
+ DatasetPreprocessor,
10
+ OutputManager,
11
+ VersionManager,
12
+ )
13
+ from .core import BayesOptModel
14
+ from .models import (
15
+ FeatureTokenizer,
16
+ FTTransformerCore,
17
+ FTTransformerSklearn,
18
+ GraphNeuralNetSklearn,
19
+ MaskedTabularDataset,
20
+ ResBlock,
21
+ ResNetSequential,
22
+ ResNetSklearn,
23
+ ScaledTransformerEncoderLayer,
24
+ SimpleGraphLayer,
25
+ SimpleGNN,
26
+ TabularDataset,
27
+ )
28
+ from .trainers import (
29
+ FTTrainer,
30
+ GLMTrainer,
31
+ GNNTrainer,
32
+ ResNetTrainer,
33
+ TrainerBase,
34
+ XGBTrainer,
35
+ _xgb_cuda_available,
36
+ )
37
+ from .utils import (
38
+ EPS,
39
+ DistributedUtils,
40
+ IOUtils,
41
+ PlotUtils,
42
+ TorchTrainerMixin,
43
+ TrainingUtils,
44
+ compute_batch_size,
45
+ csv_to_dict,
46
+ ensure_parent_dir,
47
+ free_cuda,
48
+ infer_factor_and_cate_list,
49
+ plot_dlift_list,
50
+ plot_lift_list,
51
+ set_global_seed,
52
+ split_data,
53
+ tweedie_loss,
54
+ )
55
+
56
+ __all__ = [
57
+ "BayesOptConfig",
58
+ "DatasetPreprocessor",
59
+ "OutputManager",
60
+ "VersionManager",
61
+ "BayesOptModel",
62
+ "FeatureTokenizer",
63
+ "FTTransformerCore",
64
+ "FTTransformerSklearn",
65
+ "GraphNeuralNetSklearn",
66
+ "MaskedTabularDataset",
67
+ "ResBlock",
68
+ "ResNetSequential",
69
+ "ResNetSklearn",
70
+ "ScaledTransformerEncoderLayer",
71
+ "SimpleGraphLayer",
72
+ "SimpleGNN",
73
+ "TabularDataset",
74
+ "FTTrainer",
75
+ "GLMTrainer",
76
+ "GNNTrainer",
77
+ "ResNetTrainer",
78
+ "TrainerBase",
79
+ "XGBTrainer",
80
+ "_xgb_cuda_available",
81
+ "EPS",
82
+ "DistributedUtils",
83
+ "IOUtils",
84
+ "PlotUtils",
85
+ "TorchTrainerMixin",
86
+ "TrainingUtils",
87
+ "compute_batch_size",
88
+ "csv_to_dict",
89
+ "ensure_parent_dir",
90
+ "free_cuda",
91
+ "infer_factor_and_cate_list",
92
+ "plot_dlift_list",
93
+ "plot_lift_list",
94
+ "set_global_seed",
95
+ "split_data",
96
+ "tweedie_loss",
97
+ "torch",
98
+ ]
@@ -0,0 +1,303 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import StandardScaler
13
+
14
+ from .utils import IOUtils
15
+
16
+ # NOTE: Some CSV exports may contain invisible BOM characters or leading/trailing
17
+ # spaces in column names. Pandas requires exact matches, so we normalize a few
18
+ # "required" column names (response/weight/binary response) before validating.
19
+
20
+
21
+ def _clean_column_name(name: Any) -> Any:
22
+ if not isinstance(name, str):
23
+ return name
24
+ return name.replace("\ufeff", "").strip()
25
+
26
+
27
+ def _normalize_required_columns(
28
+ df: pd.DataFrame, required: List[Optional[str]], *, df_label: str
29
+ ) -> None:
30
+ required_names = [r for r in required if isinstance(r, str) and r.strip()]
31
+ if not required_names:
32
+ return
33
+
34
+ mapping: Dict[Any, Any] = {}
35
+ existing = set(df.columns)
36
+ for col in df.columns:
37
+ cleaned = _clean_column_name(col)
38
+ if cleaned != col and cleaned not in existing:
39
+ mapping[col] = cleaned
40
+ if mapping:
41
+ df.rename(columns=mapping, inplace=True)
42
+
43
+ existing = set(df.columns)
44
+ for req in required_names:
45
+ if req in existing:
46
+ continue
47
+ candidates = [
48
+ col
49
+ for col in df.columns
50
+ if isinstance(col, str) and _clean_column_name(col).lower() == req.lower()
51
+ ]
52
+ if len(candidates) == 1 and req not in existing:
53
+ df.rename(columns={candidates[0]: req}, inplace=True)
54
+ existing = set(df.columns)
55
+ elif len(candidates) > 1:
56
+ raise KeyError(
57
+ f"{df_label} has multiple columns matching required {req!r} "
58
+ f"(case/space-insensitive): {candidates}"
59
+ )
60
+
61
+
62
+ # ===== Core components and training wrappers =================================
63
+
64
+ # =============================================================================
65
+ # Config, preprocessing, and trainer base types
66
+ # =============================================================================
67
+ @dataclass
68
+ class BayesOptConfig:
69
+ model_nme: str
70
+ resp_nme: str
71
+ weight_nme: str
72
+ factor_nmes: List[str]
73
+ task_type: str = 'regression'
74
+ binary_resp_nme: Optional[str] = None
75
+ cate_list: Optional[List[str]] = None
76
+ prop_test: float = 0.25
77
+ rand_seed: Optional[int] = None
78
+ epochs: int = 100
79
+ use_gpu: bool = True
80
+ xgb_max_depth_max: int = 25
81
+ xgb_n_estimators_max: int = 500
82
+ use_resn_data_parallel: bool = False
83
+ use_ft_data_parallel: bool = False
84
+ use_resn_ddp: bool = False
85
+ use_ft_ddp: bool = False
86
+ use_gnn_data_parallel: bool = False
87
+ use_gnn_ddp: bool = False
88
+ gnn_use_approx_knn: bool = True
89
+ gnn_approx_knn_threshold: int = 50000
90
+ gnn_graph_cache: Optional[str] = None
91
+ gnn_max_gpu_knn_nodes: Optional[int] = 200000
92
+ gnn_knn_gpu_mem_ratio: float = 0.9
93
+ gnn_knn_gpu_mem_overhead: float = 2.0
94
+ region_province_col: Optional[str] = None # Province column for hierarchical smoothing
95
+ region_city_col: Optional[str] = None # City column for hierarchical smoothing
96
+ region_effect_alpha: float = 50.0 # Smoothing strength (pseudo sample size)
97
+ geo_feature_nmes: Optional[List[str]] = None # Columns for geo tokens; None disables GNN
98
+ geo_token_hidden_dim: int = 32
99
+ geo_token_layers: int = 2
100
+ geo_token_dropout: float = 0.1
101
+ geo_token_k_neighbors: int = 10
102
+ geo_token_learning_rate: float = 1e-3
103
+ geo_token_epochs: int = 50
104
+ output_dir: Optional[str] = None
105
+ optuna_storage: Optional[str] = None
106
+ optuna_study_prefix: Optional[str] = None
107
+ best_params_files: Optional[Dict[str, str]] = None
108
+ # FT roles:
109
+ # - "model": FT is a standalone predictor (keep lift/SHAP evaluation)
110
+ # - "embedding": FT only exports embeddings as downstream features
111
+ # - "unsupervised_embedding": masked reconstruction pretraining + embeddings
112
+ ft_role: str = "model"
113
+ ft_feature_prefix: str = "ft_emb"
114
+ ft_num_numeric_tokens: Optional[int] = None
115
+ reuse_best_params: bool = False
116
+ resn_weight_decay: float = 1e-4
117
+ final_ensemble: bool = False
118
+ final_ensemble_k: int = 3
119
+ final_refit: bool = True
120
+
121
+
122
+ class OutputManager:
123
+ # Centralize output paths for plots, results, and models.
124
+
125
+ def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
126
+ self.root = Path(root or os.getcwd())
127
+ self.model_name = model_name
128
+ self.plot_dir = self.root / 'plot'
129
+ self.result_dir = self.root / 'Results'
130
+ self.model_dir = self.root / 'model'
131
+
132
+ def _prepare(self, path: Path) -> str:
133
+ IOUtils.ensure_parent_dir(str(path))
134
+ return str(path)
135
+
136
+ def plot_path(self, filename: str) -> str:
137
+ return self._prepare(self.plot_dir / filename)
138
+
139
+ def result_path(self, filename: str) -> str:
140
+ return self._prepare(self.result_dir / filename)
141
+
142
+ def model_path(self, filename: str) -> str:
143
+ return self._prepare(self.model_dir / filename)
144
+
145
+
146
+ class VersionManager:
147
+ """Lightweight versioning: save config and best-params snapshots for traceability."""
148
+
149
+ def __init__(self, output: OutputManager) -> None:
150
+ self.output = output
151
+ self.version_dir = Path(self.output.result_dir) / "versions"
152
+ IOUtils.ensure_parent_dir(str(self.version_dir))
153
+
154
+ def save(self, tag: str, payload: Dict[str, Any]) -> str:
155
+ safe_tag = tag.replace(" ", "_")
156
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
157
+ path = self.version_dir / f"{ts}_{safe_tag}.json"
158
+ IOUtils.ensure_parent_dir(str(path))
159
+ with open(path, "w", encoding="utf-8") as f:
160
+ json.dump(payload, f, ensure_ascii=False, indent=2, default=str)
161
+ print(f"[Version] Saved snapshot: {path}")
162
+ return str(path)
163
+
164
+ def load_latest(self, tag: str) -> Optional[Dict[str, Any]]:
165
+ """Load the latest snapshot for a tag (sorted by timestamp prefix)."""
166
+ safe_tag = tag.replace(" ", "_")
167
+ pattern = f"*_{safe_tag}.json"
168
+ candidates = sorted(self.version_dir.glob(pattern))
169
+ if not candidates:
170
+ return None
171
+ path = candidates[-1]
172
+ try:
173
+ return json.loads(path.read_text(encoding="utf-8"))
174
+ except Exception as exc:
175
+ print(f"[Version] Failed to load snapshot {path}: {exc}")
176
+ return None
177
+
178
+
179
+ class DatasetPreprocessor:
180
+ # Prepare shared train/test views for trainers.
181
+
182
+ def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
183
+ config: BayesOptConfig) -> None:
184
+ self.config = config
185
+ self.train_data = train_df.copy(deep=True)
186
+ self.test_data = test_df.copy(deep=True)
187
+ self.num_features: List[str] = []
188
+ self.train_oht_data: Optional[pd.DataFrame] = None
189
+ self.test_oht_data: Optional[pd.DataFrame] = None
190
+ self.train_oht_scl_data: Optional[pd.DataFrame] = None
191
+ self.test_oht_scl_data: Optional[pd.DataFrame] = None
192
+ self.var_nmes: List[str] = []
193
+ self.cat_categories_for_shap: Dict[str, List[Any]] = {}
194
+
195
+ def run(self) -> "DatasetPreprocessor":
196
+ """Run preprocessing: categorical encoding, target clipping, numeric scaling."""
197
+ cfg = self.config
198
+ _normalize_required_columns(
199
+ self.train_data,
200
+ [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
201
+ df_label="Train data",
202
+ )
203
+ _normalize_required_columns(
204
+ self.test_data,
205
+ [cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
206
+ df_label="Test data",
207
+ )
208
+ missing_train = [
209
+ col for col in (cfg.resp_nme, cfg.weight_nme)
210
+ if col not in self.train_data.columns
211
+ ]
212
+ if missing_train:
213
+ raise KeyError(
214
+ f"Train data missing required columns: {missing_train}. "
215
+ f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
216
+ )
217
+ if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.train_data.columns:
218
+ raise KeyError(
219
+ f"Train data missing binary response column: {cfg.binary_resp_nme}. "
220
+ f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
221
+ )
222
+
223
+ test_has_resp = cfg.resp_nme in self.test_data.columns
224
+ test_has_weight = cfg.weight_nme in self.test_data.columns
225
+ test_has_binary = bool(
226
+ cfg.binary_resp_nme and cfg.binary_resp_nme in self.test_data.columns
227
+ )
228
+ if not test_has_weight:
229
+ self.test_data[cfg.weight_nme] = 1.0
230
+ if not test_has_resp:
231
+ self.test_data[cfg.resp_nme] = np.nan
232
+ if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.test_data.columns:
233
+ self.test_data[cfg.binary_resp_nme] = np.nan
234
+
235
+ # Precompute weighted actuals for plots and validation checks.
236
+ self.train_data.loc[:, 'w_act'] = self.train_data[cfg.resp_nme] * \
237
+ self.train_data[cfg.weight_nme]
238
+ if test_has_resp:
239
+ self.test_data.loc[:, 'w_act'] = self.test_data[cfg.resp_nme] * \
240
+ self.test_data[cfg.weight_nme]
241
+ if cfg.binary_resp_nme:
242
+ self.train_data.loc[:, 'w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
243
+ self.train_data[cfg.weight_nme]
244
+ if test_has_binary:
245
+ self.test_data.loc[:, 'w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
246
+ self.test_data[cfg.weight_nme]
247
+ # High-quantile clipping absorbs outliers; removing it lets extremes dominate loss.
248
+ q99 = self.train_data[cfg.resp_nme].quantile(0.999)
249
+ self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
250
+ upper=q99)
251
+ cate_list = list(cfg.cate_list or [])
252
+ if cate_list:
253
+ for cate in cate_list:
254
+ self.train_data[cate] = self.train_data[cate].astype(
255
+ 'category')
256
+ self.test_data[cate] = self.test_data[cate].astype('category')
257
+ cats = self.train_data[cate].cat.categories
258
+ self.cat_categories_for_shap[cate] = list(cats)
259
+ self.num_features = [
260
+ nme for nme in cfg.factor_nmes if nme not in cate_list]
261
+ train_oht = self.train_data[cfg.factor_nmes +
262
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
263
+ test_oht = self.test_data[cfg.factor_nmes +
264
+ [cfg.weight_nme] + [cfg.resp_nme]].copy()
265
+ train_oht = pd.get_dummies(
266
+ train_oht,
267
+ columns=cate_list,
268
+ drop_first=True,
269
+ dtype=np.int8
270
+ )
271
+ test_oht = pd.get_dummies(
272
+ test_oht,
273
+ columns=cate_list,
274
+ drop_first=True,
275
+ dtype=np.int8
276
+ )
277
+
278
+ # Fill missing dummy columns when reindexing to align train/test columns.
279
+ test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
280
+
281
+ # Keep unscaled one-hot data for fold-specific scaling to avoid leakage.
282
+ self.train_oht_data = train_oht.copy(deep=True)
283
+ self.test_oht_data = test_oht.copy(deep=True)
284
+
285
+ train_oht_scaled = train_oht.copy(deep=True)
286
+ test_oht_scaled = test_oht.copy(deep=True)
287
+ for num_chr in self.num_features:
288
+ # Scale per column so features are on comparable ranges for NN stability.
289
+ scaler = StandardScaler()
290
+ train_oht_scaled[num_chr] = scaler.fit_transform(
291
+ train_oht_scaled[num_chr].values.reshape(-1, 1))
292
+ test_oht_scaled[num_chr] = scaler.transform(
293
+ test_oht_scaled[num_chr].values.reshape(-1, 1))
294
+ # Fill missing dummy columns when reindexing to align train/test columns.
295
+ test_oht_scaled = test_oht_scaled.reindex(
296
+ columns=train_oht_scaled.columns, fill_value=0)
297
+ self.train_oht_scl_data = train_oht_scaled
298
+ self.test_oht_scl_data = test_oht_scaled
299
+ excluded = {cfg.weight_nme, cfg.resp_nme}
300
+ self.var_nmes = [
301
+ col for col in train_oht_scaled.columns if col not in excluded
302
+ ]
303
+ return self