ins-pricing 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. ins_pricing/README.md +9 -6
  2. ins_pricing/__init__.py +3 -11
  3. ins_pricing/cli/BayesOpt_entry.py +24 -0
  4. ins_pricing/{modelling → cli}/BayesOpt_incremental.py +197 -64
  5. ins_pricing/cli/Explain_Run.py +25 -0
  6. ins_pricing/{modelling → cli}/Explain_entry.py +169 -124
  7. ins_pricing/cli/Pricing_Run.py +25 -0
  8. ins_pricing/cli/__init__.py +1 -0
  9. ins_pricing/cli/bayesopt_entry_runner.py +1312 -0
  10. ins_pricing/cli/utils/__init__.py +1 -0
  11. ins_pricing/cli/utils/cli_common.py +320 -0
  12. ins_pricing/cli/utils/cli_config.py +375 -0
  13. ins_pricing/{modelling → cli/utils}/notebook_utils.py +74 -19
  14. {ins_pricing_gemini/modelling → ins_pricing/cli}/watchdog_run.py +2 -2
  15. ins_pricing/{modelling → docs/modelling}/BayesOpt_USAGE.md +69 -49
  16. ins_pricing/docs/modelling/README.md +34 -0
  17. ins_pricing/modelling/__init__.py +57 -6
  18. ins_pricing/modelling/core/__init__.py +1 -0
  19. ins_pricing/modelling/{bayesopt → core/bayesopt}/config_preprocess.py +64 -1
  20. ins_pricing/modelling/{bayesopt → core/bayesopt}/core.py +150 -810
  21. ins_pricing/modelling/core/bayesopt/model_explain_mixin.py +296 -0
  22. ins_pricing/modelling/core/bayesopt/model_plotting_mixin.py +548 -0
  23. ins_pricing/modelling/core/bayesopt/models/__init__.py +27 -0
  24. ins_pricing/modelling/core/bayesopt/models/model_ft_components.py +316 -0
  25. ins_pricing/modelling/core/bayesopt/models/model_ft_trainer.py +808 -0
  26. ins_pricing/modelling/core/bayesopt/models/model_gnn.py +675 -0
  27. ins_pricing/modelling/core/bayesopt/models/model_resn.py +435 -0
  28. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +19 -0
  29. ins_pricing/modelling/core/bayesopt/trainers/trainer_base.py +1020 -0
  30. ins_pricing/modelling/core/bayesopt/trainers/trainer_ft.py +787 -0
  31. ins_pricing/modelling/core/bayesopt/trainers/trainer_glm.py +195 -0
  32. ins_pricing/modelling/core/bayesopt/trainers/trainer_gnn.py +312 -0
  33. ins_pricing/modelling/core/bayesopt/trainers/trainer_resn.py +261 -0
  34. ins_pricing/modelling/core/bayesopt/trainers/trainer_xgb.py +348 -0
  35. ins_pricing/modelling/{bayesopt → core/bayesopt}/utils.py +2 -2
  36. ins_pricing/modelling/core/evaluation.py +115 -0
  37. ins_pricing/production/__init__.py +4 -0
  38. ins_pricing/production/preprocess.py +71 -0
  39. ins_pricing/setup.py +10 -5
  40. {ins_pricing_gemini/modelling/tests → ins_pricing/tests/modelling}/test_plotting.py +2 -2
  41. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/METADATA +4 -4
  42. ins_pricing-0.2.0.dist-info/RECORD +125 -0
  43. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/top_level.txt +0 -1
  44. ins_pricing/modelling/BayesOpt_entry.py +0 -633
  45. ins_pricing/modelling/Explain_Run.py +0 -36
  46. ins_pricing/modelling/Pricing_Run.py +0 -36
  47. ins_pricing/modelling/README.md +0 -33
  48. ins_pricing/modelling/bayesopt/models.py +0 -2196
  49. ins_pricing/modelling/bayesopt/trainers.py +0 -2446
  50. ins_pricing/modelling/cli_common.py +0 -136
  51. ins_pricing/modelling/tests/test_plotting.py +0 -63
  52. ins_pricing/modelling/watchdog_run.py +0 -211
  53. ins_pricing-0.1.11.dist-info/RECORD +0 -169
  54. ins_pricing_gemini/__init__.py +0 -23
  55. ins_pricing_gemini/governance/__init__.py +0 -20
  56. ins_pricing_gemini/governance/approval.py +0 -93
  57. ins_pricing_gemini/governance/audit.py +0 -37
  58. ins_pricing_gemini/governance/registry.py +0 -99
  59. ins_pricing_gemini/governance/release.py +0 -159
  60. ins_pricing_gemini/modelling/Explain_Run.py +0 -36
  61. ins_pricing_gemini/modelling/Pricing_Run.py +0 -36
  62. ins_pricing_gemini/modelling/__init__.py +0 -151
  63. ins_pricing_gemini/modelling/cli_common.py +0 -141
  64. ins_pricing_gemini/modelling/config.py +0 -249
  65. ins_pricing_gemini/modelling/config_preprocess.py +0 -254
  66. ins_pricing_gemini/modelling/core.py +0 -741
  67. ins_pricing_gemini/modelling/data_container.py +0 -42
  68. ins_pricing_gemini/modelling/explain/__init__.py +0 -55
  69. ins_pricing_gemini/modelling/explain/gradients.py +0 -334
  70. ins_pricing_gemini/modelling/explain/metrics.py +0 -176
  71. ins_pricing_gemini/modelling/explain/permutation.py +0 -155
  72. ins_pricing_gemini/modelling/explain/shap_utils.py +0 -146
  73. ins_pricing_gemini/modelling/features.py +0 -215
  74. ins_pricing_gemini/modelling/model_manager.py +0 -148
  75. ins_pricing_gemini/modelling/model_plotting.py +0 -463
  76. ins_pricing_gemini/modelling/models.py +0 -2203
  77. ins_pricing_gemini/modelling/notebook_utils.py +0 -294
  78. ins_pricing_gemini/modelling/plotting/__init__.py +0 -45
  79. ins_pricing_gemini/modelling/plotting/common.py +0 -63
  80. ins_pricing_gemini/modelling/plotting/curves.py +0 -572
  81. ins_pricing_gemini/modelling/plotting/diagnostics.py +0 -139
  82. ins_pricing_gemini/modelling/plotting/geo.py +0 -362
  83. ins_pricing_gemini/modelling/plotting/importance.py +0 -121
  84. ins_pricing_gemini/modelling/run_logging.py +0 -133
  85. ins_pricing_gemini/modelling/tests/conftest.py +0 -8
  86. ins_pricing_gemini/modelling/tests/test_cross_val_generic.py +0 -66
  87. ins_pricing_gemini/modelling/tests/test_distributed_utils.py +0 -18
  88. ins_pricing_gemini/modelling/tests/test_explain.py +0 -56
  89. ins_pricing_gemini/modelling/tests/test_geo_tokens_split.py +0 -49
  90. ins_pricing_gemini/modelling/tests/test_graph_cache.py +0 -33
  91. ins_pricing_gemini/modelling/tests/test_plotting_library.py +0 -150
  92. ins_pricing_gemini/modelling/tests/test_preprocessor.py +0 -48
  93. ins_pricing_gemini/modelling/trainers.py +0 -2447
  94. ins_pricing_gemini/modelling/utils.py +0 -1020
  95. ins_pricing_gemini/pricing/__init__.py +0 -27
  96. ins_pricing_gemini/pricing/calibration.py +0 -39
  97. ins_pricing_gemini/pricing/data_quality.py +0 -117
  98. ins_pricing_gemini/pricing/exposure.py +0 -85
  99. ins_pricing_gemini/pricing/factors.py +0 -91
  100. ins_pricing_gemini/pricing/monitoring.py +0 -99
  101. ins_pricing_gemini/pricing/rate_table.py +0 -78
  102. ins_pricing_gemini/production/__init__.py +0 -21
  103. ins_pricing_gemini/production/drift.py +0 -30
  104. ins_pricing_gemini/production/monitoring.py +0 -143
  105. ins_pricing_gemini/production/scoring.py +0 -40
  106. ins_pricing_gemini/reporting/__init__.py +0 -11
  107. ins_pricing_gemini/reporting/report_builder.py +0 -72
  108. ins_pricing_gemini/reporting/scheduler.py +0 -45
  109. ins_pricing_gemini/scripts/BayesOpt_incremental.py +0 -722
  110. ins_pricing_gemini/scripts/Explain_entry.py +0 -545
  111. ins_pricing_gemini/scripts/__init__.py +0 -1
  112. ins_pricing_gemini/scripts/train.py +0 -568
  113. ins_pricing_gemini/setup.py +0 -55
  114. ins_pricing_gemini/smoke_test.py +0 -28
  115. /ins_pricing/{modelling → cli/utils}/run_logging.py +0 -0
  116. /ins_pricing/modelling/{BayesOpt.py → core/BayesOpt.py} +0 -0
  117. /ins_pricing/modelling/{bayesopt → core/bayesopt}/__init__.py +0 -0
  118. /ins_pricing/{modelling/tests → tests/modelling}/conftest.py +0 -0
  119. /ins_pricing/{modelling/tests → tests/modelling}/test_cross_val_generic.py +0 -0
  120. /ins_pricing/{modelling/tests → tests/modelling}/test_distributed_utils.py +0 -0
  121. /ins_pricing/{modelling/tests → tests/modelling}/test_explain.py +0 -0
  122. /ins_pricing/{modelling/tests → tests/modelling}/test_geo_tokens_split.py +0 -0
  123. /ins_pricing/{modelling/tests → tests/modelling}/test_graph_cache.py +0 -0
  124. /ins_pricing/{modelling/tests → tests/modelling}/test_plotting_library.py +0 -0
  125. /ins_pricing/{modelling/tests → tests/modelling}/test_preprocessor.py +0 -0
  126. {ins_pricing-0.1.11.dist-info → ins_pricing-0.2.0.dist-info}/WHEEL +0 -0
@@ -3,62 +3,50 @@ from __future__ import annotations
3
3
  from dataclasses import asdict
4
4
  from datetime import datetime
5
5
  import os
6
+ from pathlib import Path
6
7
  from typing import Any, Dict, List, Optional
7
-
8
- try: # matplotlib is optional; avoid hard import failures in headless/minimal envs
9
- import matplotlib
10
- if os.name != "nt" and not os.environ.get("DISPLAY") and not os.environ.get("MPLBACKEND"):
11
- matplotlib.use("Agg")
12
- import matplotlib.pyplot as plt
13
- _MPL_IMPORT_ERROR: Optional[BaseException] = None
14
- except Exception as exc: # pragma: no cover - optional dependency
15
- plt = None # type: ignore[assignment]
16
- _MPL_IMPORT_ERROR = exc
17
8
  import numpy as np
18
9
  import pandas as pd
19
10
  import torch
20
- import statsmodels.api as sm
21
- from sklearn.model_selection import ShuffleSplit
11
+ from sklearn.model_selection import GroupKFold, ShuffleSplit, TimeSeriesSplit
22
12
  from sklearn.preprocessing import StandardScaler
23
13
 
24
14
  from .config_preprocess import BayesOptConfig, DatasetPreprocessor, OutputManager, VersionManager
15
+ from .model_explain_mixin import BayesOptExplainMixin
16
+ from .model_plotting_mixin import BayesOptPlottingMixin
25
17
  from .models import GraphNeuralNetSklearn
26
18
  from .trainers import FTTrainer, GLMTrainer, GNNTrainer, ResNetTrainer, XGBTrainer
27
- from .utils import EPS, PlotUtils, infer_factor_and_cate_list, set_global_seed
28
- try:
29
- from ..plotting import curves as plot_curves
30
- from ..plotting import diagnostics as plot_diagnostics
31
- from ..plotting.common import PlotStyle, finalize_figure
32
- from ..explain import gradients as explain_gradients
33
- from ..explain import permutation as explain_permutation
34
- from ..explain import shap_utils as explain_shap
35
- except Exception: # pragma: no cover - optional for legacy imports
36
- try: # best-effort for non-package imports
37
- from ins_pricing.plotting import curves as plot_curves
38
- from ins_pricing.plotting import diagnostics as plot_diagnostics
39
- from ins_pricing.plotting.common import PlotStyle, finalize_figure
40
- from ins_pricing.explain import gradients as explain_gradients
41
- from ins_pricing.explain import permutation as explain_permutation
42
- from ins_pricing.explain import shap_utils as explain_shap
43
- except Exception: # pragma: no cover
44
- plot_curves = None
45
- plot_diagnostics = None
46
- PlotStyle = None
47
- finalize_figure = None
48
- explain_gradients = None
49
- explain_permutation = None
50
- explain_shap = None
51
-
52
-
53
- def _plot_skip(label: str) -> None:
54
- if _MPL_IMPORT_ERROR is not None:
55
- print(f"[Plot] Skip {label}: matplotlib unavailable ({_MPL_IMPORT_ERROR}).", flush=True)
56
- else:
57
- print(f"[Plot] Skip {label}: matplotlib unavailable.", flush=True)
19
+ from .utils import EPS, infer_factor_and_cate_list, set_global_seed
20
+
21
+
22
+ class _CVSplitter:
23
+ """Wrapper to carry optional groups or time order for CV splits."""
24
+
25
+ def __init__(
26
+ self,
27
+ splitter,
28
+ *,
29
+ groups: Optional[pd.Series] = None,
30
+ order: Optional[np.ndarray] = None,
31
+ ) -> None:
32
+ self._splitter = splitter
33
+ self._groups = groups
34
+ self._order = order
35
+
36
+ def split(self, X, y=None, groups=None):
37
+ if self._order is not None:
38
+ order = np.asarray(self._order)
39
+ X_ord = X.iloc[order] if hasattr(X, "iloc") else X[order]
40
+ for tr_idx, val_idx in self._splitter.split(X_ord, y=y):
41
+ yield order[tr_idx], order[val_idx]
42
+ return
43
+ use_groups = groups if groups is not None else self._groups
44
+ for tr_idx, val_idx in self._splitter.split(X, y=y, groups=use_groups):
45
+ yield tr_idx, val_idx
58
46
 
59
47
  # BayesOpt orchestration and SHAP utilities
60
48
  # =============================================================================
61
- class BayesOptModel:
49
+ class BayesOptModel(BayesOptPlottingMixin, BayesOptExplainMixin):
62
50
  def __init__(self, train_data, test_data,
63
51
  model_nme, resp_nme, weight_nme, factor_nmes: Optional[List[str]] = None, task_type='regression',
64
52
  binary_resp_nme=None,
@@ -89,7 +77,32 @@ class BayesOptModel:
89
77
  final_refit: bool = True,
90
78
  optuna_storage: Optional[str] = None,
91
79
  optuna_study_prefix: Optional[str] = None,
92
- best_params_files: Optional[Dict[str, str]] = None):
80
+ best_params_files: Optional[Dict[str, str]] = None,
81
+ cv_strategy: Optional[str] = None,
82
+ cv_splits: Optional[int] = None,
83
+ cv_group_col: Optional[str] = None,
84
+ cv_time_col: Optional[str] = None,
85
+ cv_time_ascending: bool = True,
86
+ ft_oof_folds: Optional[int] = None,
87
+ ft_oof_strategy: Optional[str] = None,
88
+ ft_oof_shuffle: bool = True,
89
+ save_preprocess: bool = False,
90
+ preprocess_artifact_path: Optional[str] = None,
91
+ plot_path_style: Optional[str] = None,
92
+ bo_sample_limit: Optional[int] = None,
93
+ cache_predictions: bool = False,
94
+ prediction_cache_dir: Optional[str] = None,
95
+ prediction_cache_format: Optional[str] = None,
96
+ region_province_col: Optional[str] = None,
97
+ region_city_col: Optional[str] = None,
98
+ region_effect_alpha: Optional[float] = None,
99
+ geo_feature_nmes: Optional[List[str]] = None,
100
+ geo_token_hidden_dim: Optional[int] = None,
101
+ geo_token_layers: Optional[int] = None,
102
+ geo_token_dropout: Optional[float] = None,
103
+ geo_token_k_neighbors: Optional[int] = None,
104
+ geo_token_learning_rate: Optional[float] = None,
105
+ geo_token_epochs: Optional[int] = None):
93
106
  """Orchestrate BayesOpt training across multiple trainers.
94
107
 
95
108
  Args:
@@ -176,6 +189,47 @@ class BayesOptModel:
176
189
  final_ensemble=bool(final_ensemble),
177
190
  final_ensemble_k=int(final_ensemble_k),
178
191
  final_refit=bool(final_refit),
192
+ cv_strategy=str(cv_strategy or "random"),
193
+ cv_splits=cv_splits,
194
+ cv_group_col=cv_group_col,
195
+ cv_time_col=cv_time_col,
196
+ cv_time_ascending=bool(cv_time_ascending),
197
+ ft_oof_folds=ft_oof_folds,
198
+ ft_oof_strategy=ft_oof_strategy,
199
+ ft_oof_shuffle=bool(ft_oof_shuffle),
200
+ save_preprocess=bool(save_preprocess),
201
+ preprocess_artifact_path=preprocess_artifact_path,
202
+ plot_path_style=str(plot_path_style or "nested"),
203
+ bo_sample_limit=bo_sample_limit,
204
+ cache_predictions=bool(cache_predictions),
205
+ prediction_cache_dir=prediction_cache_dir,
206
+ prediction_cache_format=str(prediction_cache_format or "parquet"),
207
+ region_province_col=region_province_col,
208
+ region_city_col=region_city_col,
209
+ region_effect_alpha=float(region_effect_alpha)
210
+ if region_effect_alpha is not None
211
+ else 50.0,
212
+ geo_feature_nmes=list(geo_feature_nmes)
213
+ if geo_feature_nmes is not None
214
+ else None,
215
+ geo_token_hidden_dim=int(geo_token_hidden_dim)
216
+ if geo_token_hidden_dim is not None
217
+ else 32,
218
+ geo_token_layers=int(geo_token_layers)
219
+ if geo_token_layers is not None
220
+ else 2,
221
+ geo_token_dropout=float(geo_token_dropout)
222
+ if geo_token_dropout is not None
223
+ else 0.1,
224
+ geo_token_k_neighbors=int(geo_token_k_neighbors)
225
+ if geo_token_k_neighbors is not None
226
+ else 10,
227
+ geo_token_learning_rate=float(geo_token_learning_rate)
228
+ if geo_token_learning_rate is not None
229
+ else 1e-3,
230
+ geo_token_epochs=int(geo_token_epochs)
231
+ if geo_token_epochs is not None
232
+ else 50,
179
233
  )
180
234
  self.config = cfg
181
235
  self.model_nme = cfg.model_nme
@@ -204,15 +258,24 @@ class BayesOptModel:
204
258
  self.var_nmes = preprocessor.var_nmes
205
259
  self.num_features = preprocessor.num_features
206
260
  self.cat_categories_for_shap = preprocessor.cat_categories_for_shap
261
+ if getattr(self.config, "save_preprocess", False):
262
+ artifact_path = getattr(self.config, "preprocess_artifact_path", None)
263
+ if artifact_path:
264
+ target = Path(str(artifact_path))
265
+ if not target.is_absolute():
266
+ target = Path(self.output_manager.result_dir) / target
267
+ else:
268
+ target = Path(self.output_manager.result_path(
269
+ f"{self.model_nme}_preprocess.json"
270
+ ))
271
+ preprocessor.save_artifacts(target)
207
272
  self.geo_token_cols: List[str] = []
208
273
  self.train_geo_tokens: Optional[pd.DataFrame] = None
209
274
  self.test_geo_tokens: Optional[pd.DataFrame] = None
210
275
  self.geo_gnn_model: Optional[GraphNeuralNetSklearn] = None
211
276
  self._add_region_effect()
212
277
 
213
- self.cv = ShuffleSplit(n_splits=int(1/self.prop_test),
214
- test_size=self.prop_test,
215
- random_state=self.rand_seed)
278
+ self.cv = self._build_cv_splitter()
216
279
  if self.task_type == 'classification':
217
280
  self.obj = 'binary:logistic'
218
281
  else: # regression task
@@ -261,6 +324,45 @@ class BayesOptModel:
261
324
  self.ft_load = None
262
325
  self.version_manager = VersionManager(self.output_manager)
263
326
 
327
+ def _build_cv_splitter(self) -> _CVSplitter:
328
+ strategy = str(getattr(self.config, "cv_strategy", "random") or "random").strip().lower()
329
+ val_ratio = float(self.prop_test) if self.prop_test is not None else 0.25
330
+ if not (0.0 < val_ratio < 1.0):
331
+ val_ratio = 0.25
332
+ cv_splits = getattr(self.config, "cv_splits", None)
333
+ if cv_splits is None:
334
+ cv_splits = max(2, int(round(1 / val_ratio)))
335
+ cv_splits = max(2, int(cv_splits))
336
+
337
+ if strategy in {"group", "grouped"}:
338
+ group_col = getattr(self.config, "cv_group_col", None)
339
+ if not group_col:
340
+ raise ValueError("cv_group_col is required for group cv_strategy.")
341
+ if group_col not in self.train_data.columns:
342
+ raise KeyError(f"cv_group_col '{group_col}' not in train_data.")
343
+ groups = self.train_data[group_col]
344
+ splitter = GroupKFold(n_splits=cv_splits)
345
+ return _CVSplitter(splitter, groups=groups)
346
+
347
+ if strategy in {"time", "timeseries", "temporal"}:
348
+ time_col = getattr(self.config, "cv_time_col", None)
349
+ if not time_col:
350
+ raise ValueError("cv_time_col is required for time cv_strategy.")
351
+ if time_col not in self.train_data.columns:
352
+ raise KeyError(f"cv_time_col '{time_col}' not in train_data.")
353
+ ascending = bool(getattr(self.config, "cv_time_ascending", True))
354
+ order_index = self.train_data[time_col].sort_values(ascending=ascending).index
355
+ order = self.train_data.index.get_indexer(order_index)
356
+ splitter = TimeSeriesSplit(n_splits=cv_splits)
357
+ return _CVSplitter(splitter, order=order)
358
+
359
+ splitter = ShuffleSplit(
360
+ n_splits=cv_splits,
361
+ test_size=val_ratio,
362
+ random_state=self.rand_seed,
363
+ )
364
+ return _CVSplitter(splitter)
365
+
264
366
  def default_tweedie_power(self, obj: Optional[str] = None) -> Optional[float]:
265
367
  if self.task_type == 'classification':
266
368
  return None
@@ -473,142 +575,6 @@ class BayesOptModel:
473
575
  col for col in self.train_oht_scl_data.columns if col not in excluded
474
576
  ]
475
577
 
476
- # Single-factor plotting helper.
477
- def plot_oneway(
478
- self,
479
- n_bins=10,
480
- pred_col: Optional[str] = None,
481
- pred_label: Optional[str] = None,
482
- pred_weighted: Optional[bool] = None,
483
- plot_subdir: Optional[str] = None,
484
- ):
485
- if plt is None and plot_diagnostics is None:
486
- _plot_skip("oneway plot")
487
- return
488
- if pred_col is not None and pred_col not in self.train_data.columns:
489
- print(
490
- f"[Oneway] Missing prediction column '{pred_col}'; skip predicted line.",
491
- flush=True,
492
- )
493
- pred_col = None
494
- if pred_weighted is None and pred_col is not None:
495
- pred_weighted = pred_col.startswith("w_pred_")
496
- if pred_weighted is None:
497
- pred_weighted = False
498
- plot_subdir = plot_subdir.strip("/\\") if plot_subdir else "oneway"
499
- plot_prefix = f"{self.model_nme}/{plot_subdir}"
500
-
501
- def _safe_tag(value: str) -> str:
502
- return (
503
- value.strip()
504
- .replace(" ", "_")
505
- .replace("/", "_")
506
- .replace("\\", "_")
507
- .replace(":", "_")
508
- )
509
-
510
- if plot_diagnostics is None:
511
- for c in self.factor_nmes:
512
- fig = plt.figure(figsize=(7, 5))
513
- if c in self.cate_list:
514
- group_col = c
515
- plot_source = self.train_data
516
- else:
517
- group_col = f'{c}_bins'
518
- bins = pd.qcut(
519
- self.train_data[c],
520
- n_bins,
521
- duplicates='drop' # Drop duplicate quantiles to avoid errors.
522
- )
523
- plot_source = self.train_data.assign(**{group_col: bins})
524
- if pred_col is not None and pred_col in plot_source.columns:
525
- if pred_weighted:
526
- plot_source = plot_source.assign(
527
- _pred_w=plot_source[pred_col]
528
- )
529
- else:
530
- plot_source = plot_source.assign(
531
- _pred_w=plot_source[pred_col] * plot_source[self.weight_nme]
532
- )
533
- plot_data = plot_source.groupby(
534
- [group_col], observed=True).sum(numeric_only=True)
535
- plot_data.reset_index(inplace=True)
536
- plot_data['act_v'] = plot_data['w_act'] / \
537
- plot_data[self.weight_nme]
538
- if pred_col is not None and "_pred_w" in plot_data.columns:
539
- plot_data["pred_v"] = plot_data["_pred_w"] / plot_data[self.weight_nme]
540
- ax = fig.add_subplot(111)
541
- ax.plot(plot_data.index, plot_data['act_v'],
542
- label='Actual', color='red')
543
- if pred_col is not None and "pred_v" in plot_data.columns:
544
- ax.plot(
545
- plot_data.index,
546
- plot_data["pred_v"],
547
- label=pred_label or "Predicted",
548
- color="tab:blue",
549
- )
550
- ax.set_title(
551
- 'Analysis of %s : Train Data' % group_col,
552
- fontsize=8)
553
- plt.xticks(plot_data.index,
554
- list(plot_data[group_col].astype(str)),
555
- rotation=90)
556
- if len(list(plot_data[group_col].astype(str))) > 50:
557
- plt.xticks(fontsize=3)
558
- else:
559
- plt.xticks(fontsize=6)
560
- plt.yticks(fontsize=6)
561
- ax2 = ax.twinx()
562
- ax2.bar(plot_data.index,
563
- plot_data[self.weight_nme],
564
- alpha=0.5, color='seagreen')
565
- plt.yticks(fontsize=6)
566
- plt.margins(0.05)
567
- plt.subplots_adjust(wspace=0.3)
568
- if pred_col is not None and "pred_v" in plot_data.columns:
569
- ax.legend(fontsize=6)
570
- pred_tag = _safe_tag(pred_label or pred_col) if pred_col else None
571
- if pred_tag:
572
- filename = f'00_{self.model_nme}_{group_col}_oneway_{pred_tag}.png'
573
- else:
574
- filename = f'00_{self.model_nme}_{group_col}_oneway.png'
575
- save_path = self.output_manager.plot_path(
576
- f'{plot_prefix}/{filename}')
577
- plt.savefig(save_path, dpi=300)
578
- plt.close(fig)
579
- return
580
-
581
- if "w_act" not in self.train_data.columns:
582
- print("[Oneway] Missing w_act column; skip plotting.", flush=True)
583
- return
584
-
585
- for c in self.factor_nmes:
586
- is_cat = c in (self.cate_list or [])
587
- group_col = c if is_cat else f"{c}_bins"
588
- title = f"Analysis of {group_col} : Train Data"
589
- pred_tag = _safe_tag(pred_label or pred_col) if pred_col else None
590
- if pred_tag:
591
- filename = f"00_{self.model_nme}_{group_col}_oneway_{pred_tag}.png"
592
- else:
593
- filename = f"00_{self.model_nme}_{group_col}_oneway.png"
594
- save_path = self.output_manager.plot_path(
595
- f"{plot_prefix}/{filename}"
596
- )
597
- plot_diagnostics.plot_oneway(
598
- self.train_data,
599
- feature=c,
600
- weight_col=self.weight_nme,
601
- target_col="w_act",
602
- pred_col=pred_col,
603
- pred_weighted=pred_weighted,
604
- pred_label=pred_label,
605
- n_bins=n_bins,
606
- is_categorical=is_cat,
607
- title=title,
608
- save_path=save_path,
609
- show=False,
610
- )
611
-
612
578
  def _require_trainer(self, model_key: str) -> "TrainerBase":
613
579
  trainer = self.trainers.get(model_key)
614
580
  if trainer is None:
@@ -896,486 +862,6 @@ class BayesOptModel:
896
862
  def bayesopt_ft(self, max_evals=50):
897
863
  self.optimize_model('ft', max_evals)
898
864
 
899
- # Lift curve plotting.
900
- def plot_lift(self, model_label, pred_nme, n_bins=10):
901
- if plt is None:
902
- _plot_skip("lift plot")
903
- return
904
- model_map = {
905
- 'Xgboost': 'pred_xgb',
906
- 'ResNet': 'pred_resn',
907
- 'ResNetClassifier': 'pred_resn',
908
- 'GLM': 'pred_glm',
909
- 'GNN': 'pred_gnn',
910
- }
911
- if str(self.config.ft_role) == "model":
912
- model_map.update({
913
- 'FTTransformer': 'pred_ft',
914
- 'FTTransformerClassifier': 'pred_ft',
915
- })
916
- for k, v in model_map.items():
917
- if model_label.startswith(k):
918
- pred_nme = v
919
- break
920
- safe_label = (
921
- str(model_label)
922
- .replace(" ", "_")
923
- .replace("/", "_")
924
- .replace("\\", "_")
925
- .replace(":", "_")
926
- )
927
- plot_prefix = f"{self.model_nme}/lift"
928
- filename = f"01_{self.model_nme}_{safe_label}_lift.png"
929
-
930
- datasets = []
931
- for title, data in [
932
- ('Lift Chart on Train Data', self.train_data),
933
- ('Lift Chart on Test Data', self.test_data),
934
- ]:
935
- if 'w_act' not in data.columns or data['w_act'].isna().all():
936
- print(
937
- f"[Lift] Missing labels for {title}; skip.",
938
- flush=True,
939
- )
940
- continue
941
- datasets.append((title, data))
942
-
943
- if not datasets:
944
- print("[Lift] No labeled data available; skip plotting.", flush=True)
945
- return
946
-
947
- if plot_curves is None:
948
- fig = plt.figure(figsize=(11, 5))
949
- positions = [111] if len(datasets) == 1 else [121, 122]
950
- for pos, (title, data) in zip(positions, datasets):
951
- if pred_nme not in data.columns or f'w_{pred_nme}' not in data.columns:
952
- print(
953
- f"[Lift] Missing prediction columns in {title}; skip.",
954
- flush=True,
955
- )
956
- continue
957
- lift_df = pd.DataFrame({
958
- 'pred': data[pred_nme].values,
959
- 'w_pred': data[f'w_{pred_nme}'].values,
960
- 'act': data['w_act'].values,
961
- 'weight': data[self.weight_nme].values
962
- })
963
- plot_data = PlotUtils.split_data(lift_df, 'pred', 'weight', n_bins)
964
- denom = np.maximum(plot_data['weight'], EPS)
965
- plot_data['exp_v'] = plot_data['w_pred'] / denom
966
- plot_data['act_v'] = plot_data['act'] / denom
967
- plot_data = plot_data.reset_index()
968
-
969
- ax = fig.add_subplot(pos)
970
- PlotUtils.plot_lift_ax(ax, plot_data, title)
971
-
972
- plt.subplots_adjust(wspace=0.3)
973
- save_path = self.output_manager.plot_path(
974
- f"{plot_prefix}/{filename}")
975
- plt.savefig(save_path, dpi=300)
976
- plt.show()
977
- plt.close(fig)
978
- return
979
-
980
- style = PlotStyle() if PlotStyle else None
981
- fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
982
- if len(datasets) == 1:
983
- axes = [axes]
984
-
985
- for ax, (title, data) in zip(axes, datasets):
986
- pred_vals = None
987
- if pred_nme in data.columns:
988
- pred_vals = data[pred_nme].values
989
- else:
990
- w_pred_col = f"w_{pred_nme}"
991
- if w_pred_col in data.columns:
992
- denom = np.maximum(data[self.weight_nme].values, EPS)
993
- pred_vals = data[w_pred_col].values / denom
994
- if pred_vals is None:
995
- print(
996
- f"[Lift] Missing prediction columns in {title}; skip.",
997
- flush=True,
998
- )
999
- continue
1000
-
1001
- plot_curves.plot_lift_curve(
1002
- pred_vals,
1003
- data['w_act'].values,
1004
- data[self.weight_nme].values,
1005
- n_bins=n_bins,
1006
- title=title,
1007
- pred_label="Predicted",
1008
- act_label="Actual",
1009
- weight_label="Earned Exposure",
1010
- pred_weighted=False,
1011
- actual_weighted=True,
1012
- ax=ax,
1013
- show=False,
1014
- style=style,
1015
- )
1016
-
1017
- plt.subplots_adjust(wspace=0.3)
1018
- save_path = self.output_manager.plot_path(
1019
- f"{plot_prefix}/{filename}")
1020
- if finalize_figure:
1021
- finalize_figure(fig, save_path=save_path, show=True, style=style)
1022
- else:
1023
- plt.savefig(save_path, dpi=300)
1024
- plt.show()
1025
- plt.close(fig)
1026
-
1027
- # Double lift curve plot.
1028
- def plot_dlift(self, model_comp: List[str] = ['xgb', 'resn'], n_bins: int = 10) -> None:
1029
- # Compare two models across bins.
1030
- # Args:
1031
- # model_comp: model keys to compare (e.g., ['xgb', 'resn']).
1032
- # n_bins: number of bins for lift curves.
1033
- if plt is None:
1034
- _plot_skip("double lift plot")
1035
- return
1036
- if len(model_comp) != 2:
1037
- raise ValueError("`model_comp` must contain two models to compare.")
1038
-
1039
- model_name_map = {
1040
- 'xgb': 'Xgboost',
1041
- 'resn': 'ResNet',
1042
- 'glm': 'GLM',
1043
- 'gnn': 'GNN',
1044
- }
1045
- if str(self.config.ft_role) == "model":
1046
- model_name_map['ft'] = 'FTTransformer'
1047
-
1048
- name1, name2 = model_comp
1049
- if name1 not in model_name_map or name2 not in model_name_map:
1050
- raise ValueError(f"Unsupported model key. Choose from {list(model_name_map.keys())}.")
1051
- plot_prefix = f"{self.model_nme}/double_lift"
1052
- filename = f"02_{self.model_nme}_dlift_{name1}_vs_{name2}.png"
1053
-
1054
- datasets = []
1055
- for data_name, data in [('Train Data', self.train_data),
1056
- ('Test Data', self.test_data)]:
1057
- if 'w_act' not in data.columns or data['w_act'].isna().all():
1058
- print(
1059
- f"[Double Lift] Missing labels for {data_name}; skip.",
1060
- flush=True,
1061
- )
1062
- continue
1063
- datasets.append((data_name, data))
1064
-
1065
- if not datasets:
1066
- print("[Double Lift] No labeled data available; skip plotting.", flush=True)
1067
- return
1068
-
1069
- if plot_curves is None:
1070
- fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
1071
- if len(datasets) == 1:
1072
- axes = [axes]
1073
-
1074
- for ax, (data_name, data) in zip(axes, datasets):
1075
- pred1_col = f'w_pred_{name1}'
1076
- pred2_col = f'w_pred_{name2}'
1077
-
1078
- if pred1_col not in data.columns or pred2_col not in data.columns:
1079
- print(
1080
- f"Warning: missing prediction columns {pred1_col} or {pred2_col} in {data_name}. Skip plot.")
1081
- continue
1082
-
1083
- lift_data = pd.DataFrame({
1084
- 'pred1': data[pred1_col].values,
1085
- 'pred2': data[pred2_col].values,
1086
- 'diff_ly': data[pred1_col].values / np.maximum(data[pred2_col].values, EPS),
1087
- 'act': data['w_act'].values,
1088
- 'weight': data[self.weight_nme].values
1089
- })
1090
- plot_data = PlotUtils.split_data(
1091
- lift_data, 'diff_ly', 'weight', n_bins)
1092
- denom = np.maximum(plot_data['act'], EPS)
1093
- plot_data['exp_v1'] = plot_data['pred1'] / denom
1094
- plot_data['exp_v2'] = plot_data['pred2'] / denom
1095
- plot_data['act_v'] = plot_data['act'] / denom
1096
- plot_data.reset_index(inplace=True)
1097
-
1098
- label1 = model_name_map[name1]
1099
- label2 = model_name_map[name2]
1100
-
1101
- PlotUtils.plot_dlift_ax(
1102
- ax, plot_data, f'Double Lift Chart on {data_name}', label1, label2)
1103
-
1104
- plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
1105
- save_path = self.output_manager.plot_path(
1106
- f"{plot_prefix}/{filename}")
1107
- plt.savefig(save_path, dpi=300)
1108
- plt.show()
1109
- plt.close(fig)
1110
- return
1111
-
1112
- style = PlotStyle() if PlotStyle else None
1113
- fig, axes = plt.subplots(1, len(datasets), figsize=(11, 5))
1114
- if len(datasets) == 1:
1115
- axes = [axes]
1116
-
1117
- label1 = model_name_map[name1]
1118
- label2 = model_name_map[name2]
1119
-
1120
- for ax, (data_name, data) in zip(axes, datasets):
1121
- weight_vals = data[self.weight_nme].values
1122
- pred1 = None
1123
- pred2 = None
1124
-
1125
- pred1_col = f"pred_{name1}"
1126
- pred2_col = f"pred_{name2}"
1127
- if pred1_col in data.columns:
1128
- pred1 = data[pred1_col].values
1129
- else:
1130
- w_pred1_col = f"w_pred_{name1}"
1131
- if w_pred1_col in data.columns:
1132
- pred1 = data[w_pred1_col].values / np.maximum(weight_vals, EPS)
1133
-
1134
- if pred2_col in data.columns:
1135
- pred2 = data[pred2_col].values
1136
- else:
1137
- w_pred2_col = f"w_pred_{name2}"
1138
- if w_pred2_col in data.columns:
1139
- pred2 = data[w_pred2_col].values / np.maximum(weight_vals, EPS)
1140
-
1141
- if pred1 is None or pred2 is None:
1142
- print(
1143
- f"Warning: missing pred_{name1}/pred_{name2} or w_pred columns in {data_name}. Skip plot.")
1144
- continue
1145
-
1146
- plot_curves.plot_double_lift_curve(
1147
- pred1,
1148
- pred2,
1149
- data['w_act'].values,
1150
- weight_vals,
1151
- n_bins=n_bins,
1152
- title=f"Double Lift Chart on {data_name}",
1153
- label1=label1,
1154
- label2=label2,
1155
- pred1_weighted=False,
1156
- pred2_weighted=False,
1157
- actual_weighted=True,
1158
- ax=ax,
1159
- show=False,
1160
- style=style,
1161
- )
1162
-
1163
- plt.subplots_adjust(bottom=0.25, top=0.95, right=0.8, wspace=0.3)
1164
- save_path = self.output_manager.plot_path(
1165
- f"{plot_prefix}/{filename}")
1166
- if finalize_figure:
1167
- finalize_figure(fig, save_path=save_path, show=True, style=style)
1168
- else:
1169
- plt.savefig(save_path, dpi=300)
1170
- plt.show()
1171
- plt.close(fig)
1172
-
1173
- # Conversion lift curve plot.
1174
- def plot_conversion_lift(self, model_pred_col: str, n_bins: int = 20):
1175
- if plt is None:
1176
- _plot_skip("conversion lift plot")
1177
- return
1178
- if not self.binary_resp_nme:
1179
- print("Error: `binary_resp_nme` not provided at BayesOptModel init; cannot plot conversion lift.")
1180
- return
1181
-
1182
- if plot_curves is None:
1183
- fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
1184
- datasets = {
1185
- 'Train Data': self.train_data,
1186
- 'Test Data': self.test_data
1187
- }
1188
-
1189
- for ax, (data_name, data) in zip(axes, datasets.items()):
1190
- if model_pred_col not in data.columns:
1191
- print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
1192
- continue
1193
-
1194
- # Sort by model prediction and compute bins.
1195
- plot_data = data.sort_values(by=model_pred_col).copy()
1196
- plot_data['cum_weight'] = plot_data[self.weight_nme].cumsum()
1197
- total_weight = plot_data[self.weight_nme].sum()
1198
-
1199
- if total_weight > EPS:
1200
- plot_data['bin'] = pd.cut(
1201
- plot_data['cum_weight'],
1202
- bins=n_bins,
1203
- labels=False,
1204
- right=False
1205
- )
1206
- else:
1207
- plot_data['bin'] = 0
1208
-
1209
- # Aggregate by bins.
1210
- lift_agg = plot_data.groupby('bin').agg(
1211
- total_weight=(self.weight_nme, 'sum'),
1212
- actual_conversions=(self.binary_resp_nme, 'sum'),
1213
- weighted_conversions=('w_binary_act', 'sum'),
1214
- avg_pred=(model_pred_col, 'mean')
1215
- ).reset_index()
1216
-
1217
- # Compute conversion rate.
1218
- lift_agg['conversion_rate'] = lift_agg['weighted_conversions'] / \
1219
- lift_agg['total_weight']
1220
-
1221
- # Compute overall average conversion rate.
1222
- overall_conversion_rate = data['w_binary_act'].sum(
1223
- ) / data[self.weight_nme].sum()
1224
- ax.axhline(y=overall_conversion_rate, color='gray', linestyle='--',
1225
- label=f'Overall Avg Rate ({overall_conversion_rate:.2%})')
1226
-
1227
- ax.plot(lift_agg['bin'], lift_agg['conversion_rate'],
1228
- marker='o', linestyle='-', label='Actual Conversion Rate')
1229
- ax.set_title(f'Conversion Rate Lift Chart on {data_name}')
1230
- ax.set_xlabel(f'Model Score Decile (based on {model_pred_col})')
1231
- ax.set_ylabel('Conversion Rate')
1232
- ax.grid(True, linestyle='--', alpha=0.6)
1233
- ax.legend()
1234
-
1235
- plt.tight_layout()
1236
- plt.show()
1237
- return
1238
-
1239
- fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
1240
- datasets = {
1241
- 'Train Data': self.train_data,
1242
- 'Test Data': self.test_data
1243
- }
1244
-
1245
- for ax, (data_name, data) in zip(axes, datasets.items()):
1246
- if model_pred_col not in data.columns:
1247
- print(f"Warning: missing prediction column '{model_pred_col}' in {data_name}. Skip plot.")
1248
- continue
1249
-
1250
- plot_curves.plot_conversion_lift(
1251
- data[model_pred_col].values,
1252
- data[self.binary_resp_nme].values,
1253
- data[self.weight_nme].values,
1254
- n_bins=n_bins,
1255
- title=f'Conversion Rate Lift Chart on {data_name}',
1256
- ax=ax,
1257
- show=False,
1258
- )
1259
-
1260
- plt.tight_layout()
1261
- plt.show()
1262
-
1263
- # ========= Lightweight explainability: Permutation Importance =========
1264
- def compute_permutation_importance(self,
1265
- model_key: str,
1266
- on_train: bool = True,
1267
- metric: Any = "auto",
1268
- n_repeats: int = 5,
1269
- max_rows: int = 5000,
1270
- random_state: Optional[int] = None):
1271
- if explain_permutation is None:
1272
- raise RuntimeError("explain.permutation is not available.")
1273
-
1274
- model_key = str(model_key)
1275
- data = self.train_data if on_train else self.test_data
1276
- if self.resp_nme not in data.columns:
1277
- raise RuntimeError("Missing response column for permutation importance.")
1278
- y = data[self.resp_nme]
1279
- w = data[self.weight_nme] if self.weight_nme in data.columns else None
1280
-
1281
- if model_key == "resn":
1282
- if self.resn_best is None:
1283
- raise RuntimeError("ResNet model not trained.")
1284
- X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
1285
- if X is None:
1286
- raise RuntimeError("Missing standardized features for ResNet.")
1287
- X = X[self.var_nmes]
1288
- predict_fn = lambda df: self.resn_best.predict(df)
1289
- elif model_key == "ft":
1290
- if self.ft_best is None:
1291
- raise RuntimeError("FT model not trained.")
1292
- if str(self.config.ft_role) != "model":
1293
- raise RuntimeError("FT role is not 'model'; FT predictions unavailable.")
1294
- X = data[self.factor_nmes]
1295
- geo_tokens = self.train_geo_tokens if on_train else self.test_geo_tokens
1296
- geo_np = None
1297
- if geo_tokens is not None:
1298
- geo_np = geo_tokens.to_numpy(dtype=np.float32, copy=False)
1299
- predict_fn = lambda df, geo=geo_np: self.ft_best.predict(df, geo_tokens=geo)
1300
- elif model_key == "xgb":
1301
- if self.xgb_best is None:
1302
- raise RuntimeError("XGB model not trained.")
1303
- X = data[self.factor_nmes]
1304
- predict_fn = lambda df: self.xgb_best.predict(df)
1305
- else:
1306
- raise ValueError("Unsupported model_key for permutation importance.")
1307
-
1308
- return explain_permutation.permutation_importance(
1309
- predict_fn,
1310
- X,
1311
- y,
1312
- sample_weight=w,
1313
- metric=metric,
1314
- task_type=self.task_type,
1315
- n_repeats=n_repeats,
1316
- random_state=random_state,
1317
- max_rows=max_rows,
1318
- )
1319
-
1320
- # ========= Deep explainability: Integrated Gradients =========
1321
- def compute_integrated_gradients_resn(self,
1322
- on_train: bool = True,
1323
- baseline: Any = None,
1324
- steps: int = 50,
1325
- batch_size: int = 256,
1326
- target: Optional[int] = None):
1327
- if explain_gradients is None:
1328
- raise RuntimeError("explain.gradients is not available.")
1329
- if self.resn_best is None:
1330
- raise RuntimeError("ResNet model not trained.")
1331
- X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
1332
- if X is None:
1333
- raise RuntimeError("Missing standardized features for ResNet.")
1334
- X = X[self.var_nmes]
1335
- return explain_gradients.resnet_integrated_gradients(
1336
- self.resn_best,
1337
- X,
1338
- baseline=baseline,
1339
- steps=steps,
1340
- batch_size=batch_size,
1341
- target=target,
1342
- )
1343
-
1344
- def compute_integrated_gradients_ft(self,
1345
- on_train: bool = True,
1346
- geo_tokens: Optional[np.ndarray] = None,
1347
- baseline_num: Any = None,
1348
- baseline_geo: Any = None,
1349
- steps: int = 50,
1350
- batch_size: int = 256,
1351
- target: Optional[int] = None):
1352
- if explain_gradients is None:
1353
- raise RuntimeError("explain.gradients is not available.")
1354
- if self.ft_best is None:
1355
- raise RuntimeError("FT model not trained.")
1356
- if str(self.config.ft_role) != "model":
1357
- raise RuntimeError("FT role is not 'model'; FT explanations unavailable.")
1358
-
1359
- data = self.train_data if on_train else self.test_data
1360
- X = data[self.factor_nmes]
1361
-
1362
- if geo_tokens is None and getattr(self.ft_best, "num_geo", 0) > 0:
1363
- tokens_df = self.train_geo_tokens if on_train else self.test_geo_tokens
1364
- if tokens_df is not None:
1365
- geo_tokens = tokens_df.to_numpy(dtype=np.float32, copy=False)
1366
-
1367
- return explain_gradients.ft_integrated_gradients(
1368
- self.ft_best,
1369
- X,
1370
- geo_tokens=geo_tokens,
1371
- baseline_num=baseline_num,
1372
- baseline_geo=baseline_geo,
1373
- steps=steps,
1374
- batch_size=batch_size,
1375
- target=target,
1376
- )
1377
-
1378
- # Save model
1379
865
  def save_model(self, model_name=None):
1380
866
  keys = [model_name] if model_name else self.trainers.keys()
1381
867
  for key in keys:
@@ -1401,149 +887,3 @@ class BayesOptModel:
1401
887
  else:
1402
888
  if model_name:
1403
889
  print(f"[load_model] Warning: Unknown model key {key}")
1404
-
1405
- def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
1406
- if len(data) == 0:
1407
- return data
1408
- return data.sample(min(len(data), n), random_state=self.rand_seed)
1409
-
1410
- @staticmethod
1411
- def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
1412
- min_needed = arr.shape[1] + 2
1413
- return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
1414
-
1415
- def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
1416
- matrices = []
1417
- for col in self.factor_nmes:
1418
- s = data[col]
1419
- if col in self.cate_list:
1420
- cats = pd.Categorical(
1421
- s,
1422
- categories=self.cat_categories_for_shap[col]
1423
- )
1424
- codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
1425
- matrices.append(codes)
1426
- else:
1427
- vals = pd.to_numeric(s, errors="coerce")
1428
- arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
1429
- matrices.append(arr)
1430
- X_mat = np.concatenate(matrices, axis=1) # Result shape (N, F)
1431
- return X_mat
1432
-
1433
- def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
1434
- data_dict = {}
1435
- for j, col in enumerate(self.factor_nmes):
1436
- col_vals = X_mat[:, j]
1437
- if col in self.cate_list:
1438
- cats = self.cat_categories_for_shap[col]
1439
- codes = np.round(col_vals).astype(int)
1440
- codes = np.clip(codes, -1, len(cats) - 1)
1441
- cat_series = pd.Categorical.from_codes(
1442
- codes,
1443
- categories=cats
1444
- )
1445
- data_dict[col] = cat_series
1446
- else:
1447
- data_dict[col] = col_vals.astype(float)
1448
-
1449
- df = pd.DataFrame(data_dict, columns=self.factor_nmes)
1450
- for col in self.cate_list:
1451
- if col in df.columns:
1452
- df[col] = df[col].astype("category")
1453
- return df
1454
-
1455
- def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
1456
- X = data[self.var_nmes]
1457
- return sm.add_constant(X, has_constant='add')
1458
-
1459
- def _compute_shap_core(self,
1460
- model_key: str,
1461
- n_background: int,
1462
- n_samples: int,
1463
- on_train: bool,
1464
- X_df: pd.DataFrame,
1465
- prep_fn,
1466
- predict_fn,
1467
- cleanup_fn=None):
1468
- if explain_shap is None:
1469
- raise RuntimeError("explain.shap_utils is not available.")
1470
- return explain_shap.compute_shap_core(
1471
- self,
1472
- model_key,
1473
- n_background,
1474
- n_samples,
1475
- on_train,
1476
- X_df=X_df,
1477
- prep_fn=prep_fn,
1478
- predict_fn=predict_fn,
1479
- cleanup_fn=cleanup_fn,
1480
- )
1481
-
1482
- # ========= GLM SHAP explainability =========
1483
- def compute_shap_glm(self, n_background: int = 500,
1484
- n_samples: int = 200,
1485
- on_train: bool = True):
1486
- if explain_shap is None:
1487
- raise RuntimeError("explain.shap_utils is not available.")
1488
- self.shap_glm = explain_shap.compute_shap_glm(
1489
- self,
1490
- n_background=n_background,
1491
- n_samples=n_samples,
1492
- on_train=on_train,
1493
- )
1494
- return self.shap_glm
1495
-
1496
- # ========= XGBoost SHAP explainability =========
1497
- def compute_shap_xgb(self, n_background: int = 500,
1498
- n_samples: int = 200,
1499
- on_train: bool = True):
1500
- if explain_shap is None:
1501
- raise RuntimeError("explain.shap_utils is not available.")
1502
- self.shap_xgb = explain_shap.compute_shap_xgb(
1503
- self,
1504
- n_background=n_background,
1505
- n_samples=n_samples,
1506
- on_train=on_train,
1507
- )
1508
- return self.shap_xgb
1509
-
1510
- # ========= ResNet SHAP explainability =========
1511
- def _resn_predict_wrapper(self, X_np):
1512
- model = self.resn_best.resnet.to("cpu")
1513
- with torch.no_grad():
1514
- X_tensor = torch.tensor(X_np, dtype=torch.float32)
1515
- y_pred = model(X_tensor).cpu().numpy()
1516
- y_pred = np.clip(y_pred, 1e-6, None)
1517
- return y_pred.reshape(-1)
1518
-
1519
- def compute_shap_resn(self, n_background: int = 500,
1520
- n_samples: int = 200,
1521
- on_train: bool = True):
1522
- if explain_shap is None:
1523
- raise RuntimeError("explain.shap_utils is not available.")
1524
- self.shap_resn = explain_shap.compute_shap_resn(
1525
- self,
1526
- n_background=n_background,
1527
- n_samples=n_samples,
1528
- on_train=on_train,
1529
- )
1530
- return self.shap_resn
1531
-
1532
- # ========= FT-Transformer SHAP explainability =========
1533
- def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
1534
- df_input = self._decode_ft_shap_matrix_to_df(X_mat)
1535
- y_pred = self.ft_best.predict(df_input)
1536
- return np.asarray(y_pred, dtype=np.float64).reshape(-1)
1537
-
1538
- def compute_shap_ft(self, n_background: int = 500,
1539
- n_samples: int = 200,
1540
- on_train: bool = True):
1541
- if explain_shap is None:
1542
- raise RuntimeError("explain.shap_utils is not available.")
1543
- self.shap_ft = explain_shap.compute_shap_ft(
1544
- self,
1545
- n_background=n_background,
1546
- n_samples=n_samples,
1547
- on_train=on_train,
1548
- )
1549
- return self.shap_ft