ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,586 +1,560 @@
1
- """Config-driven explain runner for trained BayesOpt models."""
2
-
3
- from __future__ import annotations
4
-
1
+ """Config-driven explain runner for trained BayesOpt models."""
2
+
3
+ from __future__ import annotations
4
+
5
5
  from pathlib import Path
6
+ import importlib.util
6
7
  import sys
7
8
 
8
- if __package__ in {None, ""}:
9
- repo_root = Path(__file__).resolve().parents[2]
10
- if str(repo_root) not in sys.path:
11
- sys.path.insert(0, str(repo_root))
12
-
13
- import argparse
14
- import json
15
- from typing import Any, Dict, List, Optional, Sequence
16
-
9
+ def _ensure_repo_root() -> None:
10
+ if __package__ not in {None, ""}:
11
+ return
12
+ if importlib.util.find_spec("ins_pricing") is not None:
13
+ return
14
+ bootstrap_path = Path(__file__).resolve().parents[1] / "utils" / "bootstrap.py"
15
+ spec = importlib.util.spec_from_file_location("ins_pricing.cli.utils.bootstrap", bootstrap_path)
16
+ if spec is None or spec.loader is None:
17
+ return
18
+ module = importlib.util.module_from_spec(spec)
19
+ spec.loader.exec_module(module)
20
+ module.ensure_repo_root()
21
+
22
+
23
+ _ensure_repo_root()
24
+
25
+ import argparse
26
+ import json
27
+ from typing import Any, Dict, List, Optional, Sequence
28
+
17
29
  import numpy as np
18
30
  import pandas as pd
19
31
 
20
- try:
21
- from .. import bayesopt as ropt # type: ignore
22
- from .utils.cli_common import ( # type: ignore
23
- build_model_names,
24
- dedupe_preserve_order,
25
- load_dataset,
26
- resolve_data_path,
27
- coerce_dataset_types,
28
- split_train_test,
29
- )
30
- from .utils.cli_config import ( # type: ignore
31
- add_config_json_arg,
32
- add_output_dir_arg,
33
- resolve_and_load_config,
34
- resolve_data_config,
35
- resolve_explain_output_overrides,
36
- resolve_explain_save_dir,
37
- resolve_explain_save_root,
38
- resolve_model_path_value,
39
- resolve_split_config,
40
- resolve_runtime_config,
41
- resolve_output_dirs,
42
- )
43
- except Exception: # pragma: no cover
44
- try:
45
- import bayesopt as ropt # type: ignore
46
- from utils.cli_common import ( # type: ignore
47
- build_model_names,
48
- dedupe_preserve_order,
49
- load_dataset,
50
- resolve_data_path,
51
- coerce_dataset_types,
52
- split_train_test,
53
- )
54
- from utils.cli_config import ( # type: ignore
55
- add_config_json_arg,
56
- add_output_dir_arg,
57
- resolve_and_load_config,
58
- resolve_data_config,
59
- resolve_explain_output_overrides,
60
- resolve_explain_save_dir,
61
- resolve_explain_save_root,
62
- resolve_model_path_value,
63
- resolve_split_config,
64
- resolve_runtime_config,
65
- resolve_output_dirs,
66
- )
67
- except Exception:
68
- import ins_pricing.modelling.core.bayesopt as ropt # type: ignore
69
- from ins_pricing.cli.utils.cli_common import ( # type: ignore
70
- build_model_names,
71
- dedupe_preserve_order,
72
- load_dataset,
73
- resolve_data_path,
74
- coerce_dataset_types,
75
- split_train_test,
76
- )
77
- from ins_pricing.cli.utils.cli_config import ( # type: ignore
78
- add_config_json_arg,
79
- add_output_dir_arg,
80
- resolve_and_load_config,
81
- resolve_data_config,
82
- resolve_explain_output_overrides,
83
- resolve_explain_save_dir,
84
- resolve_explain_save_root,
85
- resolve_model_path_value,
86
- resolve_split_config,
87
- resolve_runtime_config,
88
- resolve_output_dirs,
89
- )
90
-
91
- try:
92
- from .utils.run_logging import configure_run_logging # type: ignore
93
- except Exception: # pragma: no cover
94
- try:
95
- from utils.run_logging import configure_run_logging # type: ignore
96
- except Exception: # pragma: no cover
97
- configure_run_logging = None # type: ignore
98
-
99
-
100
- _SUPPORTED_METHODS = {"permutation", "shap", "integrated_gradients"}
101
- _METHOD_ALIASES = {
102
- "ig": "integrated_gradients",
103
- "integrated": "integrated_gradients",
104
- "intgrad": "integrated_gradients",
105
- }
106
-
107
-
108
- def _safe_name(value: str) -> str:
109
- return "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in str(value))
110
-
111
-
112
- def _load_dataset(
113
- path: Path,
114
- *,
115
- data_format: str,
116
- dtype_map: Optional[Dict[str, Any]],
117
- ) -> pd.DataFrame:
118
- raw = load_dataset(
119
- path,
120
- data_format=data_format,
121
- dtype_map=dtype_map,
122
- low_memory=False,
123
- )
124
- return coerce_dataset_types(raw)
125
-
126
-
127
- def _normalize_methods(raw: Sequence[str]) -> List[str]:
128
- methods: List[str] = []
129
- for item in raw:
130
- key = str(item).strip().lower()
131
- if not key:
132
- continue
133
- key = _METHOD_ALIASES.get(key, key)
134
- if key not in _SUPPORTED_METHODS:
135
- raise ValueError(f"Unsupported explain method: {item}")
136
- methods.append(key)
137
- return dedupe_preserve_order(methods)
138
-
139
-
140
- def _save_series(series: pd.Series, path: Path) -> None:
141
- path.parent.mkdir(parents=True, exist_ok=True)
142
- series.to_frame(name="importance").to_csv(path, index=True)
143
-
144
-
145
- def _save_df(df: pd.DataFrame, path: Path) -> None:
146
- path.parent.mkdir(parents=True, exist_ok=True)
147
- df.to_csv(path, index=False)
148
-
149
-
150
- def _shap_importance(values: Any, feature_names: Sequence[str]) -> pd.Series:
151
- if isinstance(values, list):
152
- values = values[0]
153
- arr = np.asarray(values)
154
- if arr.ndim == 3:
155
- arr = arr[0]
156
- scores = np.mean(np.abs(arr), axis=0)
157
- return pd.Series(scores, index=list(feature_names)).sort_values(ascending=False)
158
-
159
-
160
- def _parse_args() -> argparse.Namespace:
161
- parser = argparse.ArgumentParser(
162
- description="Run explainability (permutation/SHAP/IG) on trained models."
163
- )
164
- add_config_json_arg(
165
- parser,
166
- help_text="Path to config.json (same schema as training).",
167
- )
168
- parser.add_argument(
169
- "--model-keys",
170
- nargs="+",
171
- default=None,
172
- choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
173
- help="Model keys to load for explanation (default from config.explain.model_keys).",
174
- )
175
- parser.add_argument(
176
- "--methods",
177
- nargs="+",
178
- default=None,
179
- help="Explain methods: permutation, shap, integrated_gradients (default from config.explain.methods).",
180
- )
181
- add_output_dir_arg(
182
- parser,
183
- help_text="Override output root for loading models/results.",
184
- )
185
- parser.add_argument(
186
- "--eval-path",
187
- default=None,
188
- help="Override validation CSV path (supports {model_name}).",
189
- )
190
- parser.add_argument(
191
- "--on-train",
192
- action="store_true",
193
- help="Explain on train split instead of validation/test.",
194
- )
195
- parser.add_argument(
196
- "--save-dir",
197
- default=None,
198
- help="Override output directory for explanation artifacts.",
199
- )
200
- return parser.parse_args()
201
-
202
-
203
- def _explain_for_model(
204
- model: ropt.BayesOptModel,
205
- *,
206
- model_name: str,
207
- model_keys: List[str],
208
- methods: List[str],
209
- on_train: bool,
210
- save_dir: Path,
211
- explain_cfg: Dict[str, Any],
212
- ) -> None:
213
- perm_cfg = dict(explain_cfg.get("permutation") or {})
214
- shap_cfg = dict(explain_cfg.get("shap") or {})
215
- ig_cfg = dict(explain_cfg.get("integrated_gradients") or {})
216
-
217
- perm_metric = perm_cfg.get("metric", explain_cfg.get("metric", "auto"))
218
- perm_repeats = int(perm_cfg.get("n_repeats", 5))
219
- perm_max_rows = perm_cfg.get("max_rows", 5000)
220
- perm_random_state = perm_cfg.get("random_state", None)
221
-
222
- shap_background = int(shap_cfg.get("n_background", 500))
223
- shap_samples = int(shap_cfg.get("n_samples", 200))
224
- shap_save_values = bool(shap_cfg.get("save_values", False))
225
-
226
- ig_steps = int(ig_cfg.get("steps", 50))
227
- ig_batch_size = int(ig_cfg.get("batch_size", 256))
228
- ig_target = ig_cfg.get("target", None)
229
- ig_baseline = ig_cfg.get("baseline", None)
230
- ig_baseline_num = ig_cfg.get("baseline_num", None)
231
- ig_baseline_geo = ig_cfg.get("baseline_geo", None)
232
- ig_save_values = bool(ig_cfg.get("save_values", False))
233
-
234
- for key in model_keys:
235
- trainer = model.trainers.get(key)
236
- if trainer is None:
237
- print(f"[Explain] Skip {model_name}/{key}: trainer not available.")
238
- continue
239
- model.load_model(key)
240
- trained_model = getattr(model, f"{key}_best", None)
241
- if trained_model is None:
242
- print(f"[Explain] Skip {model_name}/{key}: model not loaded.")
243
- continue
244
-
245
- if key == "ft" and str(model.config.ft_role) != "model":
246
- print(f"[Explain] Skip {model_name}/ft: ft_role != 'model'.")
247
- continue
248
-
249
- for method in methods:
250
- if method == "permutation" and key not in {"xgb", "resn", "ft"}:
251
- print(f"[Explain] Skip permutation for {model_name}/{key}.")
252
- continue
253
- if method == "shap" and key not in {"glm", "xgb", "resn", "ft"}:
254
- print(f"[Explain] Skip shap for {model_name}/{key}.")
255
- continue
256
- if method == "integrated_gradients" and key not in {"resn", "ft"}:
257
- print(f"[Explain] Skip integrated gradients for {model_name}/{key}.")
258
- continue
259
-
260
- if method == "permutation":
261
- try:
262
- result = model.compute_permutation_importance(
263
- key,
264
- on_train=on_train,
265
- metric=perm_metric,
266
- n_repeats=perm_repeats,
267
- max_rows=perm_max_rows,
268
- random_state=perm_random_state,
269
- )
270
- except Exception as exc:
271
- print(f"[Explain] permutation failed for {model_name}/{key}: {exc}")
272
- continue
273
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_permutation.csv"
274
- _save_df(result, out_path)
275
- print(f"[Explain] Saved permutation -> {out_path}")
276
-
277
- if method == "shap":
278
- try:
279
- if key == "glm":
280
- shap_result = model.compute_shap_glm(
281
- n_background=shap_background,
282
- n_samples=shap_samples,
283
- on_train=on_train,
284
- )
285
- elif key == "xgb":
286
- shap_result = model.compute_shap_xgb(
287
- n_background=shap_background,
288
- n_samples=shap_samples,
289
- on_train=on_train,
290
- )
291
- elif key == "resn":
292
- shap_result = model.compute_shap_resn(
293
- n_background=shap_background,
294
- n_samples=shap_samples,
295
- on_train=on_train,
296
- )
297
- else:
298
- shap_result = model.compute_shap_ft(
299
- n_background=shap_background,
300
- n_samples=shap_samples,
301
- on_train=on_train,
302
- )
303
- except Exception as exc:
304
- print(f"[Explain] shap failed for {model_name}/{key}: {exc}")
305
- continue
306
-
307
- shap_values = shap_result.get("shap_values")
308
- X_explain = shap_result.get("X_explain")
309
- feature_names = (
310
- list(X_explain.columns)
311
- if isinstance(X_explain, pd.DataFrame)
312
- else list(model.factor_nmes)
313
- )
314
- importance = _shap_importance(shap_values, feature_names)
315
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_importance.csv"
316
- _save_series(importance, out_path)
317
- print(f"[Explain] Saved SHAP importance -> {out_path}")
318
-
319
- if shap_save_values:
320
- values_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_values.npy"
321
- np.save(values_path, np.array(shap_values, dtype=object), allow_pickle=True)
322
- if isinstance(X_explain, pd.DataFrame):
323
- x_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_X.csv"
324
- _save_df(X_explain, x_path)
325
- meta_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_meta.json"
326
- meta = {
327
- "base_value": shap_result.get("base_value"),
328
- "n_samples": int(len(X_explain)) if X_explain is not None else None,
329
- }
330
- meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
331
-
332
- if method == "integrated_gradients":
333
- try:
334
- if key == "resn":
335
- ig_result = model.compute_integrated_gradients_resn(
336
- on_train=on_train,
337
- baseline=ig_baseline,
338
- steps=ig_steps,
339
- batch_size=ig_batch_size,
340
- target=ig_target,
341
- )
342
- series = ig_result.get("importance")
343
- if isinstance(series, pd.Series):
344
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_importance.csv"
345
- _save_series(series, out_path)
346
- print(f"[Explain] Saved IG importance -> {out_path}")
347
- if ig_save_values and "attributions" in ig_result:
348
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_attributions.npy"
349
- np.save(attr_path, ig_result.get("attributions"))
350
- else:
351
- ig_result = model.compute_integrated_gradients_ft(
352
- on_train=on_train,
353
- baseline_num=ig_baseline_num,
354
- baseline_geo=ig_baseline_geo,
355
- steps=ig_steps,
356
- batch_size=ig_batch_size,
357
- target=ig_target,
358
- )
359
- series_num = ig_result.get("importance_num")
360
- series_geo = ig_result.get("importance_geo")
361
- if isinstance(series_num, pd.Series):
362
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_importance.csv"
363
- _save_series(series_num, out_path)
364
- print(f"[Explain] Saved IG num importance -> {out_path}")
365
- if isinstance(series_geo, pd.Series):
366
- out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_importance.csv"
367
- _save_series(series_geo, out_path)
368
- print(f"[Explain] Saved IG geo importance -> {out_path}")
369
- if ig_save_values:
370
- if ig_result.get("attributions_num") is not None:
371
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_attributions.npy"
372
- np.save(attr_path, ig_result.get("attributions_num"))
373
- if ig_result.get("attributions_geo") is not None:
374
- attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_attributions.npy"
375
- np.save(attr_path, ig_result.get("attributions_geo"))
376
- except Exception as exc:
377
- print(f"[Explain] integrated gradients failed for {model_name}/{key}: {exc}")
378
- continue
379
-
380
-
381
- def explain_from_config(args: argparse.Namespace) -> None:
382
- script_dir = Path(__file__).resolve().parents[1]
383
- config_path, cfg = resolve_and_load_config(
384
- args.config_json,
385
- script_dir,
386
- required_keys=["data_dir", "model_list", "model_categories", "target", "weight"],
387
- )
388
-
389
- data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
390
- cfg,
391
- config_path,
392
- create_data_dir=True,
393
- )
394
-
395
- runtime_cfg = resolve_runtime_config(cfg)
396
- output_cfg = resolve_output_dirs(
397
- cfg,
398
- config_path,
399
- output_override=args.output_dir,
400
- )
401
- output_dir = output_cfg["output_dir"]
402
-
403
- split_cfg = resolve_split_config(cfg)
404
- prop_test = split_cfg["prop_test"]
405
- rand_seed = runtime_cfg["rand_seed"]
406
- split_strategy = split_cfg["split_strategy"]
407
- split_group_col = split_cfg["split_group_col"]
408
- split_time_col = split_cfg["split_time_col"]
409
- split_time_ascending = split_cfg["split_time_ascending"]
410
-
411
- explain_cfg = dict(cfg.get("explain") or {})
412
-
413
- model_keys = args.model_keys or explain_cfg.get("model_keys") or ["xgb"]
414
- if "all" in model_keys:
415
- model_keys = ["glm", "xgb", "resn", "ft", "gnn"]
416
- model_keys = dedupe_preserve_order([str(x) for x in model_keys])
417
-
418
- method_list = args.methods or explain_cfg.get("methods") or ["permutation"]
419
- methods = _normalize_methods([str(x) for x in method_list])
420
-
421
- on_train = bool(args.on_train or explain_cfg.get("on_train", False))
422
-
423
- model_names = build_model_names(cfg["model_list"], cfg["model_categories"])
424
- if not model_names:
425
- raise ValueError("No model names generated from model_list/model_categories.")
426
-
427
- save_root = resolve_explain_save_root(
428
- args.save_dir or explain_cfg.get("save_dir"),
429
- config_path.parent,
430
- )
431
-
432
- for model_name in model_names:
433
- train_path = resolve_model_path_value(
434
- explain_cfg.get("train_path"),
435
- model_name=model_name,
436
- base_dir=config_path.parent,
437
- data_dir=data_dir,
438
- )
439
- if train_path is None:
440
- train_path = resolve_data_path(
441
- data_dir,
442
- model_name,
443
- data_format=data_format,
444
- path_template=data_path_template,
445
- )
446
- if not train_path.exists():
447
- raise FileNotFoundError(f"Missing training dataset: {train_path}")
448
-
449
- validation_override = args.eval_path or explain_cfg.get("validation_path") or explain_cfg.get("eval_path")
450
- validation_path = resolve_model_path_value(
451
- validation_override,
452
- model_name=model_name,
453
- base_dir=config_path.parent,
454
- data_dir=data_dir,
455
- )
456
-
457
- raw = _load_dataset(
458
- train_path,
459
- data_format=data_format,
460
- dtype_map=dtype_map,
461
- )
462
- if validation_path is not None:
463
- if not validation_path.exists():
464
- raise FileNotFoundError(f"Missing validation dataset: {validation_path}")
465
- train_df = raw
466
- test_df = _load_dataset(
467
- validation_path,
468
- data_format=data_format,
469
- dtype_map=dtype_map,
470
- )
471
- else:
472
- if float(prop_test) <= 0:
473
- train_df = raw
474
- test_df = raw.copy()
475
- else:
476
- train_df, test_df = split_train_test(
477
- raw,
478
- holdout_ratio=prop_test,
479
- strategy=split_strategy,
480
- group_col=split_group_col,
481
- time_col=split_time_col,
482
- time_ascending=split_time_ascending,
483
- rand_seed=rand_seed,
484
- reset_index_mode="time_group",
485
- ratio_label="prop_test",
486
- include_strategy_in_ratio_error=True,
487
- )
488
-
489
- binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
490
- feature_list = cfg.get("feature_list")
491
- categorical_features = cfg.get("categorical_features")
492
- plot_path_style = runtime_cfg["plot_path_style"]
493
-
494
- config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
495
- allowed_config_keys = set(config_fields.keys())
496
- config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
497
- config_payload.update({
498
- "model_nme": model_name,
499
- "resp_nme": cfg["target"],
500
- "weight_nme": cfg["weight"],
501
- "factor_nmes": feature_list,
502
- "task_type": str(cfg.get("task_type", "regression")),
503
- "binary_resp_nme": binary_target,
504
- "cate_list": categorical_features,
505
- "prop_test": prop_test,
506
- "rand_seed": rand_seed,
507
- "epochs": int(runtime_cfg["epochs"]),
508
- "use_gpu": bool(cfg.get("use_gpu", True)),
509
- "output_dir": output_dir,
32
+ from ins_pricing.cli.utils.import_resolver import resolve_imports, setup_sys_path
33
+
34
+ setup_sys_path()
35
+ _imports = resolve_imports()
36
+
37
+ ropt = _imports.bayesopt
38
+ if ropt is None: # pragma: no cover
39
+ raise ImportError("Failed to resolve ins_pricing.bayesopt for explain CLI.")
40
+
41
+ build_model_names = _imports.build_model_names
42
+ dedupe_preserve_order = _imports.dedupe_preserve_order
43
+ load_dataset = _imports.load_dataset
44
+ resolve_data_path = _imports.resolve_data_path
45
+ coerce_dataset_types = _imports.coerce_dataset_types
46
+ split_train_test = _imports.split_train_test
47
+
48
+ add_config_json_arg = _imports.add_config_json_arg
49
+ add_output_dir_arg = _imports.add_output_dir_arg
50
+ resolve_and_load_config = _imports.resolve_and_load_config
51
+ resolve_data_config = _imports.resolve_data_config
52
+ resolve_explain_output_overrides = _imports.resolve_explain_output_overrides
53
+ resolve_explain_save_dir = _imports.resolve_explain_save_dir
54
+ resolve_explain_save_root = _imports.resolve_explain_save_root
55
+ resolve_model_path_value = _imports.resolve_model_path_value
56
+ resolve_split_config = _imports.resolve_split_config
57
+ resolve_runtime_config = _imports.resolve_runtime_config
58
+ resolve_output_dirs = _imports.resolve_output_dirs
59
+
60
+ configure_run_logging = _imports.configure_run_logging
61
+
62
+
63
+ _SUPPORTED_METHODS = {"permutation", "shap", "integrated_gradients"}
64
+ _METHOD_ALIASES = {
65
+ "ig": "integrated_gradients",
66
+ "integrated": "integrated_gradients",
67
+ "intgrad": "integrated_gradients",
68
+ }
69
+
70
+
71
+ def _safe_name(value: str) -> str:
72
+ return "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in str(value))
73
+
74
+
75
+ def _load_dataset(
76
+ path: Path,
77
+ *,
78
+ data_format: str,
79
+ dtype_map: Optional[Dict[str, Any]],
80
+ ) -> pd.DataFrame:
81
+ raw = load_dataset(
82
+ path,
83
+ data_format=data_format,
84
+ dtype_map=dtype_map,
85
+ low_memory=False,
86
+ )
87
+ return coerce_dataset_types(raw)
88
+
89
+
90
+ def _normalize_methods(raw: Sequence[str]) -> List[str]:
91
+ methods: List[str] = []
92
+ for item in raw:
93
+ key = str(item).strip().lower()
94
+ if not key:
95
+ continue
96
+ key = _METHOD_ALIASES.get(key, key)
97
+ if key not in _SUPPORTED_METHODS:
98
+ raise ValueError(f"Unsupported explain method: {item}")
99
+ methods.append(key)
100
+ return dedupe_preserve_order(methods)
101
+
102
+
103
+ def _save_series(series: pd.Series, path: Path) -> None:
104
+ path.parent.mkdir(parents=True, exist_ok=True)
105
+ series.to_frame(name="importance").to_csv(path, index=True)
106
+
107
+
108
+ def _save_df(df: pd.DataFrame, path: Path) -> None:
109
+ path.parent.mkdir(parents=True, exist_ok=True)
110
+ df.to_csv(path, index=False)
111
+
112
+
113
+ def _shap_importance(values: Any, feature_names: Sequence[str]) -> pd.Series:
114
+ if isinstance(values, list):
115
+ values = values[0]
116
+ arr = np.asarray(values)
117
+ if arr.ndim == 3:
118
+ arr = arr[0]
119
+ scores = np.mean(np.abs(arr), axis=0)
120
+ return pd.Series(scores, index=list(feature_names)).sort_values(ascending=False)
121
+
122
+
123
+ def _parse_args() -> argparse.Namespace:
124
+ parser = argparse.ArgumentParser(
125
+ description="Run explainability (permutation/SHAP/IG) on trained models."
126
+ )
127
+ add_config_json_arg(
128
+ parser,
129
+ help_text="Path to config.json (same schema as training).",
130
+ )
131
+ parser.add_argument(
132
+ "--model-keys",
133
+ nargs="+",
134
+ default=None,
135
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
136
+ help="Model keys to load for explanation (default from config.explain.model_keys).",
137
+ )
138
+ parser.add_argument(
139
+ "--methods",
140
+ nargs="+",
141
+ default=None,
142
+ help="Explain methods: permutation, shap, integrated_gradients (default from config.explain.methods).",
143
+ )
144
+ add_output_dir_arg(
145
+ parser,
146
+ help_text="Override output root for loading models/results.",
147
+ )
148
+ parser.add_argument(
149
+ "--eval-path",
150
+ default=None,
151
+ help="Override validation CSV path (supports {model_name}).",
152
+ )
153
+ parser.add_argument(
154
+ "--on-train",
155
+ action="store_true",
156
+ help="Explain on train split instead of validation/test.",
157
+ )
158
+ parser.add_argument(
159
+ "--save-dir",
160
+ default=None,
161
+ help="Override output directory for explanation artifacts.",
162
+ )
163
+ return parser.parse_args()
164
+
165
+
166
+ def _explain_for_model(
167
+ model: ropt.BayesOptModel,
168
+ *,
169
+ model_name: str,
170
+ model_keys: List[str],
171
+ methods: List[str],
172
+ on_train: bool,
173
+ save_dir: Path,
174
+ explain_cfg: Dict[str, Any],
175
+ ) -> None:
176
+ perm_cfg = dict(explain_cfg.get("permutation") or {})
177
+ shap_cfg = dict(explain_cfg.get("shap") or {})
178
+ ig_cfg = dict(explain_cfg.get("integrated_gradients") or {})
179
+
180
+ perm_metric = perm_cfg.get("metric", explain_cfg.get("metric", "auto"))
181
+ perm_repeats = int(perm_cfg.get("n_repeats", 5))
182
+ perm_max_rows = perm_cfg.get("max_rows", 5000)
183
+ perm_random_state = perm_cfg.get("random_state", None)
184
+
185
+ shap_background = int(shap_cfg.get("n_background", 500))
186
+ shap_samples = int(shap_cfg.get("n_samples", 200))
187
+ shap_save_values = bool(shap_cfg.get("save_values", False))
188
+
189
+ ig_steps = int(ig_cfg.get("steps", 50))
190
+ ig_batch_size = int(ig_cfg.get("batch_size", 256))
191
+ ig_target = ig_cfg.get("target", None)
192
+ ig_baseline = ig_cfg.get("baseline", None)
193
+ ig_baseline_num = ig_cfg.get("baseline_num", None)
194
+ ig_baseline_geo = ig_cfg.get("baseline_geo", None)
195
+ ig_save_values = bool(ig_cfg.get("save_values", False))
196
+
197
+ for key in model_keys:
198
+ trainer = model.trainers.get(key)
199
+ if trainer is None:
200
+ print(f"[Explain] Skip {model_name}/{key}: trainer not available.")
201
+ continue
202
+ model.load_model(key)
203
+ trained_model = getattr(model, f"{key}_best", None)
204
+ if trained_model is None:
205
+ print(f"[Explain] Skip {model_name}/{key}: model not loaded.")
206
+ continue
207
+
208
+ if key == "ft" and str(model.config.ft_role) != "model":
209
+ print(f"[Explain] Skip {model_name}/ft: ft_role != 'model'.")
210
+ continue
211
+
212
+ for method in methods:
213
+ if method == "permutation" and key not in {"xgb", "resn", "ft"}:
214
+ print(f"[Explain] Skip permutation for {model_name}/{key}.")
215
+ continue
216
+ if method == "shap" and key not in {"glm", "xgb", "resn", "ft"}:
217
+ print(f"[Explain] Skip shap for {model_name}/{key}.")
218
+ continue
219
+ if method == "integrated_gradients" and key not in {"resn", "ft"}:
220
+ print(f"[Explain] Skip integrated gradients for {model_name}/{key}.")
221
+ continue
222
+
223
+ if method == "permutation":
224
+ try:
225
+ result = model.compute_permutation_importance(
226
+ key,
227
+ on_train=on_train,
228
+ metric=perm_metric,
229
+ n_repeats=perm_repeats,
230
+ max_rows=perm_max_rows,
231
+ random_state=perm_random_state,
232
+ )
233
+ except Exception as exc:
234
+ print(f"[Explain] permutation failed for {model_name}/{key}: {exc}")
235
+ continue
236
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_permutation.csv"
237
+ _save_df(result, out_path)
238
+ print(f"[Explain] Saved permutation -> {out_path}")
239
+
240
+ if method == "shap":
241
+ try:
242
+ if key == "glm":
243
+ shap_result = model.compute_shap_glm(
244
+ n_background=shap_background,
245
+ n_samples=shap_samples,
246
+ on_train=on_train,
247
+ )
248
+ elif key == "xgb":
249
+ shap_result = model.compute_shap_xgb(
250
+ n_background=shap_background,
251
+ n_samples=shap_samples,
252
+ on_train=on_train,
253
+ )
254
+ elif key == "resn":
255
+ shap_result = model.compute_shap_resn(
256
+ n_background=shap_background,
257
+ n_samples=shap_samples,
258
+ on_train=on_train,
259
+ )
260
+ else:
261
+ shap_result = model.compute_shap_ft(
262
+ n_background=shap_background,
263
+ n_samples=shap_samples,
264
+ on_train=on_train,
265
+ )
266
+ except Exception as exc:
267
+ print(f"[Explain] shap failed for {model_name}/{key}: {exc}")
268
+ continue
269
+
270
+ shap_values = shap_result.get("shap_values")
271
+ X_explain = shap_result.get("X_explain")
272
+ feature_names = (
273
+ list(X_explain.columns)
274
+ if isinstance(X_explain, pd.DataFrame)
275
+ else list(model.factor_nmes)
276
+ )
277
+ importance = _shap_importance(shap_values, feature_names)
278
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_importance.csv"
279
+ _save_series(importance, out_path)
280
+ print(f"[Explain] Saved SHAP importance -> {out_path}")
281
+
282
+ if shap_save_values:
283
+ values_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_values.npy"
284
+ np.save(values_path, np.array(shap_values, dtype=object), allow_pickle=True)
285
+ if isinstance(X_explain, pd.DataFrame):
286
+ x_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_X.csv"
287
+ _save_df(X_explain, x_path)
288
+ meta_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_meta.json"
289
+ meta = {
290
+ "base_value": shap_result.get("base_value"),
291
+ "n_samples": int(len(X_explain)) if X_explain is not None else None,
292
+ }
293
+ meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
294
+
295
+ if method == "integrated_gradients":
296
+ try:
297
+ if key == "resn":
298
+ ig_result = model.compute_integrated_gradients_resn(
299
+ on_train=on_train,
300
+ baseline=ig_baseline,
301
+ steps=ig_steps,
302
+ batch_size=ig_batch_size,
303
+ target=ig_target,
304
+ )
305
+ series = ig_result.get("importance")
306
+ if isinstance(series, pd.Series):
307
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_importance.csv"
308
+ _save_series(series, out_path)
309
+ print(f"[Explain] Saved IG importance -> {out_path}")
310
+ if ig_save_values and "attributions" in ig_result:
311
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_attributions.npy"
312
+ np.save(attr_path, ig_result.get("attributions"))
313
+ else:
314
+ ig_result = model.compute_integrated_gradients_ft(
315
+ on_train=on_train,
316
+ baseline_num=ig_baseline_num,
317
+ baseline_geo=ig_baseline_geo,
318
+ steps=ig_steps,
319
+ batch_size=ig_batch_size,
320
+ target=ig_target,
321
+ )
322
+ series_num = ig_result.get("importance_num")
323
+ series_geo = ig_result.get("importance_geo")
324
+ if isinstance(series_num, pd.Series):
325
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_importance.csv"
326
+ _save_series(series_num, out_path)
327
+ print(f"[Explain] Saved IG num importance -> {out_path}")
328
+ if isinstance(series_geo, pd.Series):
329
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_importance.csv"
330
+ _save_series(series_geo, out_path)
331
+ print(f"[Explain] Saved IG geo importance -> {out_path}")
332
+ if ig_save_values:
333
+ if ig_result.get("attributions_num") is not None:
334
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_attributions.npy"
335
+ np.save(attr_path, ig_result.get("attributions_num"))
336
+ if ig_result.get("attributions_geo") is not None:
337
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_attributions.npy"
338
+ np.save(attr_path, ig_result.get("attributions_geo"))
339
+ except Exception as exc:
340
+ print(f"[Explain] integrated gradients failed for {model_name}/{key}: {exc}")
341
+ continue
342
+
343
+
344
+ def explain_from_config(args: argparse.Namespace) -> None:
345
+ script_dir = Path(__file__).resolve().parents[1]
346
+ config_path, cfg = resolve_and_load_config(
347
+ args.config_json,
348
+ script_dir,
349
+ required_keys=["data_dir", "model_list", "model_categories", "target", "weight"],
350
+ )
351
+
352
+ data_dir, data_format, data_path_template, dtype_map = resolve_data_config(
353
+ cfg,
354
+ config_path,
355
+ create_data_dir=True,
356
+ )
357
+
358
+ runtime_cfg = resolve_runtime_config(cfg)
359
+ output_cfg = resolve_output_dirs(
360
+ cfg,
361
+ config_path,
362
+ output_override=args.output_dir,
363
+ )
364
+ output_dir = output_cfg["output_dir"]
365
+
366
+ split_cfg = resolve_split_config(cfg)
367
+ prop_test = split_cfg["prop_test"]
368
+ rand_seed = runtime_cfg["rand_seed"]
369
+ split_strategy = split_cfg["split_strategy"]
370
+ split_group_col = split_cfg["split_group_col"]
371
+ split_time_col = split_cfg["split_time_col"]
372
+ split_time_ascending = split_cfg["split_time_ascending"]
373
+
374
+ explain_cfg = dict(cfg.get("explain") or {})
375
+
376
+ model_keys = args.model_keys or explain_cfg.get("model_keys") or ["xgb"]
377
+ if "all" in model_keys:
378
+ model_keys = ["glm", "xgb", "resn", "ft", "gnn"]
379
+ model_keys = dedupe_preserve_order([str(x) for x in model_keys])
380
+
381
+ method_list = args.methods or explain_cfg.get("methods") or ["permutation"]
382
+ methods = _normalize_methods([str(x) for x in method_list])
383
+
384
+ on_train = bool(args.on_train or explain_cfg.get("on_train", False))
385
+
386
+ model_names = build_model_names(cfg["model_list"], cfg["model_categories"])
387
+ if not model_names:
388
+ raise ValueError("No model names generated from model_list/model_categories.")
389
+
390
+ save_root = resolve_explain_save_root(
391
+ args.save_dir or explain_cfg.get("save_dir"),
392
+ config_path.parent,
393
+ )
394
+
395
+ for model_name in model_names:
396
+ train_path = resolve_model_path_value(
397
+ explain_cfg.get("train_path"),
398
+ model_name=model_name,
399
+ base_dir=config_path.parent,
400
+ data_dir=data_dir,
401
+ )
402
+ if train_path is None:
403
+ train_path = resolve_data_path(
404
+ data_dir,
405
+ model_name,
406
+ data_format=data_format,
407
+ path_template=data_path_template,
408
+ )
409
+ if not train_path.exists():
410
+ raise FileNotFoundError(f"Missing training dataset: {train_path}")
411
+
412
+ validation_override = args.eval_path or explain_cfg.get("validation_path") or explain_cfg.get("eval_path")
413
+ validation_path = resolve_model_path_value(
414
+ validation_override,
415
+ model_name=model_name,
416
+ base_dir=config_path.parent,
417
+ data_dir=data_dir,
418
+ )
419
+
420
+ raw = _load_dataset(
421
+ train_path,
422
+ data_format=data_format,
423
+ dtype_map=dtype_map,
424
+ )
425
+ if validation_path is not None:
426
+ if not validation_path.exists():
427
+ raise FileNotFoundError(f"Missing validation dataset: {validation_path}")
428
+ train_df = raw
429
+ test_df = _load_dataset(
430
+ validation_path,
431
+ data_format=data_format,
432
+ dtype_map=dtype_map,
433
+ )
434
+ else:
435
+ if float(prop_test) <= 0:
436
+ train_df = raw
437
+ test_df = raw.copy()
438
+ else:
439
+ train_df, test_df = split_train_test(
440
+ raw,
441
+ holdout_ratio=prop_test,
442
+ strategy=split_strategy,
443
+ group_col=split_group_col,
444
+ time_col=split_time_col,
445
+ time_ascending=split_time_ascending,
446
+ rand_seed=rand_seed,
447
+ reset_index_mode="time_group",
448
+ ratio_label="prop_test",
449
+ include_strategy_in_ratio_error=True,
450
+ )
451
+
452
+ binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
453
+ feature_list = cfg.get("feature_list")
454
+ categorical_features = cfg.get("categorical_features")
455
+ plot_path_style = runtime_cfg["plot_path_style"]
456
+
457
+ config_fields = getattr(ropt.BayesOptConfig, "__dataclass_fields__", {})
458
+ allowed_config_keys = set(config_fields.keys())
459
+ config_payload = {k: v for k, v in cfg.items() if k in allowed_config_keys}
460
+ config_payload.update({
461
+ "model_nme": model_name,
462
+ "resp_nme": cfg["target"],
463
+ "weight_nme": cfg["weight"],
464
+ "factor_nmes": feature_list,
465
+ "task_type": str(cfg.get("task_type", "regression")),
466
+ "binary_resp_nme": binary_target,
467
+ "cate_list": categorical_features,
468
+ "prop_test": prop_test,
469
+ "rand_seed": rand_seed,
470
+ "epochs": int(runtime_cfg["epochs"]),
471
+ "use_gpu": bool(cfg.get("use_gpu", True)),
472
+ "output_dir": output_dir,
510
473
  "xgb_max_depth_max": runtime_cfg["xgb_max_depth_max"],
511
474
  "xgb_n_estimators_max": runtime_cfg["xgb_n_estimators_max"],
475
+ "xgb_gpu_id": runtime_cfg["xgb_gpu_id"],
476
+ "xgb_cleanup_per_fold": runtime_cfg["xgb_cleanup_per_fold"],
477
+ "xgb_cleanup_synchronize": runtime_cfg["xgb_cleanup_synchronize"],
478
+ "xgb_use_dmatrix": runtime_cfg["xgb_use_dmatrix"],
479
+ "ft_cleanup_per_fold": runtime_cfg["ft_cleanup_per_fold"],
480
+ "ft_cleanup_synchronize": runtime_cfg["ft_cleanup_synchronize"],
481
+ "resn_cleanup_per_fold": runtime_cfg["resn_cleanup_per_fold"],
482
+ "resn_cleanup_synchronize": runtime_cfg["resn_cleanup_synchronize"],
483
+ "gnn_cleanup_per_fold": runtime_cfg["gnn_cleanup_per_fold"],
484
+ "gnn_cleanup_synchronize": runtime_cfg["gnn_cleanup_synchronize"],
485
+ "optuna_cleanup_synchronize": runtime_cfg["optuna_cleanup_synchronize"],
512
486
  "resn_weight_decay": cfg.get("resn_weight_decay"),
513
- "final_ensemble": bool(cfg.get("final_ensemble", False)),
514
- "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
515
- "final_refit": bool(cfg.get("final_refit", True)),
516
- "optuna_storage": runtime_cfg["optuna_storage"],
517
- "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
518
- "best_params_files": runtime_cfg["best_params_files"],
519
- "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
520
- "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
521
- "gnn_graph_cache": cfg.get("gnn_graph_cache"),
522
- "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
523
- "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
524
- "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
525
- "region_province_col": cfg.get("region_province_col"),
526
- "region_city_col": cfg.get("region_city_col"),
527
- "region_effect_alpha": cfg.get("region_effect_alpha"),
528
- "geo_feature_nmes": cfg.get("geo_feature_nmes"),
529
- "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
530
- "geo_token_layers": cfg.get("geo_token_layers"),
531
- "geo_token_dropout": cfg.get("geo_token_dropout"),
532
- "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
533
- "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
534
- "geo_token_epochs": cfg.get("geo_token_epochs"),
535
- "ft_role": str(cfg.get("ft_role", "model")),
536
- "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
537
- "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
538
- "reuse_best_params": runtime_cfg["reuse_best_params"],
539
- "plot_path_style": plot_path_style or "nested",
540
- })
541
- config_payload = {k: v for k, v in config_payload.items() if v is not None}
542
- config = ropt.BayesOptConfig(**config_payload)
543
- model = ropt.BayesOptModel(train_df, test_df, config=config)
544
-
545
- output_overrides = resolve_explain_output_overrides(
546
- explain_cfg,
547
- model_name=model_name,
548
- base_dir=config_path.parent,
549
- )
550
- model_dir_override = output_overrides.get("model_dir")
551
- if model_dir_override is not None:
552
- model.output_manager.model_dir = model_dir_override
553
- result_dir_override = output_overrides.get("result_dir")
554
- if result_dir_override is not None:
555
- model.output_manager.result_dir = result_dir_override
556
- plot_dir_override = output_overrides.get("plot_dir")
557
- if plot_dir_override is not None:
558
- model.output_manager.plot_dir = plot_dir_override
559
-
560
- save_dir = resolve_explain_save_dir(
561
- save_root,
562
- result_dir=model.output_manager.result_dir,
563
- )
564
- save_dir.mkdir(parents=True, exist_ok=True)
565
-
566
- print(f"\n=== Explain model {model_name} ===")
567
- _explain_for_model(
568
- model,
569
- model_name=model_name,
570
- model_keys=model_keys,
571
- methods=methods,
572
- on_train=on_train,
573
- save_dir=save_dir,
574
- explain_cfg=explain_cfg,
575
- )
576
-
577
-
578
- def main() -> None:
579
- if configure_run_logging:
580
- configure_run_logging(prefix="explain_entry")
581
- args = _parse_args()
582
- explain_from_config(args)
583
-
584
-
585
- if __name__ == "__main__":
586
- main()
487
+ "final_ensemble": bool(cfg.get("final_ensemble", False)),
488
+ "final_ensemble_k": int(cfg.get("final_ensemble_k", 3)),
489
+ "final_refit": bool(cfg.get("final_refit", True)),
490
+ "optuna_storage": runtime_cfg["optuna_storage"],
491
+ "optuna_study_prefix": runtime_cfg["optuna_study_prefix"],
492
+ "best_params_files": runtime_cfg["best_params_files"],
493
+ "gnn_use_approx_knn": cfg.get("gnn_use_approx_knn", True),
494
+ "gnn_approx_knn_threshold": cfg.get("gnn_approx_knn_threshold", 50000),
495
+ "gnn_graph_cache": cfg.get("gnn_graph_cache"),
496
+ "gnn_max_gpu_knn_nodes": cfg.get("gnn_max_gpu_knn_nodes", 200000),
497
+ "gnn_knn_gpu_mem_ratio": cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
498
+ "gnn_knn_gpu_mem_overhead": cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
499
+ "region_province_col": cfg.get("region_province_col"),
500
+ "region_city_col": cfg.get("region_city_col"),
501
+ "region_effect_alpha": cfg.get("region_effect_alpha"),
502
+ "geo_feature_nmes": cfg.get("geo_feature_nmes"),
503
+ "geo_token_hidden_dim": cfg.get("geo_token_hidden_dim"),
504
+ "geo_token_layers": cfg.get("geo_token_layers"),
505
+ "geo_token_dropout": cfg.get("geo_token_dropout"),
506
+ "geo_token_k_neighbors": cfg.get("geo_token_k_neighbors"),
507
+ "geo_token_learning_rate": cfg.get("geo_token_learning_rate"),
508
+ "geo_token_epochs": cfg.get("geo_token_epochs"),
509
+ "ft_role": str(cfg.get("ft_role", "model")),
510
+ "ft_feature_prefix": str(cfg.get("ft_feature_prefix", "ft_emb")),
511
+ "ft_num_numeric_tokens": cfg.get("ft_num_numeric_tokens"),
512
+ "reuse_best_params": runtime_cfg["reuse_best_params"],
513
+ "plot_path_style": plot_path_style or "nested",
514
+ })
515
+ config_payload = {k: v for k, v in config_payload.items() if v is not None}
516
+ config = ropt.BayesOptConfig(**config_payload)
517
+ model = ropt.BayesOptModel(train_df, test_df, config=config)
518
+
519
+ output_overrides = resolve_explain_output_overrides(
520
+ explain_cfg,
521
+ model_name=model_name,
522
+ base_dir=config_path.parent,
523
+ )
524
+ model_dir_override = output_overrides.get("model_dir")
525
+ if model_dir_override is not None:
526
+ model.output_manager.model_dir = model_dir_override
527
+ result_dir_override = output_overrides.get("result_dir")
528
+ if result_dir_override is not None:
529
+ model.output_manager.result_dir = result_dir_override
530
+ plot_dir_override = output_overrides.get("plot_dir")
531
+ if plot_dir_override is not None:
532
+ model.output_manager.plot_dir = plot_dir_override
533
+
534
+ save_dir = resolve_explain_save_dir(
535
+ save_root,
536
+ result_dir=model.output_manager.result_dir,
537
+ )
538
+ save_dir.mkdir(parents=True, exist_ok=True)
539
+
540
+ print(f"\n=== Explain model {model_name} ===")
541
+ _explain_for_model(
542
+ model,
543
+ model_name=model_name,
544
+ model_keys=model_keys,
545
+ methods=methods,
546
+ on_train=on_train,
547
+ save_dir=save_dir,
548
+ explain_cfg=explain_cfg,
549
+ )
550
+
551
+
552
+ def main() -> None:
553
+ if configure_run_logging:
554
+ configure_run_logging(prefix="explain_entry")
555
+ args = _parse_args()
556
+ explain_from_config(args)
557
+
558
+
559
+ if __name__ == "__main__":
560
+ main()