ins-pricing 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +52 -50
  4. ins_pricing/cli/BayesOpt_incremental.py +39 -105
  5. ins_pricing/cli/Explain_Run.py +31 -23
  6. ins_pricing/cli/Explain_entry.py +532 -579
  7. ins_pricing/cli/Pricing_Run.py +31 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +11 -9
  9. ins_pricing/cli/utils/cli_common.py +256 -256
  10. ins_pricing/cli/utils/cli_config.py +375 -375
  11. ins_pricing/cli/utils/import_resolver.py +382 -365
  12. ins_pricing/cli/utils/notebook_utils.py +340 -340
  13. ins_pricing/cli/watchdog_run.py +209 -201
  14. ins_pricing/frontend/__init__.py +10 -10
  15. ins_pricing/frontend/example_workflows.py +1 -1
  16. ins_pricing/governance/__init__.py +20 -20
  17. ins_pricing/governance/release.py +159 -159
  18. ins_pricing/modelling/__init__.py +147 -92
  19. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +2 -2
  20. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  21. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +562 -562
  22. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +965 -964
  23. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  24. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +482 -548
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +915 -913
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +788 -785
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +448 -446
  29. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1308 -1308
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +3 -3
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +197 -198
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +344 -344
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +283 -283
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +346 -347
  36. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  37. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  38. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  39. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  40. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +623 -623
  42. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  43. ins_pricing/modelling/explain/__init__.py +55 -55
  44. ins_pricing/modelling/explain/metrics.py +27 -174
  45. ins_pricing/modelling/explain/permutation.py +237 -237
  46. ins_pricing/modelling/plotting/__init__.py +40 -36
  47. ins_pricing/modelling/plotting/compat.py +228 -0
  48. ins_pricing/modelling/plotting/curves.py +572 -572
  49. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  50. ins_pricing/modelling/plotting/geo.py +362 -362
  51. ins_pricing/modelling/plotting/importance.py +121 -121
  52. ins_pricing/pricing/__init__.py +27 -27
  53. ins_pricing/production/__init__.py +35 -25
  54. ins_pricing/production/{predict.py → inference.py} +140 -57
  55. ins_pricing/production/monitoring.py +8 -21
  56. ins_pricing/reporting/__init__.py +11 -11
  57. ins_pricing/setup.py +1 -1
  58. ins_pricing/tests/production/test_inference.py +90 -0
  59. ins_pricing/utils/__init__.py +116 -83
  60. ins_pricing/utils/device.py +255 -255
  61. ins_pricing/utils/features.py +53 -0
  62. ins_pricing/utils/io.py +72 -0
  63. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  64. ins_pricing/utils/metrics.py +158 -24
  65. ins_pricing/utils/numerics.py +76 -0
  66. ins_pricing/utils/paths.py +9 -1
  67. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/METADATA +182 -182
  68. ins_pricing-0.5.0.dist-info/RECORD +131 -0
  69. ins_pricing/modelling/core/BayesOpt.py +0 -146
  70. ins_pricing/modelling/core/__init__.py +0 -1
  71. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  72. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  73. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  74. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  75. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  76. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  77. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  78. ins_pricing/tests/production/test_predict.py +0 -233
  79. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  80. /ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +0 -0
  81. /ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +0 -0
  82. /ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +0 -0
  83. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/WHEEL +0 -0
  84. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,586 +1,539 @@
1
- """Config-driven explain runner for trained BayesOpt models."""
2
-
3
- from __future__ import annotations
4
-
1
+ """Config-driven explain runner for trained BayesOpt models."""
2
+
3
+ from __future__ import annotations
4
+
5
5
  from pathlib import Path
6
+ import importlib.util
6
7
  import sys
7
8
 
8
9
  if __package__ in {None, ""}:
9
- repo_root = Path(__file__).resolve().parents[2]
10
- if str(repo_root) not in sys.path:
11
- sys.path.insert(0, str(repo_root))
12
-
13
- import argparse
14
- import json
15
- from typing import Any, Dict, List, Optional, Sequence
16
-
10
+ if importlib.util.find_spec("ins_pricing") is None:
11
+ repo_root = Path(__file__).resolve().parents[2]
12
+ if str(repo_root) not in sys.path:
13
+ sys.path.insert(0, str(repo_root))
14
+
15
+ import argparse
16
+ import json
17
+ from typing import Any, Dict, List, Optional, Sequence
18
+
17
19
  import numpy as np
18
20
  import pandas as pd
19
21
 
20
- try:
21
- from .. import bayesopt as ropt # type: ignore
22
- from .utils.cli_common import ( # type: ignore
23
- build_model_names,
24
- dedupe_preserve_order,
25
- load_dataset,
26
- resolve_data_path,
27
- coerce_dataset_types,
28
- split_train_test,
29
- )
30
- from .utils.cli_config import ( # type: ignore
31
- add_config_json_arg,
32
- add_output_dir_arg,
33
- resolve_and_load_config,
34
- resolve_data_config,
35
- resolve_explain_output_overrides,
36
- resolve_explain_save_dir,
37
- resolve_explain_save_root,
38
- resolve_model_path_value,
39
- resolve_split_config,
40
- resolve_runtime_config,
41
- resolve_output_dirs,
42
- )
43
- except Exception: # pragma: no cover
44
- try:
45
- import bayesopt as ropt # type: ignore
46
- from utils.cli_common import ( # type: ignore
47
- build_model_names,
48
- dedupe_preserve_order,
49
- load_dataset,
50
- resolve_data_path,
51
- coerce_dataset_types,
52
- split_train_test,
53
- )
54
- from utils.cli_config import ( # type: ignore
55
- add_config_json_arg,
56
- add_output_dir_arg,
57
- resolve_and_load_config,
58
- resolve_data_config,
59
- resolve_explain_output_overrides,
60
- resolve_explain_save_dir,
61
- resolve_explain_save_root,
62
- resolve_model_path_value,
63
- resolve_split_config,
64
- resolve_runtime_config,
65
- resolve_output_dirs,
66
- )
67
- except Exception:
68
- import ins_pricing.modelling.core.bayesopt as ropt # type: ignore
69
- from ins_pricing.cli.utils.cli_common import ( # type: ignore
70
- build_model_names,
71
- dedupe_preserve_order,
72
- load_dataset,
73
- resolve_data_path,
74
- coerce_dataset_types,
75
- split_train_test,
76
- )
77
- from ins_pricing.cli.utils.cli_config import ( # type: ignore
78
- add_config_json_arg,
79
- add_output_dir_arg,
80
- resolve_and_load_config,
81
- resolve_data_config,
82
- resolve_explain_output_overrides,
83
- resolve_explain_save_dir,
84
- resolve_explain_save_root,
85
- resolve_model_path_value,
86
- resolve_split_config,
87
- resolve_runtime_config,
88
- resolve_output_dirs,
89
- )
90
-
91
- try:
92
- from .utils.run_logging import configure_run_logging # type: ignore
93
- except Exception: # pragma: no cover
94
- try:
95
- from utils.run_logging import configure_run_logging # type: ignore
96
- except Exception: # pragma: no cover
97
- configure_run_logging = None # type: ignore
98
-
99
-
100
- _SUPPORTED_METHODS = {"permutation", "shap", "integrated_gradients"}
101
- _METHOD_ALIASES = {
102
- "ig": "integrated_gradients",
103
- "integrated": "integrated_gradients",
104
- "intgrad": "integrated_gradients",
105
- }
106
-
107
-
108
- def _safe_name(value: str) -> str:
109
- return "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in str(value))
110
-
111
-
112
- def _load_dataset(
113
- path: Path,
114
- *,
115
- data_format: str,
116
- dtype_map: Optional[Dict[str, Any]],
117
- ) -> pd.DataFrame:
118
- raw = load_dataset(
119
- path,
120
- data_format=data_format,
121
- dtype_map=dtype_map,
122
- low_memory=False,
123
- )
124
- return coerce_dataset_types(raw)
125
-
126
-
127
- def _normalize_methods(raw: Sequence[str]) -> List[str]:
128
- methods: List[str] = []
129
- for item in raw:
130
- key = str(item).strip().lower()
131
- if not key:
132
- continue
133
- key = _METHOD_ALIASES.get(key, key)
134
- if key not in _SUPPORTED_METHODS:
135
- raise ValueError(f"Unsupported explain method: {item}")
136
- methods.append(key)
137
- return dedupe_preserve_order(methods)
138
-
139
-
140
- def _save_series(series: pd.Series, path: Path) -> None:
141
- path.parent.mkdir(parents=True, exist_ok=True)
142
- series.to_frame(name="importance").to_csv(path, index=True)
143
-
144
-
145
- def _save_df(df: pd.DataFrame, path: Path) -> None:
146
- path.parent.mkdir(parents=True, exist_ok=True)
147
- df.to_csv(path, index=False)
148
-
149
-
150
- def _shap_importance(values: Any, feature_names: Sequence[str]) -> pd.Series:
151
- if isinstance(values, list):
152
- values = values[0]
153
- arr = np.asarray(values)
154
- if arr.ndim == 3:
155
- arr = arr[0]
156
- scores = np.mean(np.abs(arr), axis=0)
157
- return pd.Series(scores, index=list(feature_names)).sort_values(ascending=False)
158
-
159
-
160
- def _parse_args() -> argparse.Namespace:
161
- parser = argparse.ArgumentParser(
162
- description="Run explainability (permutation/SHAP/IG) on trained models."
163
- )
164
- add_config_json_arg(
165
- parser,
166
- help_text="Path to config.json (same schema as training).",
167
- )
168
- parser.add_argument(
169
- "--model-keys",
170
- nargs="+",
171
- default=None,
172
- choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
173
- help="Model keys to load for explanation (default from config.explain.model_keys).",
174
- )
175
- parser.add_argument(
176
- "--methods",
177
- nargs="+",
178
- default=None,
179
- help="Explain methods: permutation, shap, integrated_gradients (default from config.explain.methods).",
180
- )
181
- add_output_dir_arg(
182
- parser,
183
- help_text="Override output root for loading models/results.",
184
- )
185
- parser.add_argument(
186
- "--eval-path",
187
- default=None,
188
- help="Override validation CSV path (supports {model_name}).",
189
- )
190
- parser.add_argument(
191
- "--on-train",
192
- action="store_true",
193
- help="Explain on train split instead of validation/test.",
194
- )
195
- parser.add_argument(
196
- "--save-dir",
197
- default=None,
198
- help="Override output directory for explanation artifacts.",
199
- )
200
- return parser.parse_args()
201
-
202
-
203
- def _explain_for_model(
204
- model: ropt.BayesOptModel,
205
- *,
206
- model_name: str,
207
- model_keys: List[str],
208
- methods: List[str],
209
- on_train: bool,
210
- save_dir: Path,
211
- explain_cfg: Dict[str, Any],
212
- ) -> None:
213
- perm_cfg = dict(explain_cfg.get("permutation") or {})
214
- shap_cfg = dict(explain_cfg.get("shap") or {})
215
- ig_cfg = dict(explain_cfg.get("integrated_gradients") or {})
216
-
217
- perm_metric = perm_cfg.get("metric", explain_cfg.get("metric", "auto"))
218
- perm_repeats = int(perm_cfg.get("n_repeats", 5))
219
- perm_max_rows = perm_cfg.get("max_rows", 5000)
220
- perm_random_state = perm_cfg.get("random_state", None)
221
-
222
- shap_background = int(shap_cfg.get("n_background", 500))
223
- shap_samples = int(shap_cfg.get("n_samples", 200))
224
- shap_save_values = bool(shap_cfg.get("save_values", False))
225
-
226
- ig_steps = int(ig_cfg.get("steps", 50))
227
- ig_batch_size = int(ig_cfg.get("batch_size", 256))
228
- ig_target = ig_cfg.get("target", None)
229
- ig_baseline = ig_cfg.get("baseline", None)
230
- ig_baseline_num = ig_cfg.get("baseline_num", None)
231
- ig_baseline_geo = ig_cfg.get("baseline_geo", None)
232
- ig_save_values = bool(ig_cfg.get("save_values", False))
233
-
234
- for key in model_keys:
235
- trainer = model.trainers.get(key)
236
- if trainer is None:
237
- print(f"[Explain] Skip {model_name}/{key}: trainer not available.")
238
- continue
239
- model.load_model(key)
240
- trained_model = getattr(model, f"{key}_best", None)
241
- if trained_model is None:
242
- print(f"[Explain] Skip {model_name}/{key}: model not loaded.")
243
- continue
244
-
245
- if key == "ft" and str(model.config.ft_role) != "model":
246
- print(f"[Explain] Skip {model_name}/ft: ft_role != 'model'.")
247
- continue
248
-
249
- for method in methods:
250
- if method == "permutation" and key not in {"xgb", "resn", "ft"}:
251
- print(f"[Explain] Skip permutation for {model_name}/{key}.")
252
- continue
253
- if method == "shap" and key not in {"glm", "xgb", "resn", "ft"}:
254
- print(f"[Explain] Skip shap for {model_name}/{key}.")
255
- continue
256
- if method == "integrated_gradients" and key not in {"resn", "ft"}:
257
- print(f"[Explain] Skip integrated gradients for {model_name}/{key}.")
258
- continue
259
-
260
- if method == "permutation":
261
- try:
262
- result = model.compute_permutation_importance(
263
- key,
264
- on_train=on_train,
265
- metric=perm_metric,
266
- n_repeats=perm_repeats,
267
- max_rows=perm_max_rows,
268
- random_state=perm_random_state,
269
- )
270
- except Exception as exc:
271
- print(f"[Explain] permutation failed for {model_name}/{key}: {exc}")
272
- continue
273
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_permutation.csv"
274
- _save_df(result, out_path)
275
- print(f"[Explain] Saved permutation -> {out_path}")
276
-
277
- if method == "shap":
278
- try:
279
- if key == "glm":
280
- shap_result = model.compute_shap_glm(
281
- n_background=shap_background,
282
- n_samples=shap_samples,
283
- on_train=on_train,
284
- )
285
- elif key == "xgb":
286
- shap_result = model.compute_shap_xgb(
287
- n_background=shap_background,
288
- n_samples=shap_samples,
289
- on_train=on_train,
290
- )
291
- elif key == "resn":
292
- shap_result = model.compute_shap_resn(
293
- n_background=shap_background,
294
- n_samples=shap_samples,
295
- on_train=on_train,
296
- )
297
- else:
298
- shap_result = model.compute_shap_ft(
299
- n_background=shap_background,
300
- n_samples=shap_samples,
301
- on_train=on_train,
302
- )
303
- except Exception as exc:
304
- print(f"[Explain] shap failed for {model_name}/{key}: {exc}")
305
- continue
306
-
307
- shap_values = shap_result.get("shap_values")
308
- X_explain = shap_result.get("X_explain")
309
- feature_names = (
310
- list(X_explain.columns)
311
- if isinstance(X_explain, pd.DataFrame)
312
- else list(model.factor_nmes)
313
- )
314
- importance = _shap_importance(shap_values, feature_names)
315
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_importance.csv"
316
- _save_series(importance, out_path)
317
- print(f"[Explain] Saved SHAP importance -> {out_path}")
318
-
319
- if shap_save_values:
320
- values_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_values.npy"
321
- np.save(values_path, np.array(shap_values, dtype=object), allow_pickle=True)
322
- if isinstance(X_explain, pd.DataFrame):
323
- x_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_X.csv"
324
- _save_df(X_explain, x_path)
325
- meta_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_meta.json"
326
- meta = {
327
- "base_value": shap_result.get("base_value"),
328
- "n_samples": int(len(X_explain)) if X_explain is not None else None,
329
- }
330
- meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
331
-
332
- if method == "integrated_gradients":
333
- try:
334
- if key == "resn":
335
- ig_result = model.compute_integrated_gradients_resn(
336
- on_train=on_train,
337
- baseline=ig_baseline,
338
- steps=ig_steps,
339
- batch_size=ig_batch_size,
340
- target=ig_target,
341
- )
342
- series = ig_result.get("importance")
343
- if isinstance(series, pd.Series):
344
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_importance.csv"
345
- _save_series(series, out_path)
346
- print(f"[Explain] Saved IG importance -> {out_path}")
347
- if ig_save_values and "attributions" in ig_result:
348
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_attributions.npy"
349
- np.save(attr_path, ig_result.get("attributions"))
350
- else:
351
- ig_result = model.compute_integrated_gradients_ft(
352
- on_train=on_train,
353
- baseline_num=ig_baseline_num,
354
- baseline_geo=ig_baseline_geo,
355
- steps=ig_steps,
356
- batch_size=ig_batch_size,
357
- target=ig_target,
358
- )
359
- series_num = ig_result.get("importance_num")
360
- series_geo = ig_result.get("importance_geo")
361
- if isinstance(series_num, pd.Series):
362
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_importance.csv"
363
- _save_series(series_num, out_path)
364
- print(f"[Explain] Saved IG num importance -> {out_path}")
365
- if isinstance(series_geo, pd.Series):
366
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_importance.csv"
367
- _save_series(series_geo, out_path)
368
- print(f"[Explain] Saved IG geo importance -> {out_path}")
369
- if ig_save_values:
370
- if ig_result.get("attributions_num") is not None:
371
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_attributions.npy"
372
- np.save(attr_path, ig_result.get("attributions_num"))
373
- if ig_result.get("attributions_geo") is not None:
374
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_attributions.npy"
375
- np.save(attr_path, ig_result.get("attributions_geo"))
376
- except Exception as exc:
377
- print(f"[Explain] integrated gradients failed for {model_name}/{key}: {exc}")
378
- continue
379
-
380
-
381
- def explain_from_config(args: argparse.Namespace) -> None:
382
- script_dir = Path(__file__).resolve().parents[1]
383
- config_path, cfg = resolve_and_load_config(
384
- args.config_json,
385
- script_dir,
386
- required_keys=["data_dir", "model_list", "model_categories", "target", "weight"],
387
- )
388
-
389
- data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
390
- cfg,
391
- config_path,
392
- create_data_dir=True,
393
- )
394
-
395
- runtime_cfg = resolve_runtime_config(cfg)
396
- output_cfg = resolve_output_dirs(
397
- cfg,
398
- config_path,
399
- output_override=args.output_dir,
400
- )
401
- output_dir = output_cfg["output_dir"]
402
-
403
- split_cfg = resolve_split_config(cfg)
404
- prop_test = split_cfg["prop_test"]
405
- rand_seed = runtime_cfg["rand_seed"]
406
- split_strategy = split_cfg["split_strategy"]
407
- split_group_col = split_cfg["split_group_col"]
408
- split_time_col = split_cfg["split_time_col"]
409
- split_time_ascending = split_cfg["split_time_ascending"]
410
-
411
- explain_cfg = dict(cfg.get("explain") or {})
412
-
413
- model_keys = args.model_keys or explain_cfg.get("model_keys") or ["xgb"]
414
- if "all" in model_keys:
415
- model_keys = ["glm", "xgb", "resn", "ft", "gnn"]
416
- model_keys = dedupe_preserve_order([str(x) for x in model_keys])
417
-
418
- method_list = args.methods or explain_cfg.get("methods") or ["permutation"]
419
- methods = _normalize_methods([str(x) for x in method_list])
420
-
421
- on_train = bool(args.on_train or explain_cfg.get("on_train", False))
422
-
423
- model_names = build_model_names(cfg["model_list"], cfg["model_categories"])
424
- if not model_names:
425
- raise ValueError("No model names generated from model_list/model_categories.")
426
-
427
- save_root = resolve_explain_save_root(
428
- args.save_dir or explain_cfg.get("save_dir"),
429
- config_path.parent,
430
- )
431
-
432
- for model_name in model_names:
433
- train_path = resolve_model_path_value(
434
- explain_cfg.get("train_path"),
435
- model_name=model_name,
436
- base_dir=config_path.parent,
437
- data_dir=data_dir,
438
- )
439
- if train_path is None:
440
- train_path = resolve_data_path(
441
- data_dir,
442
- model_name,
443
- data_format=data_format,
444
- path_template=data_path_template,
445
- )
446
- if not train_path.exists():
447
- raise FileNotFoundError(f"Missing training dataset: {train_path}")
448
-
449
- validation_override = args.eval_path or explain_cfg.get("validation_path") or explain_cfg.get("eval_path")
450
- validation_path = resolve_model_path_value(
451
- validation_override,
452
- model_name=model_name,
453
- base_dir=config_path.parent,
454
- data_dir=data_dir,
455
- )
456
-
457
- raw = _load_dataset(
458
- train_path,
459
- data_format=data_format,
460
- dtype_map=dtype_map,
461
- )
462
- if validation_path is not None:
463
- if not validation_path.exists():
464
- raise FileNotFoundError(f"Missing validation dataset: {validation_path}")
465
- train_df = raw
466
- test_df = _load_dataset(
467
- validation_path,
468
- data_format=data_format,
469
- dtype_map=dtype_map,
470
- )
471
- else:
472
- if float(prop_test) <= 0:
473
- train_df = raw
474
- test_df = raw.copy()
475
- else:
476
- train_df, test_df = split_train_test(
477
- raw,
478
- holdout_ratio=prop_test,
479
- strategy=split_strategy,
480
- group_col=split_group_col,
481
- time_col=split_time_col,
482
- time_ascending=split_time_ascending,
483
- rand_seed=rand_seed,
484
- reset_index_mode="time_group",
485
- ratio_label="prop_test",
486
- include_strategy_in_ratio_error=True,
487
- )
488
-
489
- binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
490
- feature_list = cfg.get("feature_list")
491
- categorical_features = cfg.get("categorical_features")
492
- plot_path_style = runtime_cfg["plot_path_style"]
493
-
494
- config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
495
- allowed_config_keys = set(config_fields.keys())
496
- config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
497
- config_payload.update({
498
- "model_nme": model_name,
499
- "resp_nme": cfg["target"],
500
- "weight_nme": cfg["weight"],
501
- "factor_nmes": feature_list,
502
- "task_type": str(cfg.get("task_type", "regression")),
503
- "binary_resp_nme": binary_target,
504
- "cate_list": categorical_features,
505
- "prop_test": prop_test,
506
- "rand_seed": rand_seed,
507
- "epochs": int(runtime_cfg["epochs"]),
508
- "use_gpu": bool(cfg.get("use_gpu", True)),
509
- "output_dir": output_dir,
510
- "xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
511
- "xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
512
- "resn_weight_decay": cfg.get("resn_weight_decay"),
513
- "final_ensemble": bool(cfg.get("final_ensemble", False)),
514
- "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
515
- "final_refit": bool(cfg.get("final_refit", True)),
516
- "optuna_storage": runtime_cfg["optuna_storage"],
517
- "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
518
- "best_params_files": runtime_cfg["best_params_files"],
519
- "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
520
- "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
521
- "gnn_graph_cache": cfg.get("gnn_graph_cache"),
522
- "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
523
- "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
524
- "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
525
- "region_province_col": cfg.get("region_province_col"),
526
- "region_city_col": cfg.get("region_city_col"),
527
- "region_effect_alpha": cfg.get("region_effect_alpha"),
528
- "geo_feature_nmes": cfg.get("geo_feature_nmes"),
529
- "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
530
- "geo_token_layers": cfg.get("geo_token_layers"),
531
- "geo_token_dropout": cfg.get("geo_token_dropout"),
532
- "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
533
- "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
534
- "geo_token_epochs": cfg.get("geo_token_epochs"),
535
- "ft_role": str(cfg.get("ft_role", "model")),
536
- "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
537
- "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
538
- "reuse_best_params": runtime_cfg["reuse_best_params"],
539
- "plot_path_style": plot_path_style or "nested",
540
- })
541
- config_payload = {k: v for k, v in config_payload.items() if v is not None}
542
- config = ropt.BayesOptConfig(**config_payload)
543
- model = ropt.BayesOptModel(train_df, test_df, config=config)
544
-
545
- output_overrides = resolve_explain_output_overrides(
546
- explain_cfg,
547
- model_name=model_name,
548
- base_dir=config_path.parent,
549
- )
550
- model_dir_override = output_overrides.get("model_dir")
551
- if model_dir_override is not None:
552
- model.output_manager.model_dir = model_dir_override
553
- result_dir_override = output_overrides.get("result_dir")
554
- if result_dir_override is not None:
555
- model.output_manager.result_dir = result_dir_override
556
- plot_dir_override = output_overrides.get("plot_dir")
557
- if plot_dir_override is not None:
558
- model.output_manager.plot_dir = plot_dir_override
559
-
560
- save_dir = resolve_explain_save_dir(
561
- save_root,
562
- result_dir=model.output_manager.result_dir,
563
- )
564
- save_dir.mkdir(parents=True, exist_ok=True)
565
-
566
- print(f"\n=== Explain model {model_name} ===")
567
- _explain_for_model(
568
- model,
569
- model_name=model_name,
570
- model_keys=model_keys,
571
- methods=methods,
572
- on_train=on_train,
573
- save_dir=save_dir,
574
- explain_cfg=explain_cfg,
575
- )
576
-
577
-
578
- def main() -> None:
579
- if configure_run_logging:
580
- configure_run_logging(prefix="explain_entry")
581
- args = _parse_args()
582
- explain_from_config(args)
583
-
584
-
585
- if __name__ == "__main__":
586
- main()
22
+ from ins_pricing.cli.utils.import_resolver import resolve_imports, setup_sys_path
23
+
24
+ setup_sys_path()
25
+ _imports = resolve_imports()
26
+
27
+ ropt = _imports.bayesopt
28
+ if ropt is None: # pragma: no cover
29
+ raise ImportError("Failed to resolve ins_pricing.bayesopt for explain CLI.")
30
+
31
+ build_model_names = _imports.build_model_names
32
+ dedupe_preserve_order = _imports.dedupe_preserve_order
33
+ load_dataset = _imports.load_dataset
34
+ resolve_data_path = _imports.resolve_data_path
35
+ coerce_dataset_types = _imports.coerce_dataset_types
36
+ split_train_test = _imports.split_train_test
37
+
38
+ add_config_json_arg = _imports.add_config_json_arg
39
+ add_output_dir_arg = _imports.add_output_dir_arg
40
+ resolve_and_load_config = _imports.resolve_and_load_config
41
+ resolve_data_config = _imports.resolve_data_config
42
+ resolve_explain_output_overrides = _imports.resolve_explain_output_overrides
43
+ resolve_explain_save_dir = _imports.resolve_explain_save_dir
44
+ resolve_explain_save_root = _imports.resolve_explain_save_root
45
+ resolve_model_path_value = _imports.resolve_model_path_value
46
+ resolve_split_config = _imports.resolve_split_config
47
+ resolve_runtime_config = _imports.resolve_runtime_config
48
+ resolve_output_dirs = _imports.resolve_output_dirs
49
+
50
+ configure_run_logging = _imports.configure_run_logging
51
+
52
+
53
+ _SUPPORTED_METHODS = {"permutation", "shap", "integrated_gradients"}
54
+ _METHOD_ALIASES = {
55
+ "ig": "integrated_gradients",
56
+ "integrated": "integrated_gradients",
57
+ "intgrad": "integrated_gradients",
58
+ }
59
+
60
+
61
+ def _safe_name(value: str) -> str:
62
+ return "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in str(value))
63
+
64
+
65
+ def _load_dataset(
66
+ path: Path,
67
+ *,
68
+ data_format: str,
69
+ dtype_map: Optional[Dict[str, Any]],
70
+ ) -> pd.DataFrame:
71
+ raw = load_dataset(
72
+ path,
73
+ data_format=data_format,
74
+ dtype_map=dtype_map,
75
+ low_memory=False,
76
+ )
77
+ return coerce_dataset_types(raw)
78
+
79
+
80
+ def _normalize_methods(raw: Sequence[str]) -> List[str]:
81
+ methods: List[str] = []
82
+ for item in raw:
83
+ key = str(item).strip().lower()
84
+ if not key:
85
+ continue
86
+ key = _METHOD_ALIASES.get(key, key)
87
+ if key not in _SUPPORTED_METHODS:
88
+ raise ValueError(f"Unsupported explain method: {item}")
89
+ methods.append(key)
90
+ return dedupe_preserve_order(methods)
91
+
92
+
93
+ def _save_series(series: pd.Series, path: Path) -> None:
94
+ path.parent.mkdir(parents=True, exist_ok=True)
95
+ series.to_frame(name="importance").to_csv(path, index=True)
96
+
97
+
98
+ def _save_df(df: pd.DataFrame, path: Path) -> None:
99
+ path.parent.mkdir(parents=True, exist_ok=True)
100
+ df.to_csv(path, index=False)
101
+
102
+
103
+ def _shap_importance(values: Any, feature_names: Sequence[str]) -> pd.Series:
104
+ if isinstance(values, list):
105
+ values = values[0]
106
+ arr = np.asarray(values)
107
+ if arr.ndim == 3:
108
+ arr = arr[0]
109
+ scores = np.mean(np.abs(arr), axis=0)
110
+ return pd.Series(scores, index=list(feature_names)).sort_values(ascending=False)
111
+
112
+
113
+ def _parse_args() -> argparse.Namespace:
114
+ parser = argparse.ArgumentParser(
115
+ description="Run explainability (permutation/SHAP/IG) on trained models."
116
+ )
117
+ add_config_json_arg(
118
+ parser,
119
+ help_text="Path to config.json (same schema as training).",
120
+ )
121
+ parser.add_argument(
122
+ "--model-keys",
123
+ nargs="+",
124
+ default=None,
125
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
126
+ help="Model keys to load for explanation (default from config.explain.model_keys).",
127
+ )
128
+ parser.add_argument(
129
+ "--methods",
130
+ nargs="+",
131
+ default=None,
132
+ help="Explain methods: permutation, shap, integrated_gradients (default from config.explain.methods).",
133
+ )
134
+ add_output_dir_arg(
135
+ parser,
136
+ help_text="Override output root for loading models/results.",
137
+ )
138
+ parser.add_argument(
139
+ "--eval-path",
140
+ default=None,
141
+ help="Override validation CSV path (supports {model_name}).",
142
+ )
143
+ parser.add_argument(
144
+ "--on-train",
145
+ action="store_true",
146
+ help="Explain on train split instead of validation/test.",
147
+ )
148
+ parser.add_argument(
149
+ "--save-dir",
150
+ default=None,
151
+ help="Override output directory for explanation artifacts.",
152
+ )
153
+ return parser.parse_args()
154
+
155
+
156
+ def _explain_for_model(
157
+ model: ropt.BayesOptModel,
158
+ *,
159
+ model_name: str,
160
+ model_keys: List[str],
161
+ methods: List[str],
162
+ on_train: bool,
163
+ save_dir: Path,
164
+ explain_cfg: Dict[str, Any],
165
+ ) -> None:
166
+ perm_cfg = dict(explain_cfg.get("permutation") or {})
167
+ shap_cfg = dict(explain_cfg.get("shap") or {})
168
+ ig_cfg = dict(explain_cfg.get("integrated_gradients") or {})
169
+
170
+ perm_metric = perm_cfg.get("metric", explain_cfg.get("metric", "auto"))
171
+ perm_repeats = int(perm_cfg.get("n_repeats", 5))
172
+ perm_max_rows = perm_cfg.get("max_rows", 5000)
173
+ perm_random_state = perm_cfg.get("random_state", None)
174
+
175
+ shap_background = int(shap_cfg.get("n_background", 500))
176
+ shap_samples = int(shap_cfg.get("n_samples", 200))
177
+ shap_save_values = bool(shap_cfg.get("save_values", False))
178
+
179
+ ig_steps = int(ig_cfg.get("steps", 50))
180
+ ig_batch_size = int(ig_cfg.get("batch_size", 256))
181
+ ig_target = ig_cfg.get("target", None)
182
+ ig_baseline = ig_cfg.get("baseline", None)
183
+ ig_baseline_num = ig_cfg.get("baseline_num", None)
184
+ ig_baseline_geo = ig_cfg.get("baseline_geo", None)
185
+ ig_save_values = bool(ig_cfg.get("save_values", False))
186
+
187
+ for key in model_keys:
188
+ trainer = model.trainers.get(key)
189
+ if trainer is None:
190
+ print(f"[Explain] Skip {model_name}/{key}: trainer not available.")
191
+ continue
192
+ model.load_model(key)
193
+ trained_model = getattr(model, f"{key}_best", None)
194
+ if trained_model is None:
195
+ print(f"[Explain] Skip {model_name}/{key}: model not loaded.")
196
+ continue
197
+
198
+ if key == "ft" and str(model.config.ft_role) != "model":
199
+ print(f"[Explain] Skip {model_name}/ft: ft_role != 'model'.")
200
+ continue
201
+
202
+ for method in methods:
203
+ if method == "permutation" and key not in {"xgb", "resn", "ft"}:
204
+ print(f"[Explain] Skip permutation for {model_name}/{key}.")
205
+ continue
206
+ if method == "shap" and key not in {"glm", "xgb", "resn", "ft"}:
207
+ print(f"[Explain] Skip shap for {model_name}/{key}.")
208
+ continue
209
+ if method == "integrated_gradients" and key not in {"resn", "ft"}:
210
+ print(f"[Explain] Skip integrated gradients for {model_name}/{key}.")
211
+ continue
212
+
213
+ if method == "permutation":
214
+ try:
215
+ result = model.compute_permutation_importance(
216
+ key,
217
+ on_train=on_train,
218
+ metric=perm_metric,
219
+ n_repeats=perm_repeats,
220
+ max_rows=perm_max_rows,
221
+ random_state=perm_random_state,
222
+ )
223
+ except Exception as exc:
224
+ print(f"[Explain] permutation failed for {model_name}/{key}: {exc}")
225
+ continue
226
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_permutation.csv"
227
+ _save_df(result, out_path)
228
+ print(f"[Explain] Saved permutation -> {out_path}")
229
+
230
+ if method == "shap":
231
+ try:
232
+ if key == "glm":
233
+ shap_result = model.compute_shap_glm(
234
+ n_background=shap_background,
235
+ n_samples=shap_samples,
236
+ on_train=on_train,
237
+ )
238
+ elif key == "xgb":
239
+ shap_result = model.compute_shap_xgb(
240
+ n_background=shap_background,
241
+ n_samples=shap_samples,
242
+ on_train=on_train,
243
+ )
244
+ elif key == "resn":
245
+ shap_result = model.compute_shap_resn(
246
+ n_background=shap_background,
247
+ n_samples=shap_samples,
248
+ on_train=on_train,
249
+ )
250
+ else:
251
+ shap_result = model.compute_shap_ft(
252
+ n_background=shap_background,
253
+ n_samples=shap_samples,
254
+ on_train=on_train,
255
+ )
256
+ except Exception as exc:
257
+ print(f"[Explain] shap failed for {model_name}/{key}: {exc}")
258
+ continue
259
+
260
+ shap_values = shap_result.get("shap_values")
261
+ X_explain = shap_result.get("X_explain")
262
+ feature_names = (
263
+ list(X_explain.columns)
264
+ if isinstance(X_explain, pd.DataFrame)
265
+ else list(model.factor_nmes)
266
+ )
267
+ importance = _shap_importance(shap_values, feature_names)
268
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_importance.csv"
269
+ _save_series(importance, out_path)
270
+ print(f"[Explain] Saved SHAP importance -> {out_path}")
271
+
272
+ if shap_save_values:
273
+ values_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_values.npy"
274
+ np.save(values_path, np.array(shap_values, dtype=object), allow_pickle=True)
275
+ if isinstance(X_explain, pd.DataFrame):
276
+ x_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_X.csv"
277
+ _save_df(X_explain, x_path)
278
+ meta_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_meta.json"
279
+ meta = {
280
+ "base_value": shap_result.get("base_value"),
281
+ "n_samples": int(len(X_explain)) if X_explain is not None else None,
282
+ }
283
+ meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
284
+
285
+ if method == "integrated_gradients":
286
+ try:
287
+ if key == "resn":
288
+ ig_result = model.compute_integrated_gradients_resn(
289
+ on_train=on_train,
290
+ baseline=ig_baseline,
291
+ steps=ig_steps,
292
+ batch_size=ig_batch_size,
293
+ target=ig_target,
294
+ )
295
+ series = ig_result.get("importance")
296
+ if isinstance(series, pd.Series):
297
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_importance.csv"
298
+ _save_series(series, out_path)
299
+ print(f"[Explain] Saved IG importance -> {out_path}")
300
+ if ig_save_values and "attributions" in ig_result:
301
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_attributions.npy"
302
+ np.save(attr_path, ig_result.get("attributions"))
303
+ else:
304
+ ig_result = model.compute_integrated_gradients_ft(
305
+ on_train=on_train,
306
+ baseline_num=ig_baseline_num,
307
+ baseline_geo=ig_baseline_geo,
308
+ steps=ig_steps,
309
+ batch_size=ig_batch_size,
310
+ target=ig_target,
311
+ )
312
+ series_num = ig_result.get("importance_num")
313
+ series_geo = ig_result.get("importance_geo")
314
+ if isinstance(series_num, pd.Series):
315
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_importance.csv"
316
+ _save_series(series_num, out_path)
317
+ print(f"[Explain] Saved IG num importance -> {out_path}")
318
+ if isinstance(series_geo, pd.Series):
319
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_importance.csv"
320
+ _save_series(series_geo, out_path)
321
+ print(f"[Explain] Saved IG geo importance -> {out_path}")
322
+ if ig_save_values:
323
+ if ig_result.get("attributions_num") is not None:
324
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_attributions.npy"
325
+ np.save(attr_path, ig_result.get("attributions_num"))
326
+ if ig_result.get("attributions_geo") is not None:
327
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_attributions.npy"
328
+ np.save(attr_path, ig_result.get("attributions_geo"))
329
+ except Exception as exc:
330
+ print(f"[Explain] integrated gradients failed for {model_name}/{key}: {exc}")
331
+ continue
332
+
333
+
334
+ def explain_from_config(args: argparse.Namespace) -> None:
335
+ script_dir = Path(__file__).resolve().parents[1]
336
+ config_path, cfg = resolve_and_load_config(
337
+ args.config_json,
338
+ script_dir,
339
+ required_keys=["data_dir", "model_list", "model_categories", "target", "weight"],
340
+ )
341
+
342
+ data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
343
+ cfg,
344
+ config_path,
345
+ create_data_dir=True,
346
+ )
347
+
348
+ runtime_cfg = resolve_runtime_config(cfg)
349
+ output_cfg = resolve_output_dirs(
350
+ cfg,
351
+ config_path,
352
+ output_override=args.output_dir,
353
+ )
354
+ output_dir = output_cfg["output_dir"]
355
+
356
+ split_cfg = resolve_split_config(cfg)
357
+ prop_test = split_cfg["prop_test"]
358
+ rand_seed = runtime_cfg["rand_seed"]
359
+ split_strategy = split_cfg["split_strategy"]
360
+ split_group_col = split_cfg["split_group_col"]
361
+ split_time_col = split_cfg["split_time_col"]
362
+ split_time_ascending = split_cfg["split_time_ascending"]
363
+
364
+ explain_cfg = dict(cfg.get("explain") or {})
365
+
366
+ model_keys = args.model_keys or explain_cfg.get("model_keys") or ["xgb"]
367
+ if "all" in model_keys:
368
+ model_keys = ["glm", "xgb", "resn", "ft", "gnn"]
369
+ model_keys = dedupe_preserve_order([str(x) for x in model_keys])
370
+
371
+ method_list = args.methods or explain_cfg.get("methods") or ["permutation"]
372
+ methods = _normalize_methods([str(x) for x in method_list])
373
+
374
+ on_train = bool(args.on_train or explain_cfg.get("on_train", False))
375
+
376
+ model_names = build_model_names(cfg["model_list"], cfg["model_categories"])
377
+ if not model_names:
378
+ raise ValueError("No model names generated from model_list/model_categories.")
379
+
380
+ save_root = resolve_explain_save_root(
381
+ args.save_dir or explain_cfg.get("save_dir"),
382
+ config_path.parent,
383
+ )
384
+
385
+ for model_name in model_names:
386
+ train_path = resolve_model_path_value(
387
+ explain_cfg.get("train_path"),
388
+ model_name=model_name,
389
+ base_dir=config_path.parent,
390
+ data_dir=data_dir,
391
+ )
392
+ if train_path is None:
393
+ train_path = resolve_data_path(
394
+ data_dir,
395
+ model_name,
396
+ data_format=data_format,
397
+ path_template=data_path_template,
398
+ )
399
+ if not train_path.exists():
400
+ raise FileNotFoundError(f"Missing training dataset: {train_path}")
401
+
402
+ validation_override = args.eval_path or explain_cfg.get("validation_path") or explain_cfg.get("eval_path")
403
+ validation_path = resolve_model_path_value(
404
+ validation_override,
405
+ model_name=model_name,
406
+ base_dir=config_path.parent,
407
+ data_dir=data_dir,
408
+ )
409
+
410
+ raw = _load_dataset(
411
+ train_path,
412
+ data_format=data_format,
413
+ dtype_map=dtype_map,
414
+ )
415
+ if validation_path is not None:
416
+ if not validation_path.exists():
417
+ raise FileNotFoundError(f"Missing validation dataset: {validation_path}")
418
+ train_df = raw
419
+ test_df = _load_dataset(
420
+ validation_path,
421
+ data_format=data_format,
422
+ dtype_map=dtype_map,
423
+ )
424
+ else:
425
+ if float(prop_test) <= 0:
426
+ train_df = raw
427
+ test_df = raw.copy()
428
+ else:
429
+ train_df, test_df = split_train_test(
430
+ raw,
431
+ holdout_ratio=prop_test,
432
+ strategy=split_strategy,
433
+ group_col=split_group_col,
434
+ time_col=split_time_col,
435
+ time_ascending=split_time_ascending,
436
+ rand_seed=rand_seed,
437
+ reset_index_mode="time_group",
438
+ ratio_label="prop_test",
439
+ include_strategy_in_ratio_error=True,
440
+ )
441
+
442
+ binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
443
+ feature_list = cfg.get("feature_list")
444
+ categorical_features = cfg.get("categorical_features")
445
+ plot_path_style = runtime_cfg["plot_path_style"]
446
+
447
+ config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
448
+ allowed_config_keys = set(config_fields.keys())
449
+ config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
450
+ config_payload.update({
451
+ "model_nme": model_name,
452
+ "resp_nme": cfg["target"],
453
+ "weight_nme": cfg["weight"],
454
+ "factor_nmes": feature_list,
455
+ "task_type": str(cfg.get("task_type", "regression")),
456
+ "binary_resp_nme": binary_target,
457
+ "cate_list": categorical_features,
458
+ "prop_test": prop_test,
459
+ "rand_seed": rand_seed,
460
+ "epochs": int(runtime_cfg["epochs"]),
461
+ "use_gpu": bool(cfg.get("use_gpu", True)),
462
+ "output_dir": output_dir,
463
+ "xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
464
+ "xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
465
+ "resn_weight_decay": cfg.get("resn_weight_decay"),
466
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
467
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
468
+ "final_refit": bool(cfg.get("final_refit", True)),
469
+ "optuna_storage": runtime_cfg["optuna_storage"],
470
+ "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
471
+ "best_params_files": runtime_cfg["best_params_files"],
472
+ "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
473
+ "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
474
+ "gnn_graph_cache": cfg.get("gnn_graph_cache"),
475
+ "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
476
+ "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
477
+ "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
478
+ "region_province_col": cfg.get("region_province_col"),
479
+ "region_city_col": cfg.get("region_city_col"),
480
+ "region_effect_alpha": cfg.get("region_effect_alpha"),
481
+ "geo_feature_nmes": cfg.get("geo_feature_nmes"),
482
+ "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
483
+ "geo_token_layers": cfg.get("geo_token_layers"),
484
+ "geo_token_dropout": cfg.get("geo_token_dropout"),
485
+ "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
486
+ "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
487
+ "geo_token_epochs": cfg.get("geo_token_epochs"),
488
+ "ft_role": str(cfg.get("ft_role", "model")),
489
+ "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
490
+ "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
491
+ "reuse_best_params": runtime_cfg["reuse_best_params"],
492
+ "plot_path_style": plot_path_style or "nested",
493
+ })
494
+ config_payload = {k: v for k, v in config_payload.items() if v is not None}
495
+ config = ropt.BayesOptConfig(**config_payload)
496
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
497
+
498
+ output_overrides = resolve_explain_output_overrides(
499
+ explain_cfg,
500
+ model_name=model_name,
501
+ base_dir=config_path.parent,
502
+ )
503
+ model_dir_override = output_overrides.get("model_dir")
504
+ if model_dir_override is not None:
505
+ model.output_manager.model_dir = model_dir_override
506
+ result_dir_override = output_overrides.get("result_dir")
507
+ if result_dir_override is not None:
508
+ model.output_manager.result_dir = result_dir_override
509
+ plot_dir_override = output_overrides.get("plot_dir")
510
+ if plot_dir_override is not None:
511
+ model.output_manager.plot_dir = plot_dir_override
512
+
513
+ save_dir = resolve_explain_save_dir(
514
+ save_root,
515
+ result_dir=model.output_manager.result_dir,
516
+ )
517
+ save_dir.mkdir(parents=True, exist_ok=True)
518
+
519
+ print(f"\n=== Explain model {model_name} ===")
520
+ _explain_for_model(
521
+ model,
522
+ model_name=model_name,
523
+ model_keys=model_keys,
524
+ methods=methods,
525
+ on_train=on_train,
526
+ save_dir=save_dir,
527
+ explain_cfg=explain_cfg,
528
+ )
529
+
530
+
531
+ def main() -> None:
532
+ if configure_run_logging:
533
+ configure_run_logging(prefix="explain_entry")
534
+ args = _parse_args()
535
+ explain_from_config(args)
536
+
537
+
538
+ if __name__ == "__main__":
539
+ main()