ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ins_pricing/README.md +48 -22
- ins_pricing/__init__.py +142 -90
- ins_pricing/cli/BayesOpt_entry.py +58 -46
- ins_pricing/cli/BayesOpt_incremental.py +77 -110
- ins_pricing/cli/Explain_Run.py +42 -23
- ins_pricing/cli/Explain_entry.py +551 -577
- ins_pricing/cli/Pricing_Run.py +42 -23
- ins_pricing/cli/bayesopt_entry_runner.py +51 -16
- ins_pricing/cli/utils/bootstrap.py +23 -0
- ins_pricing/cli/utils/cli_common.py +256 -256
- ins_pricing/cli/utils/cli_config.py +379 -360
- ins_pricing/cli/utils/import_resolver.py +375 -358
- ins_pricing/cli/utils/notebook_utils.py +256 -242
- ins_pricing/cli/watchdog_run.py +216 -198
- ins_pricing/frontend/__init__.py +10 -10
- ins_pricing/frontend/app.py +132 -61
- ins_pricing/frontend/config_builder.py +33 -0
- ins_pricing/frontend/example_config.json +11 -0
- ins_pricing/frontend/example_workflows.py +1 -1
- ins_pricing/frontend/runner.py +340 -388
- ins_pricing/governance/__init__.py +20 -20
- ins_pricing/governance/release.py +159 -159
- ins_pricing/modelling/README.md +1 -1
- ins_pricing/modelling/__init__.py +147 -92
- ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
- ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
- ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
- ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
- ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
- ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
- ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
- ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
- ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
- ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
- ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
- ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
- ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
- ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
- ins_pricing/modelling/explain/__init__.py +55 -55
- ins_pricing/modelling/explain/metrics.py +27 -174
- ins_pricing/modelling/explain/permutation.py +237 -237
- ins_pricing/modelling/plotting/__init__.py +40 -36
- ins_pricing/modelling/plotting/compat.py +228 -0
- ins_pricing/modelling/plotting/curves.py +572 -572
- ins_pricing/modelling/plotting/diagnostics.py +163 -163
- ins_pricing/modelling/plotting/geo.py +362 -362
- ins_pricing/modelling/plotting/importance.py +121 -121
- ins_pricing/pricing/__init__.py +27 -27
- ins_pricing/pricing/factors.py +67 -56
- ins_pricing/production/__init__.py +35 -25
- ins_pricing/production/{predict.py → inference.py} +140 -57
- ins_pricing/production/monitoring.py +8 -21
- ins_pricing/reporting/__init__.py +11 -11
- ins_pricing/setup.py +1 -1
- ins_pricing/tests/production/test_inference.py +90 -0
- ins_pricing/utils/__init__.py +112 -78
- ins_pricing/utils/device.py +258 -237
- ins_pricing/utils/features.py +53 -0
- ins_pricing/utils/io.py +72 -0
- ins_pricing/utils/logging.py +34 -1
- ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
- ins_pricing/utils/metrics.py +158 -24
- ins_pricing/utils/numerics.py +76 -0
- ins_pricing/utils/paths.py +9 -1
- ins_pricing/utils/profiling.py +8 -4
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
- ins_pricing-0.5.1.dist-info/RECORD +132 -0
- ins_pricing/modelling/core/BayesOpt.py +0 -146
- ins_pricing/modelling/core/__init__.py +0 -1
- ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
- ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
- ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
- ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
- ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
- ins_pricing/modelling/core/bayesopt/utils.py +0 -105
- ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
- ins_pricing/tests/production/test_predict.py +0 -233
- ins_pricing-0.4.5.dist-info/RECORD +0 -130
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
- {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -1,562 +1,599 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import os
|
|
5
|
-
from dataclasses import dataclass, asdict
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any, Dict, List, Optional
|
|
9
|
-
|
|
10
|
-
import numpy as np
|
|
11
|
-
import pandas as pd
|
|
12
|
-
from sklearn.preprocessing import StandardScaler
|
|
13
|
-
|
|
14
|
-
from .utils import IOUtils
|
|
15
|
-
from .utils.losses import normalize_loss_name
|
|
16
|
-
from
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
if
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
if
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import dataclass, asdict
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from sklearn.preprocessing import StandardScaler
|
|
13
|
+
|
|
14
|
+
from ins_pricing.utils.io import IOUtils
|
|
15
|
+
from ins_pricing.utils.losses import normalize_loss_name
|
|
16
|
+
from ins_pricing.exceptions import ConfigurationError, DataValidationError
|
|
17
|
+
from ins_pricing.utils import get_logger, log_print
|
|
18
|
+
|
|
19
|
+
_logger = get_logger("ins_pricing.modelling.bayesopt.config_preprocess")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _log(*args, **kwargs) -> None:
|
|
23
|
+
log_print(_logger, *args, **kwargs)
|
|
24
|
+
|
|
25
|
+
# NOTE: Some CSV exports may contain invisible BOM characters or leading/trailing
|
|
26
|
+
# spaces in column names. Pandas requires exact matches, so we normalize a few
|
|
27
|
+
# "required" column names (response/weight/binary response) before validating.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _clean_column_name(name: Any) -> Any:
|
|
31
|
+
if not isinstance(name, str):
|
|
32
|
+
return name
|
|
33
|
+
return name.replace("\ufeff", "").strip()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _normalize_required_columns(
|
|
37
|
+
df: pd.DataFrame, required: List[Optional[str]], *, df_label: str
|
|
38
|
+
) -> None:
|
|
39
|
+
required_names = [r for r in required if isinstance(r, str) and r.strip()]
|
|
40
|
+
if not required_names:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
mapping: Dict[Any, Any] = {}
|
|
44
|
+
existing = set(df.columns)
|
|
45
|
+
for col in df.columns:
|
|
46
|
+
cleaned = _clean_column_name(col)
|
|
47
|
+
if cleaned != col and cleaned not in existing:
|
|
48
|
+
mapping[col] = cleaned
|
|
49
|
+
if mapping:
|
|
50
|
+
df.rename(columns=mapping, inplace=True)
|
|
51
|
+
|
|
52
|
+
existing = set(df.columns)
|
|
53
|
+
for req in required_names:
|
|
54
|
+
if req in existing:
|
|
55
|
+
continue
|
|
56
|
+
candidates = [
|
|
57
|
+
col
|
|
58
|
+
for col in df.columns
|
|
59
|
+
if isinstance(col, str) and _clean_column_name(col).lower() == req.lower()
|
|
60
|
+
]
|
|
61
|
+
if len(candidates) == 1 and req not in existing:
|
|
62
|
+
df.rename(columns={candidates[0]: req}, inplace=True)
|
|
63
|
+
existing = set(df.columns)
|
|
64
|
+
elif len(candidates) > 1:
|
|
65
|
+
raise KeyError(
|
|
66
|
+
f"{df_label} has multiple columns matching required {req!r} "
|
|
67
|
+
f"(case/space-insensitive): {candidates}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ===== Core components and training wrappers =================================
|
|
72
|
+
|
|
73
|
+
# =============================================================================
|
|
74
|
+
# Config, preprocessing, and trainer base types
|
|
75
|
+
# =============================================================================
|
|
76
|
+
@dataclass
|
|
77
|
+
class BayesOptConfig:
|
|
78
|
+
"""Configuration for Bayesian optimization-based model training.
|
|
79
|
+
|
|
80
|
+
This dataclass holds all configuration parameters for the BayesOpt training
|
|
81
|
+
pipeline, including model settings, distributed training options, and
|
|
82
|
+
cross-validation strategies.
|
|
83
|
+
|
|
84
|
+
Attributes:
|
|
85
|
+
model_nme: Unique identifier for the model
|
|
86
|
+
resp_nme: Column name for the response/target variable
|
|
87
|
+
weight_nme: Column name for sample weights
|
|
88
|
+
factor_nmes: List of feature column names
|
|
89
|
+
task_type: Either 'regression' or 'classification'
|
|
90
|
+
binary_resp_nme: Column name for binary response (optional)
|
|
91
|
+
cate_list: List of categorical feature column names
|
|
92
|
+
loss_name: Regression loss ('auto', 'tweedie', 'poisson', 'gamma', 'mse', 'mae')
|
|
93
|
+
prop_test: Proportion of data for validation (0.0-1.0)
|
|
94
|
+
rand_seed: Random seed for reproducibility
|
|
95
|
+
epochs: Number of training epochs
|
|
89
96
|
use_gpu: Whether to use GPU acceleration
|
|
90
97
|
xgb_max_depth_max: Maximum tree depth for XGBoost tuning
|
|
91
98
|
xgb_n_estimators_max: Maximum estimators for XGBoost tuning
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
99
|
+
xgb_gpu_id: GPU device id for XGBoost (None = default)
|
|
100
|
+
xgb_cleanup_per_fold: Whether to cleanup GPU memory after each XGBoost fold
|
|
101
|
+
xgb_cleanup_synchronize: Whether to synchronize CUDA during XGBoost cleanup
|
|
102
|
+
xgb_use_dmatrix: Whether to use xgb.train with DMatrix/QuantileDMatrix
|
|
103
|
+
ft_cleanup_per_fold: Whether to cleanup GPU memory after each FT fold
|
|
104
|
+
ft_cleanup_synchronize: Whether to synchronize CUDA during FT cleanup
|
|
105
|
+
resn_cleanup_per_fold: Whether to cleanup GPU memory after each ResNet fold
|
|
106
|
+
resn_cleanup_synchronize: Whether to synchronize CUDA during ResNet cleanup
|
|
107
|
+
gnn_cleanup_per_fold: Whether to cleanup GPU memory after each GNN fold
|
|
108
|
+
gnn_cleanup_synchronize: Whether to synchronize CUDA during GNN cleanup
|
|
109
|
+
optuna_cleanup_synchronize: Whether to synchronize CUDA during Optuna cleanup
|
|
110
|
+
use_resn_data_parallel: Use DataParallel for ResNet
|
|
111
|
+
use_ft_data_parallel: Use DataParallel for FT-Transformer
|
|
112
|
+
use_resn_ddp: Use DDP for ResNet
|
|
113
|
+
use_ft_ddp: Use DDP for FT-Transformer
|
|
114
|
+
use_gnn_data_parallel: Use DataParallel for GNN
|
|
115
|
+
use_gnn_ddp: Use DDP for GNN
|
|
116
|
+
ft_role: FT-Transformer role ('model', 'embedding', 'unsupervised_embedding')
|
|
117
|
+
cv_strategy: CV strategy ('random', 'group', 'time', 'stratified')
|
|
118
|
+
build_oht: Whether to build one-hot encoded features (default True)
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
>>> config = BayesOptConfig(
|
|
122
|
+
... model_nme="pricing_model",
|
|
123
|
+
... resp_nme="claim_amount",
|
|
124
|
+
... weight_nme="exposure",
|
|
125
|
+
... factor_nmes=["age", "gender", "region"],
|
|
126
|
+
... task_type="regression",
|
|
127
|
+
... use_ft_ddp=True,
|
|
128
|
+
... )
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
# Required fields
|
|
132
|
+
model_nme: str
|
|
133
|
+
resp_nme: str
|
|
134
|
+
weight_nme: str
|
|
135
|
+
factor_nmes: List[str]
|
|
136
|
+
|
|
137
|
+
# Task configuration
|
|
138
|
+
task_type: str = 'regression'
|
|
139
|
+
binary_resp_nme: Optional[str] = None
|
|
140
|
+
cate_list: Optional[List[str]] = None
|
|
141
|
+
loss_name: str = "auto"
|
|
142
|
+
|
|
143
|
+
# Training configuration
|
|
144
|
+
prop_test: float = 0.25
|
|
145
|
+
rand_seed: Optional[int] = None
|
|
146
|
+
epochs: int = 100
|
|
147
|
+
use_gpu: bool = True
|
|
148
|
+
|
|
131
149
|
# XGBoost settings
|
|
132
150
|
xgb_max_depth_max: int = 25
|
|
133
151
|
xgb_n_estimators_max: int = 500
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
#
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
152
|
+
xgb_gpu_id: Optional[int] = None
|
|
153
|
+
xgb_cleanup_per_fold: bool = False
|
|
154
|
+
xgb_cleanup_synchronize: bool = False
|
|
155
|
+
xgb_use_dmatrix: bool = True
|
|
156
|
+
ft_cleanup_per_fold: bool = False
|
|
157
|
+
ft_cleanup_synchronize: bool = False
|
|
158
|
+
resn_cleanup_per_fold: bool = False
|
|
159
|
+
resn_cleanup_synchronize: bool = False
|
|
160
|
+
gnn_cleanup_per_fold: bool = False
|
|
161
|
+
gnn_cleanup_synchronize: bool = False
|
|
162
|
+
optuna_cleanup_synchronize: bool = False
|
|
163
|
+
|
|
164
|
+
# Distributed training settings
|
|
165
|
+
use_resn_data_parallel: bool = False
|
|
166
|
+
use_ft_data_parallel: bool = False
|
|
167
|
+
use_resn_ddp: bool = False
|
|
168
|
+
use_ft_ddp: bool = False
|
|
169
|
+
use_gnn_data_parallel: bool = False
|
|
170
|
+
use_gnn_ddp: bool = False
|
|
171
|
+
|
|
172
|
+
# GNN settings
|
|
173
|
+
gnn_use_approx_knn: bool = True
|
|
174
|
+
gnn_approx_knn_threshold: int = 50000
|
|
175
|
+
gnn_graph_cache: Optional[str] = None
|
|
176
|
+
gnn_max_gpu_knn_nodes: Optional[int] = 200000
|
|
177
|
+
gnn_knn_gpu_mem_ratio: float = 0.9
|
|
178
|
+
gnn_knn_gpu_mem_overhead: float = 2.0
|
|
179
|
+
|
|
180
|
+
# Region/Geo settings
|
|
181
|
+
region_province_col: Optional[str] = None
|
|
182
|
+
region_city_col: Optional[str] = None
|
|
183
|
+
region_effect_alpha: float = 50.0
|
|
184
|
+
geo_feature_nmes: Optional[List[str]] = None
|
|
185
|
+
geo_token_hidden_dim: int = 32
|
|
186
|
+
geo_token_layers: int = 2
|
|
187
|
+
geo_token_dropout: float = 0.1
|
|
188
|
+
geo_token_k_neighbors: int = 10
|
|
189
|
+
geo_token_learning_rate: float = 1e-3
|
|
190
|
+
geo_token_epochs: int = 50
|
|
191
|
+
|
|
192
|
+
# Output settings
|
|
193
|
+
output_dir: Optional[str] = None
|
|
194
|
+
optuna_storage: Optional[str] = None
|
|
195
|
+
optuna_study_prefix: Optional[str] = None
|
|
196
|
+
best_params_files: Optional[Dict[str, str]] = None
|
|
197
|
+
|
|
198
|
+
# FT-Transformer settings
|
|
199
|
+
ft_role: str = "model"
|
|
200
|
+
ft_feature_prefix: str = "ft_emb"
|
|
201
|
+
ft_num_numeric_tokens: Optional[int] = None
|
|
202
|
+
|
|
203
|
+
# Training workflow settings
|
|
204
|
+
reuse_best_params: bool = False
|
|
205
|
+
resn_weight_decay: float = 1e-4
|
|
206
|
+
final_ensemble: bool = False
|
|
207
|
+
final_ensemble_k: int = 3
|
|
208
|
+
final_refit: bool = True
|
|
209
|
+
|
|
210
|
+
# Cross-validation settings
|
|
211
|
+
cv_strategy: str = "random"
|
|
212
|
+
cv_splits: Optional[int] = None
|
|
213
|
+
cv_group_col: Optional[str] = None
|
|
214
|
+
cv_time_col: Optional[str] = None
|
|
215
|
+
cv_time_ascending: bool = True
|
|
216
|
+
ft_oof_folds: Optional[int] = None
|
|
217
|
+
ft_oof_strategy: Optional[str] = None
|
|
218
|
+
ft_oof_shuffle: bool = True
|
|
219
|
+
|
|
220
|
+
# Caching and output settings
|
|
221
|
+
save_preprocess: bool = False
|
|
222
|
+
preprocess_artifact_path: Optional[str] = None
|
|
223
|
+
plot_path_style: str = "nested"
|
|
224
|
+
bo_sample_limit: Optional[int] = None
|
|
225
|
+
build_oht: bool = True
|
|
226
|
+
cache_predictions: bool = False
|
|
227
|
+
prediction_cache_dir: Optional[str] = None
|
|
228
|
+
prediction_cache_format: str = "parquet"
|
|
229
|
+
dataloader_workers: Optional[int] = None
|
|
230
|
+
|
|
231
|
+
def __post_init__(self) -> None:
|
|
232
|
+
"""Validate configuration after initialization."""
|
|
233
|
+
self._validate()
|
|
234
|
+
|
|
235
|
+
def _validate(self) -> None:
|
|
236
|
+
"""Validate configuration values and raise errors for invalid combinations."""
|
|
237
|
+
errors: List[str] = []
|
|
238
|
+
|
|
239
|
+
# Validate task_type
|
|
240
|
+
valid_task_types = {"regression", "classification"}
|
|
241
|
+
if self.task_type not in valid_task_types:
|
|
242
|
+
errors.append(
|
|
243
|
+
f"task_type must be one of {valid_task_types}, got '{self.task_type}'"
|
|
244
|
+
)
|
|
245
|
+
if self.dataloader_workers is not None:
|
|
246
|
+
try:
|
|
247
|
+
if int(self.dataloader_workers) < 0:
|
|
248
|
+
errors.append("dataloader_workers must be >= 0 when provided.")
|
|
249
|
+
except (TypeError, ValueError):
|
|
250
|
+
errors.append("dataloader_workers must be an integer when provided.")
|
|
251
|
+
# Validate loss_name
|
|
252
|
+
try:
|
|
253
|
+
normalized_loss = normalize_loss_name(self.loss_name, self.task_type)
|
|
254
|
+
if self.task_type == "classification" and normalized_loss not in {"auto", "logloss", "bce"}:
|
|
255
|
+
errors.append(
|
|
256
|
+
"loss_name must be 'auto', 'logloss', or 'bce' for classification tasks."
|
|
257
|
+
)
|
|
258
|
+
except ValueError as exc:
|
|
259
|
+
errors.append(str(exc))
|
|
260
|
+
|
|
261
|
+
# Validate prop_test
|
|
262
|
+
if not 0.0 < self.prop_test < 1.0:
|
|
263
|
+
errors.append(
|
|
264
|
+
f"prop_test must be between 0 and 1, got {self.prop_test}"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Validate epochs
|
|
268
|
+
if self.epochs < 1:
|
|
269
|
+
errors.append(f"epochs must be >= 1, got {self.epochs}")
|
|
270
|
+
|
|
271
|
+
# Validate XGBoost settings
|
|
272
|
+
if self.xgb_max_depth_max < 1:
|
|
273
|
+
errors.append(
|
|
274
|
+
f"xgb_max_depth_max must be >= 1, got {self.xgb_max_depth_max}"
|
|
275
|
+
)
|
|
247
276
|
if self.xgb_n_estimators_max < 1:
|
|
248
277
|
errors.append(
|
|
249
278
|
f"xgb_n_estimators_max must be >= 1, got {self.xgb_n_estimators_max}"
|
|
250
279
|
)
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
"
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
if self.
|
|
262
|
-
errors.append(
|
|
263
|
-
"Cannot use both
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
f"
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
# Validate
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
self.
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
self.
|
|
398
|
-
|
|
399
|
-
self.
|
|
400
|
-
self.
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
test_oht =
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
test_oht =
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
#
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
280
|
+
if self.xgb_gpu_id is not None:
|
|
281
|
+
try:
|
|
282
|
+
gpu_id = int(self.xgb_gpu_id)
|
|
283
|
+
except (TypeError, ValueError):
|
|
284
|
+
errors.append(f"xgb_gpu_id must be an integer, got {self.xgb_gpu_id!r}")
|
|
285
|
+
else:
|
|
286
|
+
if gpu_id < 0:
|
|
287
|
+
errors.append(f"xgb_gpu_id must be >= 0, got {gpu_id}")
|
|
288
|
+
|
|
289
|
+
# Validate distributed training: can't use both DataParallel and DDP
|
|
290
|
+
if self.use_resn_data_parallel and self.use_resn_ddp:
|
|
291
|
+
errors.append(
|
|
292
|
+
"Cannot use both use_resn_data_parallel and use_resn_ddp"
|
|
293
|
+
)
|
|
294
|
+
if self.use_ft_data_parallel and self.use_ft_ddp:
|
|
295
|
+
errors.append(
|
|
296
|
+
"Cannot use both use_ft_data_parallel and use_ft_ddp"
|
|
297
|
+
)
|
|
298
|
+
if self.use_gnn_data_parallel and self.use_gnn_ddp:
|
|
299
|
+
errors.append(
|
|
300
|
+
"Cannot use both use_gnn_data_parallel and use_gnn_ddp"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Validate ft_role
|
|
304
|
+
valid_ft_roles = {"model", "embedding", "unsupervised_embedding"}
|
|
305
|
+
if self.ft_role not in valid_ft_roles:
|
|
306
|
+
errors.append(
|
|
307
|
+
f"ft_role must be one of {valid_ft_roles}, got '{self.ft_role}'"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Validate cv_strategy
|
|
311
|
+
valid_cv_strategies = {"random", "group", "grouped", "time", "timeseries", "temporal", "stratified"}
|
|
312
|
+
if self.cv_strategy not in valid_cv_strategies:
|
|
313
|
+
errors.append(
|
|
314
|
+
f"cv_strategy must be one of {valid_cv_strategies}, got '{self.cv_strategy}'"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Validate group CV requires group_col
|
|
318
|
+
if self.cv_strategy in {"group", "grouped"} and not self.cv_group_col:
|
|
319
|
+
errors.append(
|
|
320
|
+
f"cv_group_col is required when cv_strategy is '{self.cv_strategy}'"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Validate time CV requires time_col
|
|
324
|
+
if self.cv_strategy in {"time", "timeseries", "temporal"} and not self.cv_time_col:
|
|
325
|
+
errors.append(
|
|
326
|
+
f"cv_time_col is required when cv_strategy is '{self.cv_strategy}'"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# Validate prediction_cache_format
|
|
330
|
+
valid_cache_formats = {"parquet", "csv"}
|
|
331
|
+
if self.prediction_cache_format not in valid_cache_formats:
|
|
332
|
+
errors.append(
|
|
333
|
+
f"prediction_cache_format must be one of {valid_cache_formats}, "
|
|
334
|
+
f"got '{self.prediction_cache_format}'"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Validate GNN memory settings
|
|
338
|
+
if self.gnn_knn_gpu_mem_ratio <= 0 or self.gnn_knn_gpu_mem_ratio > 1.0:
|
|
339
|
+
errors.append(
|
|
340
|
+
f"gnn_knn_gpu_mem_ratio must be in (0, 1], got {self.gnn_knn_gpu_mem_ratio}"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if errors:
|
|
344
|
+
raise ConfigurationError(
|
|
345
|
+
"BayesOptConfig validation failed:\n - " + "\n - ".join(errors)
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
@dataclass
|
|
350
|
+
class PreprocessArtifacts:
|
|
351
|
+
factor_nmes: List[str]
|
|
352
|
+
cate_list: List[str]
|
|
353
|
+
num_features: List[str]
|
|
354
|
+
var_nmes: List[str]
|
|
355
|
+
cat_categories: Dict[str, List[Any]]
|
|
356
|
+
dummy_columns: List[str]
|
|
357
|
+
numeric_scalers: Dict[str, Dict[str, float]]
|
|
358
|
+
weight_nme: str
|
|
359
|
+
resp_nme: str
|
|
360
|
+
binary_resp_nme: Optional[str] = None
|
|
361
|
+
drop_first: bool = True
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class OutputManager:
|
|
365
|
+
# Centralize output paths for plots, results, and models.
|
|
366
|
+
|
|
367
|
+
def __init__(self, root: Optional[str] = None, model_name: str = "model") -> None:
|
|
368
|
+
self.root = Path(root or os.getcwd())
|
|
369
|
+
self.model_name = model_name
|
|
370
|
+
self.plot_dir = self.root / 'plot'
|
|
371
|
+
self.result_dir = self.root / 'Results'
|
|
372
|
+
self.model_dir = self.root / 'model'
|
|
373
|
+
|
|
374
|
+
def _prepare(self, path: Path) -> str:
|
|
375
|
+
IOUtils.ensure_parent_dir(str(path))
|
|
376
|
+
return str(path)
|
|
377
|
+
|
|
378
|
+
def plot_path(self, filename: str) -> str:
|
|
379
|
+
return self._prepare(self.plot_dir / filename)
|
|
380
|
+
|
|
381
|
+
def result_path(self, filename: str) -> str:
|
|
382
|
+
return self._prepare(self.result_dir / filename)
|
|
383
|
+
|
|
384
|
+
def model_path(self, filename: str) -> str:
|
|
385
|
+
return self._prepare(self.model_dir / filename)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class VersionManager:
|
|
389
|
+
"""Lightweight versioning: save config and best-params snapshots for traceability."""
|
|
390
|
+
|
|
391
|
+
def __init__(self, output: OutputManager) -> None:
|
|
392
|
+
self.output = output
|
|
393
|
+
self.version_dir = Path(self.output.result_dir) / "versions"
|
|
394
|
+
IOUtils.ensure_parent_dir(str(self.version_dir))
|
|
395
|
+
|
|
396
|
+
def save(self, tag: str, payload: Dict[str, Any]) -> str:
|
|
397
|
+
safe_tag = tag.replace(" ", "_")
|
|
398
|
+
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
399
|
+
path = self.version_dir / f"{ts}_{safe_tag}.json"
|
|
400
|
+
IOUtils.ensure_parent_dir(str(path))
|
|
401
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
402
|
+
json.dump(payload, f, ensure_ascii=False, indent=2, default=str)
|
|
403
|
+
_log(f"[Version] Saved snapshot: {path}")
|
|
404
|
+
return str(path)
|
|
405
|
+
|
|
406
|
+
def load_latest(self, tag: str) -> Optional[Dict[str, Any]]:
|
|
407
|
+
"""Load the latest snapshot for a tag (sorted by timestamp prefix)."""
|
|
408
|
+
safe_tag = tag.replace(" ", "_")
|
|
409
|
+
pattern = f"*_{safe_tag}.json"
|
|
410
|
+
candidates = sorted(self.version_dir.glob(pattern))
|
|
411
|
+
if not candidates:
|
|
412
|
+
return None
|
|
413
|
+
path = candidates[-1]
|
|
414
|
+
try:
|
|
415
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
416
|
+
except Exception as exc:
|
|
417
|
+
_log(f"[Version] Failed to load snapshot {path}: {exc}")
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class DatasetPreprocessor:
|
|
422
|
+
# Prepare shared train/test views for trainers.
|
|
423
|
+
|
|
424
|
+
def __init__(self, train_df: pd.DataFrame, test_df: pd.DataFrame,
|
|
425
|
+
config: BayesOptConfig) -> None:
|
|
426
|
+
self.config = config
|
|
427
|
+
# Copy inputs to avoid mutating caller-provided DataFrames.
|
|
428
|
+
self.train_data = train_df.copy()
|
|
429
|
+
self.test_data = test_df.copy()
|
|
430
|
+
self.num_features: List[str] = []
|
|
431
|
+
self.train_oht_data: Optional[pd.DataFrame] = None
|
|
432
|
+
self.test_oht_data: Optional[pd.DataFrame] = None
|
|
433
|
+
self.train_oht_scl_data: Optional[pd.DataFrame] = None
|
|
434
|
+
self.test_oht_scl_data: Optional[pd.DataFrame] = None
|
|
435
|
+
self.var_nmes: List[str] = []
|
|
436
|
+
self.cat_categories_for_shap: Dict[str, List[Any]] = {}
|
|
437
|
+
self.numeric_scalers: Dict[str, Dict[str, float]] = {}
|
|
438
|
+
|
|
439
|
+
def run(self) -> "DatasetPreprocessor":
|
|
440
|
+
"""Run preprocessing: categorical encoding, target clipping, numeric scaling."""
|
|
441
|
+
cfg = self.config
|
|
442
|
+
_normalize_required_columns(
|
|
443
|
+
self.train_data,
|
|
444
|
+
[cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
|
|
445
|
+
df_label="Train data",
|
|
446
|
+
)
|
|
447
|
+
_normalize_required_columns(
|
|
448
|
+
self.test_data,
|
|
449
|
+
[cfg.resp_nme, cfg.weight_nme, cfg.binary_resp_nme],
|
|
450
|
+
df_label="Test data",
|
|
451
|
+
)
|
|
452
|
+
missing_train = [
|
|
453
|
+
col for col in (cfg.resp_nme, cfg.weight_nme)
|
|
454
|
+
if col not in self.train_data.columns
|
|
455
|
+
]
|
|
456
|
+
if missing_train:
|
|
457
|
+
raise DataValidationError(
|
|
458
|
+
f"Train data missing required columns: {missing_train}. "
|
|
459
|
+
f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
|
|
460
|
+
)
|
|
461
|
+
if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.train_data.columns:
|
|
462
|
+
raise DataValidationError(
|
|
463
|
+
f"Train data missing binary response column: {cfg.binary_resp_nme}. "
|
|
464
|
+
f"Available columns (first 50): {list(self.train_data.columns)[:50]}"
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
test_has_resp = cfg.resp_nme in self.test_data.columns
|
|
468
|
+
test_has_weight = cfg.weight_nme in self.test_data.columns
|
|
469
|
+
test_has_binary = bool(
|
|
470
|
+
cfg.binary_resp_nme and cfg.binary_resp_nme in self.test_data.columns
|
|
471
|
+
)
|
|
472
|
+
if not test_has_weight:
|
|
473
|
+
self.test_data[cfg.weight_nme] = 1.0
|
|
474
|
+
if not test_has_resp:
|
|
475
|
+
self.test_data[cfg.resp_nme] = np.nan
|
|
476
|
+
if cfg.binary_resp_nme and cfg.binary_resp_nme not in self.test_data.columns:
|
|
477
|
+
self.test_data[cfg.binary_resp_nme] = np.nan
|
|
478
|
+
|
|
479
|
+
# Precompute weighted actuals for plots and validation checks.
|
|
480
|
+
# Direct assignment is more efficient than .loc[:, col]
|
|
481
|
+
self.train_data['w_act'] = self.train_data[cfg.resp_nme] * \
|
|
482
|
+
self.train_data[cfg.weight_nme]
|
|
483
|
+
if test_has_resp:
|
|
484
|
+
self.test_data['w_act'] = self.test_data[cfg.resp_nme] * \
|
|
485
|
+
self.test_data[cfg.weight_nme]
|
|
486
|
+
if cfg.binary_resp_nme:
|
|
487
|
+
self.train_data['w_binary_act'] = self.train_data[cfg.binary_resp_nme] * \
|
|
488
|
+
self.train_data[cfg.weight_nme]
|
|
489
|
+
if test_has_binary:
|
|
490
|
+
self.test_data['w_binary_act'] = self.test_data[cfg.binary_resp_nme] * \
|
|
491
|
+
self.test_data[cfg.weight_nme]
|
|
492
|
+
# High-quantile clipping absorbs outliers; removing it lets extremes dominate loss.
|
|
493
|
+
q99 = self.train_data[cfg.resp_nme].quantile(0.999)
|
|
494
|
+
self.train_data[cfg.resp_nme] = self.train_data[cfg.resp_nme].clip(
|
|
495
|
+
upper=q99)
|
|
496
|
+
cate_list = list(cfg.cate_list or [])
|
|
497
|
+
if cate_list:
|
|
498
|
+
for cate in cate_list:
|
|
499
|
+
self.train_data[cate] = self.train_data[cate].astype(
|
|
500
|
+
'category')
|
|
501
|
+
self.test_data[cate] = self.test_data[cate].astype('category')
|
|
502
|
+
cats = self.train_data[cate].cat.categories
|
|
503
|
+
self.cat_categories_for_shap[cate] = list(cats)
|
|
504
|
+
self.num_features = [
|
|
505
|
+
nme for nme in cfg.factor_nmes if nme not in cate_list]
|
|
506
|
+
|
|
507
|
+
build_oht = bool(getattr(cfg, "build_oht", True))
|
|
508
|
+
if not build_oht:
|
|
509
|
+
_log("[Preprocess] build_oht=False; skip one-hot features.", flush=True)
|
|
510
|
+
self.train_oht_data = None
|
|
511
|
+
self.test_oht_data = None
|
|
512
|
+
self.train_oht_scl_data = None
|
|
513
|
+
self.test_oht_scl_data = None
|
|
514
|
+
self.var_nmes = list(cfg.factor_nmes)
|
|
515
|
+
return self
|
|
516
|
+
|
|
517
|
+
# Memory optimization: Single copy + in-place operations
|
|
518
|
+
train_oht = self.train_data[cfg.factor_nmes +
|
|
519
|
+
[cfg.weight_nme] + [cfg.resp_nme]].copy()
|
|
520
|
+
test_oht = self.test_data[cfg.factor_nmes +
|
|
521
|
+
[cfg.weight_nme] + [cfg.resp_nme]].copy()
|
|
522
|
+
train_oht = pd.get_dummies(
|
|
523
|
+
train_oht,
|
|
524
|
+
columns=cate_list,
|
|
525
|
+
drop_first=True,
|
|
526
|
+
dtype=np.int8
|
|
527
|
+
)
|
|
528
|
+
test_oht = pd.get_dummies(
|
|
529
|
+
test_oht,
|
|
530
|
+
columns=cate_list,
|
|
531
|
+
drop_first=True,
|
|
532
|
+
dtype=np.int8
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Fill missing dummy columns when reindexing to align train/test columns.
|
|
536
|
+
test_oht = test_oht.reindex(columns=train_oht.columns, fill_value=0)
|
|
537
|
+
|
|
538
|
+
# Keep unscaled one-hot data for fold-specific scaling to avoid leakage.
|
|
539
|
+
# Store direct references - these won't be mutated
|
|
540
|
+
self.train_oht_data = train_oht
|
|
541
|
+
self.test_oht_data = test_oht
|
|
542
|
+
|
|
543
|
+
# Only copy if we need to scale numeric features (memory optimization)
|
|
544
|
+
if self.num_features:
|
|
545
|
+
train_oht_scaled = train_oht.copy()
|
|
546
|
+
test_oht_scaled = test_oht.copy()
|
|
547
|
+
else:
|
|
548
|
+
# No scaling needed, reuse original
|
|
549
|
+
train_oht_scaled = train_oht
|
|
550
|
+
test_oht_scaled = test_oht
|
|
551
|
+
for num_chr in self.num_features:
|
|
552
|
+
# Scale per column so features are on comparable ranges for NN stability.
|
|
553
|
+
scaler = StandardScaler()
|
|
554
|
+
train_oht_scaled[num_chr] = scaler.fit_transform(
|
|
555
|
+
train_oht_scaled[num_chr].values.reshape(-1, 1))
|
|
556
|
+
test_oht_scaled[num_chr] = scaler.transform(
|
|
557
|
+
test_oht_scaled[num_chr].values.reshape(-1, 1))
|
|
558
|
+
scale_val = float(getattr(scaler, "scale_", [1.0])[0])
|
|
559
|
+
if scale_val == 0.0:
|
|
560
|
+
scale_val = 1.0
|
|
561
|
+
self.numeric_scalers[num_chr] = {
|
|
562
|
+
"mean": float(getattr(scaler, "mean_", [0.0])[0]),
|
|
563
|
+
"scale": scale_val,
|
|
564
|
+
}
|
|
565
|
+
# Fill missing dummy columns when reindexing to align train/test columns.
|
|
566
|
+
test_oht_scaled = test_oht_scaled.reindex(
|
|
567
|
+
columns=train_oht_scaled.columns, fill_value=0)
|
|
568
|
+
self.train_oht_scl_data = train_oht_scaled
|
|
569
|
+
self.test_oht_scl_data = test_oht_scaled
|
|
570
|
+
excluded = {cfg.weight_nme, cfg.resp_nme}
|
|
571
|
+
self.var_nmes = [
|
|
572
|
+
col for col in train_oht_scaled.columns if col not in excluded
|
|
573
|
+
]
|
|
574
|
+
return self
|
|
575
|
+
|
|
576
|
+
def export_artifacts(self) -> PreprocessArtifacts:
|
|
577
|
+
dummy_columns: List[str] = []
|
|
578
|
+
if self.train_oht_data is not None:
|
|
579
|
+
dummy_columns = list(self.train_oht_data.columns)
|
|
580
|
+
return PreprocessArtifacts(
|
|
581
|
+
factor_nmes=list(self.config.factor_nmes),
|
|
582
|
+
cate_list=list(self.config.cate_list or []),
|
|
583
|
+
num_features=list(self.num_features),
|
|
584
|
+
var_nmes=list(self.var_nmes),
|
|
585
|
+
cat_categories=dict(self.cat_categories_for_shap),
|
|
586
|
+
dummy_columns=dummy_columns,
|
|
587
|
+
numeric_scalers=dict(self.numeric_scalers),
|
|
588
|
+
weight_nme=str(self.config.weight_nme),
|
|
589
|
+
resp_nme=str(self.config.resp_nme),
|
|
590
|
+
binary_resp_nme=self.config.binary_resp_nme,
|
|
591
|
+
drop_first=True,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
def save_artifacts(self, path: str | Path) -> str:
|
|
595
|
+
payload = self.export_artifacts()
|
|
596
|
+
target = Path(path)
|
|
597
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
598
|
+
target.write_text(json.dumps(asdict(payload), ensure_ascii=True, indent=2), encoding="utf-8")
|
|
599
|
+
return str(target)
|