ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,296 +1,296 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Optional
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import torch
8
- import statsmodels.api as sm
9
-
10
- try:
11
- from ...explain import gradients as explain_gradients
12
- from ...explain import permutation as explain_permutation
13
- from ...explain import shap_utils as explain_shap
14
- except Exception: # pragma: no cover - optional for legacy imports
15
- try: # best-effort for non-package imports
16
- from ins_pricing.explain import gradients as explain_gradients
17
- from ins_pricing.explain import permutation as explain_permutation
18
- from ins_pricing.explain import shap_utils as explain_shap
19
- except Exception: # pragma: no cover
20
- explain_gradients = None
21
- explain_permutation = None
22
- explain_shap = None
23
-
24
-
25
- class BayesOptExplainMixin:
26
- def compute_permutation_importance(self,
27
- model_key: str,
28
- on_train: bool = True,
29
- metric: Any = "auto",
30
- n_repeats: int = 5,
31
- max_rows: int = 5000,
32
- random_state: Optional[int] = None):
33
- if explain_permutation is None:
34
- raise RuntimeError("explain.permutation is not available.")
35
-
36
- model_key = str(model_key)
37
- data = self.train_data if on_train else self.test_data
38
- if self.resp_nme not in data.columns:
39
- raise RuntimeError("Missing response column for permutation importance.")
40
- y = data[self.resp_nme]
41
- w = data[self.weight_nme] if self.weight_nme in data.columns else None
42
-
43
- if model_key == "resn":
44
- if self.resn_best is None:
45
- raise RuntimeError("ResNet model not trained.")
46
- X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
47
- if X is None:
48
- raise RuntimeError("Missing standardized features for ResNet.")
49
- X = X[self.var_nmes]
50
- predict_fn = lambda df: self.resn_best.predict(df)
51
- elif model_key == "ft":
52
- if self.ft_best is None:
53
- raise RuntimeError("FT model not trained.")
54
- if str(self.config.ft_role) != "model":
55
- raise RuntimeError("FT role is not 'model'; FT predictions unavailable.")
56
- X = data[self.factor_nmes]
57
- geo_tokens = self.train_geo_tokens if on_train else self.test_geo_tokens
58
- geo_np = None
59
- if geo_tokens is not None:
60
- geo_np = geo_tokens.to_numpy(dtype=np.float32, copy=False)
61
- predict_fn = lambda df, geo=geo_np: self.ft_best.predict(df, geo_tokens=geo)
62
- elif model_key == "xgb":
63
- if self.xgb_best is None:
64
- raise RuntimeError("XGB model not trained.")
65
- X = data[self.factor_nmes]
66
- predict_fn = lambda df: self.xgb_best.predict(df)
67
- else:
68
- raise ValueError("Unsupported model_key for permutation importance.")
69
-
70
- return explain_permutation.permutation_importance(
71
- predict_fn,
72
- X,
73
- y,
74
- sample_weight=w,
75
- metric=metric,
76
- task_type=self.task_type,
77
- n_repeats=n_repeats,
78
- random_state=random_state,
79
- max_rows=max_rows,
80
- )
81
-
82
- # ========= Deep explainability: Integrated Gradients =========
83
-
84
- def compute_integrated_gradients_resn(self,
85
- on_train: bool = True,
86
- baseline: Any = None,
87
- steps: int = 50,
88
- batch_size: int = 256,
89
- target: Optional[int] = None):
90
- if explain_gradients is None:
91
- raise RuntimeError("explain.gradients is not available.")
92
- if self.resn_best is None:
93
- raise RuntimeError("ResNet model not trained.")
94
- X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
95
- if X is None:
96
- raise RuntimeError("Missing standardized features for ResNet.")
97
- X = X[self.var_nmes]
98
- return explain_gradients.resnet_integrated_gradients(
99
- self.resn_best,
100
- X,
101
- baseline=baseline,
102
- steps=steps,
103
- batch_size=batch_size,
104
- target=target,
105
- )
106
-
107
-
108
- def compute_integrated_gradients_ft(self,
109
- on_train: bool = True,
110
- geo_tokens: Optional[np.ndarray] = None,
111
- baseline_num: Any = None,
112
- baseline_geo: Any = None,
113
- steps: int = 50,
114
- batch_size: int = 256,
115
- target: Optional[int] = None):
116
- if explain_gradients is None:
117
- raise RuntimeError("explain.gradients is not available.")
118
- if self.ft_best is None:
119
- raise RuntimeError("FT model not trained.")
120
- if str(self.config.ft_role) != "model":
121
- raise RuntimeError("FT role is not 'model'; FT explanations unavailable.")
122
-
123
- data = self.train_data if on_train else self.test_data
124
- X = data[self.factor_nmes]
125
-
126
- if geo_tokens is None and getattr(self.ft_best, "num_geo", 0) > 0:
127
- tokens_df = self.train_geo_tokens if on_train else self.test_geo_tokens
128
- if tokens_df is not None:
129
- geo_tokens = tokens_df.to_numpy(dtype=np.float32, copy=False)
130
-
131
- return explain_gradients.ft_integrated_gradients(
132
- self.ft_best,
133
- X,
134
- geo_tokens=geo_tokens,
135
- baseline_num=baseline_num,
136
- baseline_geo=baseline_geo,
137
- steps=steps,
138
- batch_size=batch_size,
139
- target=target,
140
- )
141
-
142
- def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
143
- if len(data) == 0:
144
- return data
145
- return data.sample(min(len(data), n), random_state=self.rand_seed)
146
-
147
- @staticmethod
148
- def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
149
- min_needed = arr.shape[1] + 2
150
- return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
151
-
152
-
153
- def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
154
- matrices = []
155
- for col in self.factor_nmes:
156
- s = data[col]
157
- if col in self.cate_list:
158
- cats = pd.Categorical(
159
- s,
160
- categories=self.cat_categories_for_shap[col]
161
- )
162
- codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
163
- matrices.append(codes)
164
- else:
165
- vals = pd.to_numeric(s, errors="coerce")
166
- arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
167
- matrices.append(arr)
168
- X_mat = np.concatenate(matrices, axis=1) # Result shape (N, F)
169
- return X_mat
170
-
171
-
172
- def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
173
- data_dict = {}
174
- for j, col in enumerate(self.factor_nmes):
175
- col_vals = X_mat[:, j]
176
- if col in self.cate_list:
177
- cats = self.cat_categories_for_shap[col]
178
- codes = np.round(col_vals).astype(int)
179
- codes = np.clip(codes, -1, len(cats) - 1)
180
- cat_series = pd.Categorical.from_codes(
181
- codes,
182
- categories=cats
183
- )
184
- data_dict[col] = cat_series
185
- else:
186
- data_dict[col] = col_vals.astype(float)
187
-
188
- df = pd.DataFrame(data_dict, columns=self.factor_nmes)
189
- for col in self.cate_list:
190
- if col in df.columns:
191
- df[col] = df[col].astype("category")
192
- return df
193
-
194
-
195
- def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
196
- X = data[self.var_nmes]
197
- return sm.add_constant(X, has_constant='add')
198
-
199
-
200
- def _compute_shap_core(self,
201
- model_key: str,
202
- n_background: int,
203
- n_samples: int,
204
- on_train: bool,
205
- X_df: pd.DataFrame,
206
- prep_fn,
207
- predict_fn,
208
- cleanup_fn=None):
209
- if explain_shap is None:
210
- raise RuntimeError("explain.shap_utils is not available.")
211
- return explain_shap.compute_shap_core(
212
- self,
213
- model_key,
214
- n_background,
215
- n_samples,
216
- on_train,
217
- X_df=X_df,
218
- prep_fn=prep_fn,
219
- predict_fn=predict_fn,
220
- cleanup_fn=cleanup_fn,
221
- )
222
-
223
- # ========= GLM SHAP explainability =========
224
-
225
- def compute_shap_glm(self, n_background: int = 500,
226
- n_samples: int = 200,
227
- on_train: bool = True):
228
- if explain_shap is None:
229
- raise RuntimeError("explain.shap_utils is not available.")
230
- self.shap_glm = explain_shap.compute_shap_glm(
231
- self,
232
- n_background=n_background,
233
- n_samples=n_samples,
234
- on_train=on_train,
235
- )
236
- return self.shap_glm
237
-
238
- # ========= XGBoost SHAP explainability =========
239
-
240
- def compute_shap_xgb(self, n_background: int = 500,
241
- n_samples: int = 200,
242
- on_train: bool = True):
243
- if explain_shap is None:
244
- raise RuntimeError("explain.shap_utils is not available.")
245
- self.shap_xgb = explain_shap.compute_shap_xgb(
246
- self,
247
- n_background=n_background,
248
- n_samples=n_samples,
249
- on_train=on_train,
250
- )
251
- return self.shap_xgb
252
-
253
- # ========= ResNet SHAP explainability =========
254
-
255
- def _resn_predict_wrapper(self, X_np):
256
- model = self.resn_best.resnet.to("cpu")
257
- with torch.no_grad():
258
- X_tensor = torch.tensor(X_np, dtype=torch.float32)
259
- y_pred = model(X_tensor).cpu().numpy()
260
- y_pred = np.clip(y_pred, 1e-6, None)
261
- return y_pred.reshape(-1)
262
-
263
-
264
- def compute_shap_resn(self, n_background: int = 500,
265
- n_samples: int = 200,
266
- on_train: bool = True):
267
- if explain_shap is None:
268
- raise RuntimeError("explain.shap_utils is not available.")
269
- self.shap_resn = explain_shap.compute_shap_resn(
270
- self,
271
- n_background=n_background,
272
- n_samples=n_samples,
273
- on_train=on_train,
274
- )
275
- return self.shap_resn
276
-
277
- # ========= FT-Transformer SHAP explainability =========
278
-
279
- def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
280
- df_input = self._decode_ft_shap_matrix_to_df(X_mat)
281
- y_pred = self.ft_best.predict(df_input)
282
- return np.asarray(y_pred, dtype=np.float64).reshape(-1)
283
-
284
-
285
- def compute_shap_ft(self, n_background: int = 500,
286
- n_samples: int = 200,
287
- on_train: bool = True):
288
- if explain_shap is None:
289
- raise RuntimeError("explain.shap_utils is not available.")
290
- self.shap_ft = explain_shap.compute_shap_ft(
291
- self,
292
- n_background=n_background,
293
- n_samples=n_samples,
294
- on_train=on_train,
295
- )
296
- return self.shap_ft
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import statsmodels.api as sm
9
+
10
+ try:
11
+ from ins_pricing.modelling.explain import gradients as explain_gradients
12
+ from ins_pricing.modelling.explain import permutation as explain_permutation
13
+ from ins_pricing.modelling.explain import shap_utils as explain_shap
14
+ except Exception: # pragma: no cover - optional for legacy imports
15
+ try: # best-effort for non-package imports
16
+ from ins_pricing.explain import gradients as explain_gradients
17
+ from ins_pricing.explain import permutation as explain_permutation
18
+ from ins_pricing.explain import shap_utils as explain_shap
19
+ except Exception: # pragma: no cover
20
+ explain_gradients = None
21
+ explain_permutation = None
22
+ explain_shap = None
23
+
24
+
25
+ class BayesOptExplainMixin:
26
+ def compute_permutation_importance(self,
27
+ model_key: str,
28
+ on_train: bool = True,
29
+ metric: Any = "auto",
30
+ n_repeats: int = 5,
31
+ max_rows: int = 5000,
32
+ random_state: Optional[int] = None):
33
+ if explain_permutation is None:
34
+ raise RuntimeError("explain.permutation is not available.")
35
+
36
+ model_key = str(model_key)
37
+ data = self.train_data if on_train else self.test_data
38
+ if self.resp_nme not in data.columns:
39
+ raise RuntimeError("Missing response column for permutation importance.")
40
+ y = data[self.resp_nme]
41
+ w = data[self.weight_nme] if self.weight_nme in data.columns else None
42
+
43
+ if model_key == "resn":
44
+ if self.resn_best is None:
45
+ raise RuntimeError("ResNet model not trained.")
46
+ X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
47
+ if X is None:
48
+ raise RuntimeError("Missing standardized features for ResNet.")
49
+ X = X[self.var_nmes]
50
+ predict_fn = lambda df: self.resn_best.predict(df)
51
+ elif model_key == "ft":
52
+ if self.ft_best is None:
53
+ raise RuntimeError("FT model not trained.")
54
+ if str(self.config.ft_role) != "model":
55
+ raise RuntimeError("FT role is not 'model'; FT predictions unavailable.")
56
+ X = data[self.factor_nmes]
57
+ geo_tokens = self.train_geo_tokens if on_train else self.test_geo_tokens
58
+ geo_np = None
59
+ if geo_tokens is not None:
60
+ geo_np = geo_tokens.to_numpy(dtype=np.float32, copy=False)
61
+ predict_fn = lambda df, geo=geo_np: self.ft_best.predict(df, geo_tokens=geo)
62
+ elif model_key == "xgb":
63
+ if self.xgb_best is None:
64
+ raise RuntimeError("XGB model not trained.")
65
+ X = data[self.factor_nmes]
66
+ predict_fn = lambda df: self.xgb_best.predict(df)
67
+ else:
68
+ raise ValueError("Unsupported model_key for permutation importance.")
69
+
70
+ return explain_permutation.permutation_importance(
71
+ predict_fn,
72
+ X,
73
+ y,
74
+ sample_weight=w,
75
+ metric=metric,
76
+ task_type=self.task_type,
77
+ n_repeats=n_repeats,
78
+ random_state=random_state,
79
+ max_rows=max_rows,
80
+ )
81
+
82
+ # ========= Deep explainability: Integrated Gradients =========
83
+
84
+ def compute_integrated_gradients_resn(self,
85
+ on_train: bool = True,
86
+ baseline: Any = None,
87
+ steps: int = 50,
88
+ batch_size: int = 256,
89
+ target: Optional[int] = None):
90
+ if explain_gradients is None:
91
+ raise RuntimeError("explain.gradients is not available.")
92
+ if self.resn_best is None:
93
+ raise RuntimeError("ResNet model not trained.")
94
+ X = self.train_oht_scl_data if on_train else self.test_oht_scl_data
95
+ if X is None:
96
+ raise RuntimeError("Missing standardized features for ResNet.")
97
+ X = X[self.var_nmes]
98
+ return explain_gradients.resnet_integrated_gradients(
99
+ self.resn_best,
100
+ X,
101
+ baseline=baseline,
102
+ steps=steps,
103
+ batch_size=batch_size,
104
+ target=target,
105
+ )
106
+
107
+
108
+ def compute_integrated_gradients_ft(self,
109
+ on_train: bool = True,
110
+ geo_tokens: Optional[np.ndarray] = None,
111
+ baseline_num: Any = None,
112
+ baseline_geo: Any = None,
113
+ steps: int = 50,
114
+ batch_size: int = 256,
115
+ target: Optional[int] = None):
116
+ if explain_gradients is None:
117
+ raise RuntimeError("explain.gradients is not available.")
118
+ if self.ft_best is None:
119
+ raise RuntimeError("FT model not trained.")
120
+ if str(self.config.ft_role) != "model":
121
+ raise RuntimeError("FT role is not 'model'; FT explanations unavailable.")
122
+
123
+ data = self.train_data if on_train else self.test_data
124
+ X = data[self.factor_nmes]
125
+
126
+ if geo_tokens is None and getattr(self.ft_best, "num_geo", 0) > 0:
127
+ tokens_df = self.train_geo_tokens if on_train else self.test_geo_tokens
128
+ if tokens_df is not None:
129
+ geo_tokens = tokens_df.to_numpy(dtype=np.float32, copy=False)
130
+
131
+ return explain_gradients.ft_integrated_gradients(
132
+ self.ft_best,
133
+ X,
134
+ geo_tokens=geo_tokens,
135
+ baseline_num=baseline_num,
136
+ baseline_geo=baseline_geo,
137
+ steps=steps,
138
+ batch_size=batch_size,
139
+ target=target,
140
+ )
141
+
142
+ def _sample_rows(self, data: pd.DataFrame, n: int) -> pd.DataFrame:
143
+ if len(data) == 0:
144
+ return data
145
+ return data.sample(min(len(data), n), random_state=self.rand_seed)
146
+
147
+ @staticmethod
148
+ def _shap_nsamples(arr: np.ndarray, max_nsamples: int = 300) -> int:
149
+ min_needed = arr.shape[1] + 2
150
+ return max(min_needed, min(max_nsamples, arr.shape[0] * arr.shape[1]))
151
+
152
+
153
+ def _build_ft_shap_matrix(self, data: pd.DataFrame) -> np.ndarray:
154
+ matrices = []
155
+ for col in self.factor_nmes:
156
+ s = data[col]
157
+ if col in self.cate_list:
158
+ cats = pd.Categorical(
159
+ s,
160
+ categories=self.cat_categories_for_shap[col]
161
+ )
162
+ codes = np.asarray(cats.codes, dtype=np.float64).reshape(-1, 1)
163
+ matrices.append(codes)
164
+ else:
165
+ vals = pd.to_numeric(s, errors="coerce")
166
+ arr = vals.to_numpy(dtype=np.float64, copy=True).reshape(-1, 1)
167
+ matrices.append(arr)
168
+ X_mat = np.concatenate(matrices, axis=1) # Result shape (N, F)
169
+ return X_mat
170
+
171
+
172
+ def _decode_ft_shap_matrix_to_df(self, X_mat: np.ndarray) -> pd.DataFrame:
173
+ data_dict = {}
174
+ for j, col in enumerate(self.factor_nmes):
175
+ col_vals = X_mat[:, j]
176
+ if col in self.cate_list:
177
+ cats = self.cat_categories_for_shap[col]
178
+ codes = np.round(col_vals).astype(int)
179
+ codes = np.clip(codes, -1, len(cats) - 1)
180
+ cat_series = pd.Categorical.from_codes(
181
+ codes,
182
+ categories=cats
183
+ )
184
+ data_dict[col] = cat_series
185
+ else:
186
+ data_dict[col] = col_vals.astype(float)
187
+
188
+ df = pd.DataFrame(data_dict, columns=self.factor_nmes)
189
+ for col in self.cate_list:
190
+ if col in df.columns:
191
+ df[col] = df[col].astype("category")
192
+ return df
193
+
194
+
195
+ def _build_glm_design(self, data: pd.DataFrame) -> pd.DataFrame:
196
+ X = data[self.var_nmes]
197
+ return sm.add_constant(X, has_constant='add')
198
+
199
+
200
+ def _compute_shap_core(self,
201
+ model_key: str,
202
+ n_background: int,
203
+ n_samples: int,
204
+ on_train: bool,
205
+ X_df: pd.DataFrame,
206
+ prep_fn,
207
+ predict_fn,
208
+ cleanup_fn=None):
209
+ if explain_shap is None:
210
+ raise RuntimeError("explain.shap_utils is not available.")
211
+ return explain_shap.compute_shap_core(
212
+ self,
213
+ model_key,
214
+ n_background,
215
+ n_samples,
216
+ on_train,
217
+ X_df=X_df,
218
+ prep_fn=prep_fn,
219
+ predict_fn=predict_fn,
220
+ cleanup_fn=cleanup_fn,
221
+ )
222
+
223
+ # ========= GLM SHAP explainability =========
224
+
225
+ def compute_shap_glm(self, n_background: int = 500,
226
+ n_samples: int = 200,
227
+ on_train: bool = True):
228
+ if explain_shap is None:
229
+ raise RuntimeError("explain.shap_utils is not available.")
230
+ self.shap_glm = explain_shap.compute_shap_glm(
231
+ self,
232
+ n_background=n_background,
233
+ n_samples=n_samples,
234
+ on_train=on_train,
235
+ )
236
+ return self.shap_glm
237
+
238
+ # ========= XGBoost SHAP explainability =========
239
+
240
+ def compute_shap_xgb(self, n_background: int = 500,
241
+ n_samples: int = 200,
242
+ on_train: bool = True):
243
+ if explain_shap is None:
244
+ raise RuntimeError("explain.shap_utils is not available.")
245
+ self.shap_xgb = explain_shap.compute_shap_xgb(
246
+ self,
247
+ n_background=n_background,
248
+ n_samples=n_samples,
249
+ on_train=on_train,
250
+ )
251
+ return self.shap_xgb
252
+
253
+ # ========= ResNet SHAP explainability =========
254
+
255
+ def _resn_predict_wrapper(self, X_np):
256
+ model = self.resn_best.resnet.to("cpu")
257
+ with torch.no_grad():
258
+ X_tensor = torch.tensor(X_np, dtype=torch.float32)
259
+ y_pred = model(X_tensor).cpu().numpy()
260
+ y_pred = np.clip(y_pred, 1e-6, None)
261
+ return y_pred.reshape(-1)
262
+
263
+
264
+ def compute_shap_resn(self, n_background: int = 500,
265
+ n_samples: int = 200,
266
+ on_train: bool = True):
267
+ if explain_shap is None:
268
+ raise RuntimeError("explain.shap_utils is not available.")
269
+ self.shap_resn = explain_shap.compute_shap_resn(
270
+ self,
271
+ n_background=n_background,
272
+ n_samples=n_samples,
273
+ on_train=on_train,
274
+ )
275
+ return self.shap_resn
276
+
277
+ # ========= FT-Transformer SHAP explainability =========
278
+
279
+ def _ft_shap_predict_wrapper(self, X_mat: np.ndarray) -> np.ndarray:
280
+ df_input = self._decode_ft_shap_matrix_to_df(X_mat)
281
+ y_pred = self.ft_best.predict(df_input)
282
+ return np.asarray(y_pred, dtype=np.float64).reshape(-1)
283
+
284
+
285
+ def compute_shap_ft(self, n_background: int = 500,
286
+ n_samples: int = 200,
287
+ on_train: bool = True):
288
+ if explain_shap is None:
289
+ raise RuntimeError("explain.shap_utils is not available.")
290
+ self.shap_ft = explain_shap.compute_shap_ft(
291
+ self,
292
+ n_background=n_background,
293
+ n_samples=n_samples,
294
+ on_train=on_train,
295
+ )
296
+ return self.shap_ft