ins-pricing 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. ins_pricing/README.md +60 -0
  2. ins_pricing/__init__.py +102 -0
  3. ins_pricing/governance/README.md +18 -0
  4. ins_pricing/governance/__init__.py +20 -0
  5. ins_pricing/governance/approval.py +93 -0
  6. ins_pricing/governance/audit.py +37 -0
  7. ins_pricing/governance/registry.py +99 -0
  8. ins_pricing/governance/release.py +159 -0
  9. ins_pricing/modelling/BayesOpt.py +146 -0
  10. ins_pricing/modelling/BayesOpt_USAGE.md +925 -0
  11. ins_pricing/modelling/BayesOpt_entry.py +575 -0
  12. ins_pricing/modelling/BayesOpt_incremental.py +731 -0
  13. ins_pricing/modelling/Explain_Run.py +36 -0
  14. ins_pricing/modelling/Explain_entry.py +539 -0
  15. ins_pricing/modelling/Pricing_Run.py +36 -0
  16. ins_pricing/modelling/README.md +33 -0
  17. ins_pricing/modelling/__init__.py +44 -0
  18. ins_pricing/modelling/bayesopt/__init__.py +98 -0
  19. ins_pricing/modelling/bayesopt/config_preprocess.py +303 -0
  20. ins_pricing/modelling/bayesopt/core.py +1476 -0
  21. ins_pricing/modelling/bayesopt/models.py +2196 -0
  22. ins_pricing/modelling/bayesopt/trainers.py +2446 -0
  23. ins_pricing/modelling/bayesopt/utils.py +1021 -0
  24. ins_pricing/modelling/cli_common.py +136 -0
  25. ins_pricing/modelling/explain/__init__.py +55 -0
  26. ins_pricing/modelling/explain/gradients.py +334 -0
  27. ins_pricing/modelling/explain/metrics.py +176 -0
  28. ins_pricing/modelling/explain/permutation.py +155 -0
  29. ins_pricing/modelling/explain/shap_utils.py +146 -0
  30. ins_pricing/modelling/notebook_utils.py +284 -0
  31. ins_pricing/modelling/plotting/__init__.py +45 -0
  32. ins_pricing/modelling/plotting/common.py +63 -0
  33. ins_pricing/modelling/plotting/curves.py +572 -0
  34. ins_pricing/modelling/plotting/diagnostics.py +139 -0
  35. ins_pricing/modelling/plotting/geo.py +362 -0
  36. ins_pricing/modelling/plotting/importance.py +121 -0
  37. ins_pricing/modelling/run_logging.py +133 -0
  38. ins_pricing/modelling/tests/conftest.py +8 -0
  39. ins_pricing/modelling/tests/test_cross_val_generic.py +66 -0
  40. ins_pricing/modelling/tests/test_distributed_utils.py +18 -0
  41. ins_pricing/modelling/tests/test_explain.py +56 -0
  42. ins_pricing/modelling/tests/test_geo_tokens_split.py +49 -0
  43. ins_pricing/modelling/tests/test_graph_cache.py +33 -0
  44. ins_pricing/modelling/tests/test_plotting.py +63 -0
  45. ins_pricing/modelling/tests/test_plotting_library.py +150 -0
  46. ins_pricing/modelling/tests/test_preprocessor.py +48 -0
  47. ins_pricing/modelling/watchdog_run.py +211 -0
  48. ins_pricing/pricing/README.md +44 -0
  49. ins_pricing/pricing/__init__.py +27 -0
  50. ins_pricing/pricing/calibration.py +39 -0
  51. ins_pricing/pricing/data_quality.py +117 -0
  52. ins_pricing/pricing/exposure.py +85 -0
  53. ins_pricing/pricing/factors.py +91 -0
  54. ins_pricing/pricing/monitoring.py +99 -0
  55. ins_pricing/pricing/rate_table.py +78 -0
  56. ins_pricing/production/__init__.py +21 -0
  57. ins_pricing/production/drift.py +30 -0
  58. ins_pricing/production/monitoring.py +143 -0
  59. ins_pricing/production/scoring.py +40 -0
  60. ins_pricing/reporting/README.md +20 -0
  61. ins_pricing/reporting/__init__.py +11 -0
  62. ins_pricing/reporting/report_builder.py +72 -0
  63. ins_pricing/reporting/scheduler.py +45 -0
  64. ins_pricing/setup.py +41 -0
  65. ins_pricing v2/__init__.py +23 -0
  66. ins_pricing v2/governance/__init__.py +20 -0
  67. ins_pricing v2/governance/approval.py +93 -0
  68. ins_pricing v2/governance/audit.py +37 -0
  69. ins_pricing v2/governance/registry.py +99 -0
  70. ins_pricing v2/governance/release.py +159 -0
  71. ins_pricing v2/modelling/Explain_Run.py +36 -0
  72. ins_pricing v2/modelling/Pricing_Run.py +36 -0
  73. ins_pricing v2/modelling/__init__.py +151 -0
  74. ins_pricing v2/modelling/cli_common.py +141 -0
  75. ins_pricing v2/modelling/config.py +249 -0
  76. ins_pricing v2/modelling/config_preprocess.py +254 -0
  77. ins_pricing v2/modelling/core.py +741 -0
  78. ins_pricing v2/modelling/data_container.py +42 -0
  79. ins_pricing v2/modelling/explain/__init__.py +55 -0
  80. ins_pricing v2/modelling/explain/gradients.py +334 -0
  81. ins_pricing v2/modelling/explain/metrics.py +176 -0
  82. ins_pricing v2/modelling/explain/permutation.py +155 -0
  83. ins_pricing v2/modelling/explain/shap_utils.py +146 -0
  84. ins_pricing v2/modelling/features.py +215 -0
  85. ins_pricing v2/modelling/model_manager.py +148 -0
  86. ins_pricing v2/modelling/model_plotting.py +463 -0
  87. ins_pricing v2/modelling/models.py +2203 -0
  88. ins_pricing v2/modelling/notebook_utils.py +294 -0
  89. ins_pricing v2/modelling/plotting/__init__.py +45 -0
  90. ins_pricing v2/modelling/plotting/common.py +63 -0
  91. ins_pricing v2/modelling/plotting/curves.py +572 -0
  92. ins_pricing v2/modelling/plotting/diagnostics.py +139 -0
  93. ins_pricing v2/modelling/plotting/geo.py +362 -0
  94. ins_pricing v2/modelling/plotting/importance.py +121 -0
  95. ins_pricing v2/modelling/run_logging.py +133 -0
  96. ins_pricing v2/modelling/tests/conftest.py +8 -0
  97. ins_pricing v2/modelling/tests/test_cross_val_generic.py +66 -0
  98. ins_pricing v2/modelling/tests/test_distributed_utils.py +18 -0
  99. ins_pricing v2/modelling/tests/test_explain.py +56 -0
  100. ins_pricing v2/modelling/tests/test_geo_tokens_split.py +49 -0
  101. ins_pricing v2/modelling/tests/test_graph_cache.py +33 -0
  102. ins_pricing v2/modelling/tests/test_plotting.py +63 -0
  103. ins_pricing v2/modelling/tests/test_plotting_library.py +150 -0
  104. ins_pricing v2/modelling/tests/test_preprocessor.py +48 -0
  105. ins_pricing v2/modelling/trainers.py +2447 -0
  106. ins_pricing v2/modelling/utils.py +1020 -0
  107. ins_pricing v2/modelling/watchdog_run.py +211 -0
  108. ins_pricing v2/pricing/__init__.py +27 -0
  109. ins_pricing v2/pricing/calibration.py +39 -0
  110. ins_pricing v2/pricing/data_quality.py +117 -0
  111. ins_pricing v2/pricing/exposure.py +85 -0
  112. ins_pricing v2/pricing/factors.py +91 -0
  113. ins_pricing v2/pricing/monitoring.py +99 -0
  114. ins_pricing v2/pricing/rate_table.py +78 -0
  115. ins_pricing v2/production/__init__.py +21 -0
  116. ins_pricing v2/production/drift.py +30 -0
  117. ins_pricing v2/production/monitoring.py +143 -0
  118. ins_pricing v2/production/scoring.py +40 -0
  119. ins_pricing v2/reporting/__init__.py +11 -0
  120. ins_pricing v2/reporting/report_builder.py +72 -0
  121. ins_pricing v2/reporting/scheduler.py +45 -0
  122. ins_pricing v2/scripts/BayesOpt_incremental.py +722 -0
  123. ins_pricing v2/scripts/Explain_entry.py +545 -0
  124. ins_pricing v2/scripts/__init__.py +1 -0
  125. ins_pricing v2/scripts/train.py +568 -0
  126. ins_pricing v2/setup.py +55 -0
  127. ins_pricing v2/smoke_test.py +28 -0
  128. ins_pricing-0.1.6.dist-info/METADATA +78 -0
  129. ins_pricing-0.1.6.dist-info/RECORD +169 -0
  130. ins_pricing-0.1.6.dist-info/WHEEL +5 -0
  131. ins_pricing-0.1.6.dist-info/top_level.txt +4 -0
  132. user_packages/__init__.py +105 -0
  133. user_packages legacy/BayesOpt.py +5659 -0
  134. user_packages legacy/BayesOpt_entry.py +513 -0
  135. user_packages legacy/BayesOpt_incremental.py +685 -0
  136. user_packages legacy/Pricing_Run.py +36 -0
  137. user_packages legacy/Try/BayesOpt Legacy251213.py +3719 -0
  138. user_packages legacy/Try/BayesOpt Legacy251215.py +3758 -0
  139. user_packages legacy/Try/BayesOpt lagecy251201.py +3506 -0
  140. user_packages legacy/Try/BayesOpt lagecy251218.py +3992 -0
  141. user_packages legacy/Try/BayesOpt legacy.py +3280 -0
  142. user_packages legacy/Try/BayesOpt.py +838 -0
  143. user_packages legacy/Try/BayesOptAll.py +1569 -0
  144. user_packages legacy/Try/BayesOptAllPlatform.py +909 -0
  145. user_packages legacy/Try/BayesOptCPUGPU.py +1877 -0
  146. user_packages legacy/Try/BayesOptSearch.py +830 -0
  147. user_packages legacy/Try/BayesOptSearchOrigin.py +829 -0
  148. user_packages legacy/Try/BayesOptV1.py +1911 -0
  149. user_packages legacy/Try/BayesOptV10.py +2973 -0
  150. user_packages legacy/Try/BayesOptV11.py +3001 -0
  151. user_packages legacy/Try/BayesOptV12.py +3001 -0
  152. user_packages legacy/Try/BayesOptV2.py +2065 -0
  153. user_packages legacy/Try/BayesOptV3.py +2209 -0
  154. user_packages legacy/Try/BayesOptV4.py +2342 -0
  155. user_packages legacy/Try/BayesOptV5.py +2372 -0
  156. user_packages legacy/Try/BayesOptV6.py +2759 -0
  157. user_packages legacy/Try/BayesOptV7.py +2832 -0
  158. user_packages legacy/Try/BayesOptV8Codex.py +2731 -0
  159. user_packages legacy/Try/BayesOptV8Gemini.py +2614 -0
  160. user_packages legacy/Try/BayesOptV9.py +2927 -0
  161. user_packages legacy/Try/BayesOpt_entry legacy.py +313 -0
  162. user_packages legacy/Try/ModelBayesOptSearch.py +359 -0
  163. user_packages legacy/Try/ResNetBayesOptSearch.py +249 -0
  164. user_packages legacy/Try/XgbBayesOptSearch.py +121 -0
  165. user_packages legacy/Try/xgbbayesopt.py +523 -0
  166. user_packages legacy/__init__.py +19 -0
  167. user_packages legacy/cli_common.py +124 -0
  168. user_packages legacy/notebook_utils.py +228 -0
  169. user_packages legacy/watchdog_run.py +202 -0
@@ -0,0 +1,545 @@
1
+ """Config-driven explain runner for trained BayesOpt models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Sequence
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.model_selection import train_test_split
13
+
14
+ try:
15
+ import ins_pricing.modelling as ropt
16
+ from ins_pricing.modelling.cli_common import (
17
+ build_model_names,
18
+ dedupe_preserve_order,
19
+ load_config_json,
20
+ normalize_config_paths,
21
+ resolve_config_path,
22
+ resolve_path,
23
+ set_env,
24
+ )
25
+ except ImportError:
26
+ import sys
27
+ from pathlib import Path
28
+ _pkg_root = Path(__file__).resolve().parent.parent
29
+ if str(_pkg_root) not in sys.path:
30
+ sys.path.insert(0, str(_pkg_root))
31
+
32
+ try:
33
+ from modelling import core as ropt
34
+ from modelling.cli_common import (
35
+ build_model_names,
36
+ dedupe_preserve_order,
37
+ load_config_json,
38
+ normalize_config_paths,
39
+ resolve_config_path,
40
+ resolve_path,
41
+ set_env,
42
+ )
43
+ except ImportError:
44
+ import ins_pricing.modelling as ropt
45
+ from ins_pricing.modelling.cli_common import (
46
+ build_model_names,
47
+ dedupe_preserve_order,
48
+ load_config_json,
49
+ normalize_config_paths,
50
+ resolve_config_path,
51
+ resolve_path,
52
+ set_env,
53
+ )
54
+
55
+ try:
56
+ from .run_logging import configure_run_logging # type: ignore
57
+ except Exception: # pragma: no cover
58
+ try:
59
+ from run_logging import configure_run_logging # type: ignore
60
+ except Exception: # pragma: no cover
61
+ configure_run_logging = None # type: ignore
62
+
63
+
64
+ _SUPPORTED_METHODS = {"permutation", "shap", "integrated_gradients"}
65
+ _METHOD_ALIASES = {
66
+ "ig": "integrated_gradients",
67
+ "integrated": "integrated_gradients",
68
+ "intgrad": "integrated_gradients",
69
+ }
70
+
71
+
72
+ def _safe_name(value: str) -> str:
73
+ return "".join(ch if ch.isalnum() or ch in "-_." else "_" for ch in str(value))
74
+
75
+
76
+ def _load_dataset(path: Path) -> pd.DataFrame:
77
+ raw = pd.read_csv(path, low_memory=False)
78
+ raw = raw.copy()
79
+ for col in raw.columns:
80
+ s = raw[col]
81
+ if pd.api.types.is_numeric_dtype(s):
82
+ raw[col] = pd.to_numeric(s, errors="coerce").fillna(0)
83
+ else:
84
+ raw[col] = s.astype("object").fillna("<NA>")
85
+ return raw
86
+
87
+
88
+ def _resolve_path_value(
89
+ value: Any,
90
+ *,
91
+ model_name: str,
92
+ base_dir: Path,
93
+ data_dir: Optional[Path] = None,
94
+ ) -> Optional[Path]:
95
+ if value is None:
96
+ return None
97
+ if isinstance(value, dict):
98
+ value = value.get(model_name)
99
+ if value is None:
100
+ return None
101
+ path_str = str(value)
102
+ try:
103
+ path_str = path_str.format(model_name=model_name)
104
+ except Exception:
105
+ pass
106
+ if data_dir is not None and not Path(path_str).is_absolute():
107
+ candidate = data_dir / path_str
108
+ if candidate.exists():
109
+ return candidate.resolve()
110
+ resolved = resolve_path(path_str, base_dir)
111
+ if resolved is None:
112
+ return None
113
+ return resolved
114
+
115
+
116
+ def _normalize_methods(raw: Sequence[str]) -> List[str]:
117
+ methods: List[str] = []
118
+ for item in raw:
119
+ key = str(item).strip().lower()
120
+ if not key:
121
+ continue
122
+ key = _METHOD_ALIASES.get(key, key)
123
+ if key not in _SUPPORTED_METHODS:
124
+ raise ValueError(f"Unsupported explain method: {item}")
125
+ methods.append(key)
126
+ return dedupe_preserve_order(methods)
127
+
128
+
129
+ def _save_series(series: pd.Series, path: Path) -> None:
130
+ path.parent.mkdir(parents=True, exist_ok=True)
131
+ series.to_frame(name="importance").to_csv(path, index=True)
132
+
133
+
134
+ def _save_df(df: pd.DataFrame, path: Path) -> None:
135
+ path.parent.mkdir(parents=True, exist_ok=True)
136
+ df.to_csv(path, index=False)
137
+
138
+
139
+ def _shap_importance(values: Any, feature_names: Sequence[str]) -> pd.Series:
140
+ if isinstance(values, list):
141
+ values = values[0]
142
+ arr = np.asarray(values)
143
+ if arr.ndim == 3:
144
+ arr = arr[0]
145
+ scores = np.mean(np.abs(arr), axis=0)
146
+ return pd.Series(scores, index=list(feature_names)).sort_values(ascending=False)
147
+
148
+
149
+ def _parse_args() -> argparse.Namespace:
150
+ parser = argparse.ArgumentParser(
151
+ description="Run explainability (permutation/SHAP/IG) on trained models."
152
+ )
153
+ parser.add_argument(
154
+ "--config-json",
155
+ required=True,
156
+ help="Path to config.json (same schema as training).",
157
+ )
158
+ parser.add_argument(
159
+ "--model-keys",
160
+ nargs="+",
161
+ default=None,
162
+ choices=["glm", "xgb", "resn", "ft", "gnn", "all"],
163
+ help="Model keys to load for explanation (default from config.explain.model_keys).",
164
+ )
165
+ parser.add_argument(
166
+ "--methods",
167
+ nargs="+",
168
+ default=None,
169
+ help="Explain methods: permutation, shap, integrated_gradients (default from config.explain.methods).",
170
+ )
171
+ parser.add_argument(
172
+ "--output-dir",
173
+ default=None,
174
+ help="Override output root for loading models/results.",
175
+ )
176
+ parser.add_argument(
177
+ "--eval-path",
178
+ default=None,
179
+ help="Override validation CSV path (supports {model_name}).",
180
+ )
181
+ parser.add_argument(
182
+ "--on-train",
183
+ action="store_true",
184
+ help="Explain on train split instead of validation/test.",
185
+ )
186
+ parser.add_argument(
187
+ "--save-dir",
188
+ default=None,
189
+ help="Override output directory for explanation artifacts.",
190
+ )
191
+ return parser.parse_args()
192
+
193
+
194
+ def _explain_for_model(
195
+ model: ropt.BayesOptModel,
196
+ *,
197
+ model_name: str,
198
+ model_keys: List[str],
199
+ methods: List[str],
200
+ on_train: bool,
201
+ save_dir: Path,
202
+ explain_cfg: Dict[str, Any],
203
+ ) -> None:
204
+ perm_cfg = dict(explain_cfg.get("permutation") or {})
205
+ shap_cfg = dict(explain_cfg.get("shap") or {})
206
+ ig_cfg = dict(explain_cfg.get("integrated_gradients") or {})
207
+
208
+ perm_metric = perm_cfg.get("metric", explain_cfg.get("metric", "auto"))
209
+ perm_repeats = int(perm_cfg.get("n_repeats", 5))
210
+ perm_max_rows = perm_cfg.get("max_rows", 5000)
211
+ perm_random_state = perm_cfg.get("random_state", None)
212
+
213
+ shap_background = int(shap_cfg.get("n_background", 500))
214
+ shap_samples = int(shap_cfg.get("n_samples", 200))
215
+ shap_save_values = bool(shap_cfg.get("save_values", False))
216
+
217
+ ig_steps = int(ig_cfg.get("steps", 50))
218
+ ig_batch_size = int(ig_cfg.get("batch_size", 256))
219
+ ig_target = ig_cfg.get("target", None)
220
+ ig_baseline = ig_cfg.get("baseline", None)
221
+ ig_baseline_num = ig_cfg.get("baseline_num", None)
222
+ ig_baseline_geo = ig_cfg.get("baseline_geo", None)
223
+ ig_save_values = bool(ig_cfg.get("save_values", False))
224
+
225
+ for key in model_keys:
226
+ trainer = model.trainers.get(key)
227
+ if trainer is None:
228
+ print(f"[Explain] Skip {model_name}/{key}: trainer not available.")
229
+ continue
230
+ model.load_model(key)
231
+ trained_model = getattr(model, f"{key}_best", None)
232
+ if trained_model is None:
233
+ print(f"[Explain] Skip {model_name}/{key}: model not loaded.")
234
+ continue
235
+
236
+ if key == "ft" and str(model.config.ft_role) != "model":
237
+ print(f"[Explain] Skip {model_name}/ft: ft_role != 'model'.")
238
+ continue
239
+
240
+ for method in methods:
241
+ if method == "permutation" and key not in {"xgb", "resn", "ft"}:
242
+ print(f"[Explain] Skip permutation for {model_name}/{key}.")
243
+ continue
244
+ if method == "shap" and key not in {"glm", "xgb", "resn", "ft"}:
245
+ print(f"[Explain] Skip shap for {model_name}/{key}.")
246
+ continue
247
+ if method == "integrated_gradients" and key not in {"resn", "ft"}:
248
+ print(f"[Explain] Skip integrated gradients for {model_name}/{key}.")
249
+ continue
250
+
251
+ if method == "permutation":
252
+ try:
253
+ result = model.compute_permutation_importance(
254
+ key,
255
+ on_train=on_train,
256
+ metric=perm_metric,
257
+ n_repeats=perm_repeats,
258
+ max_rows=perm_max_rows,
259
+ random_state=perm_random_state,
260
+ )
261
+ except Exception as exc:
262
+ print(f"[Explain] permutation failed for {model_name}/{key}: {exc}")
263
+ continue
264
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_permutation.csv"
265
+ _save_df(result, out_path)
266
+ print(f"[Explain] Saved permutation -> {out_path}")
267
+
268
+ if method == "shap":
269
+ try:
270
+ if key == "glm":
271
+ shap_result = model.compute_shap_glm(
272
+ n_background=shap_background,
273
+ n_samples=shap_samples,
274
+ on_train=on_train,
275
+ )
276
+ elif key == "xgb":
277
+ shap_result = model.compute_shap_xgb(
278
+ n_background=shap_background,
279
+ n_samples=shap_samples,
280
+ on_train=on_train,
281
+ )
282
+ elif key == "resn":
283
+ shap_result = model.compute_shap_resn(
284
+ n_background=shap_background,
285
+ n_samples=shap_samples,
286
+ on_train=on_train,
287
+ )
288
+ else:
289
+ shap_result = model.compute_shap_ft(
290
+ n_background=shap_background,
291
+ n_samples=shap_samples,
292
+ on_train=on_train,
293
+ )
294
+ except Exception as exc:
295
+ print(f"[Explain] shap failed for {model_name}/{key}: {exc}")
296
+ continue
297
+
298
+ shap_values = shap_result.get("shap_values")
299
+ X_explain = shap_result.get("X_explain")
300
+ feature_names = (
301
+ list(X_explain.columns)
302
+ if isinstance(X_explain, pd.DataFrame)
303
+ else list(model.factor_nmes)
304
+ )
305
+ importance = _shap_importance(shap_values, feature_names)
306
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_importance.csv"
307
+ _save_series(importance, out_path)
308
+ print(f"[Explain] Saved SHAP importance -> {out_path}")
309
+
310
+ if shap_save_values:
311
+ values_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_values.npy"
312
+ np.save(values_path, np.array(shap_values, dtype=object), allow_pickle=True)
313
+ if isinstance(X_explain, pd.DataFrame):
314
+ x_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_X.csv"
315
+ _save_df(X_explain, x_path)
316
+ meta_path = save_dir / f"{_safe_name(model_name)}_{key}_shap_meta.json"
317
+ meta = {
318
+ "base_value": shap_result.get("base_value"),
319
+ "n_samples": int(len(X_explain)) if X_explain is not None else None,
320
+ }
321
+ meta_path.write_text(json.dumps(meta, indent=2), encoding="utf-8")
322
+
323
+ if method == "integrated_gradients":
324
+ try:
325
+ if key == "resn":
326
+ ig_result = model.compute_integrated_gradients_resn(
327
+ on_train=on_train,
328
+ baseline=ig_baseline,
329
+ steps=ig_steps,
330
+ batch_size=ig_batch_size,
331
+ target=ig_target,
332
+ )
333
+ series = ig_result.get("importance")
334
+ if isinstance(series, pd.Series):
335
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_importance.csv"
336
+ _save_series(series, out_path)
337
+ print(f"[Explain] Saved IG importance -> {out_path}")
338
+ if ig_save_values and "attributions" in ig_result:
339
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_attributions.npy"
340
+ np.save(attr_path, ig_result.get("attributions"))
341
+ else:
342
+ ig_result = model.compute_integrated_gradients_ft(
343
+ on_train=on_train,
344
+ baseline_num=ig_baseline_num,
345
+ baseline_geo=ig_baseline_geo,
346
+ steps=ig_steps,
347
+ batch_size=ig_batch_size,
348
+ target=ig_target,
349
+ )
350
+ series_num = ig_result.get("importance_num")
351
+ series_geo = ig_result.get("importance_geo")
352
+ if isinstance(series_num, pd.Series):
353
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_importance.csv"
354
+ _save_series(series_num, out_path)
355
+ print(f"[Explain] Saved IG num importance -> {out_path}")
356
+ if isinstance(series_geo, pd.Series):
357
+ out_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_importance.csv"
358
+ _save_series(series_geo, out_path)
359
+ print(f"[Explain] Saved IG geo importance -> {out_path}")
360
+ if ig_save_values:
361
+ if ig_result.get("attributions_num") is not None:
362
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_num_attributions.npy"
363
+ np.save(attr_path, ig_result.get("attributions_num"))
364
+ if ig_result.get("attributions_geo") is not None:
365
+ attr_path = save_dir / f"{_safe_name(model_name)}_{key}_ig_geo_attributions.npy"
366
+ np.save(attr_path, ig_result.get("attributions_geo"))
367
+ except Exception as exc:
368
+ print(f"[Explain] integrated gradients failed for {model_name}/{key}: {exc}")
369
+ continue
370
+
371
+
372
+ def explain_from_config(args: argparse.Namespace) -> None:
373
+ script_dir = Path(__file__).resolve().parent
374
+ config_path = resolve_config_path(args.config_json, script_dir)
375
+ cfg = load_config_json(
376
+ config_path,
377
+ required_keys=["data_dir", "model_list", "model_categories", "target", "weight"],
378
+ )
379
+ cfg = normalize_config_paths(cfg, config_path)
380
+
381
+ set_env(cfg.get("env", {}))
382
+
383
+ data_dir = Path(cfg["data_dir"])
384
+ data_dir.mkdir(parents=True, exist_ok=True)
385
+
386
+ output_dir = args.output_dir or cfg.get("output_dir")
387
+ if isinstance(output_dir, str) and output_dir.strip():
388
+ resolved = resolve_path(output_dir, config_path.parent)
389
+ if resolved is not None:
390
+ output_dir = str(resolved)
391
+
392
+ prop_test = cfg.get("prop_test", 0.25)
393
+ rand_seed = cfg.get("rand_seed", 13)
394
+
395
+ explain_cfg = dict(cfg.get("explain") or {})
396
+
397
+ model_keys = args.model_keys or explain_cfg.get("model_keys") or ["xgb"]
398
+ if "all" in model_keys:
399
+ model_keys = ["glm", "xgb", "resn", "ft", "gnn"]
400
+ model_keys = dedupe_preserve_order([str(x) for x in model_keys])
401
+
402
+ method_list = args.methods or explain_cfg.get("methods") or ["permutation"]
403
+ methods = _normalize_methods([str(x) for x in method_list])
404
+
405
+ on_train = bool(args.on_train or explain_cfg.get("on_train", False))
406
+
407
+ model_names = build_model_names(cfg["model_list"], cfg["model_categories"])
408
+ if not model_names:
409
+ raise ValueError("No model names generated from model_list/model_categories.")
410
+
411
+ save_dir_raw = args.save_dir or explain_cfg.get("save_dir")
412
+ if save_dir_raw:
413
+ resolved = resolve_path(str(save_dir_raw), config_path.parent)
414
+ save_root = resolved if resolved is not None else Path(str(save_dir_raw))
415
+ else:
416
+ save_root = None
417
+
418
+ for model_name in model_names:
419
+ train_path = _resolve_path_value(
420
+ explain_cfg.get("train_path"),
421
+ model_name=model_name,
422
+ base_dir=config_path.parent,
423
+ data_dir=data_dir,
424
+ )
425
+ if train_path is None:
426
+ train_path = data_dir / f"{model_name}.csv"
427
+ if not train_path.exists():
428
+ raise FileNotFoundError(f"Missing training dataset: {train_path}")
429
+
430
+ validation_override = args.eval_path or explain_cfg.get("validation_path") or explain_cfg.get("eval_path")
431
+ validation_path = _resolve_path_value(
432
+ validation_override,
433
+ model_name=model_name,
434
+ base_dir=config_path.parent,
435
+ data_dir=data_dir,
436
+ )
437
+
438
+ raw = _load_dataset(train_path)
439
+ if validation_path is not None:
440
+ if not validation_path.exists():
441
+ raise FileNotFoundError(f"Missing validation dataset: {validation_path}")
442
+ train_df = raw
443
+ test_df = _load_dataset(validation_path)
444
+ else:
445
+ if float(prop_test) <= 0:
446
+ train_df = raw
447
+ test_df = raw.copy()
448
+ else:
449
+ train_df, test_df = train_test_split(
450
+ raw, test_size=prop_test, random_state=rand_seed
451
+ )
452
+
453
+ binary_target = cfg.get("binary_target") or cfg.get("binary_resp_nme")
454
+ feature_list = cfg.get("feature_list")
455
+ categorical_features = cfg.get("categorical_features")
456
+
457
+ model = ropt.BayesOptModel(
458
+ train_df,
459
+ test_df,
460
+ model_name,
461
+ cfg["target"],
462
+ cfg["weight"],
463
+ feature_list,
464
+ binary_resp_nme=binary_target,
465
+ cate_list=categorical_features,
466
+ prop_test=prop_test,
467
+ rand_seed=rand_seed,
468
+ epochs=int(cfg.get("epochs", 50)),
469
+ use_gpu=bool(cfg.get("use_gpu", True)),
470
+ output_dir=output_dir,
471
+ xgb_max_depth_max=int(cfg.get("xgb_max_depth_max", 25)),
472
+ xgb_n_estimators_max=int(cfg.get("xgb_n_estimators_max", 500)),
473
+ resn_weight_decay=cfg.get("resn_weight_decay"),
474
+ final_ensemble=bool(cfg.get("final_ensemble", False)),
475
+ final_ensemble_k=int(cfg.get("final_ensemble_k", 3)),
476
+ final_refit=bool(cfg.get("final_refit", True)),
477
+ optuna_storage=cfg.get("optuna_storage"),
478
+ optuna_study_prefix=cfg.get("optuna_study_prefix"),
479
+ best_params_files=cfg.get("best_params_files"),
480
+ gnn_use_approx_knn=cfg.get("gnn_use_approx_knn", True),
481
+ gnn_approx_knn_threshold=cfg.get("gnn_approx_knn_threshold", 50000),
482
+ gnn_graph_cache=cfg.get("gnn_graph_cache"),
483
+ gnn_max_gpu_knn_nodes=cfg.get("gnn_max_gpu_knn_nodes", 200000),
484
+ gnn_knn_gpu_mem_ratio=cfg.get("gnn_knn_gpu_mem_ratio", 0.9),
485
+ gnn_knn_gpu_mem_overhead=cfg.get("gnn_knn_gpu_mem_overhead", 2.0),
486
+ ft_role=str(cfg.get("ft_role", "model")),
487
+ ft_feature_prefix=str(cfg.get("ft_feature_prefix", "ft_emb")),
488
+ ft_num_numeric_tokens=cfg.get("ft_num_numeric_tokens"),
489
+ infer_categorical_max_unique=int(cfg.get("infer_categorical_max_unique", 50)),
490
+ infer_categorical_max_ratio=float(cfg.get("infer_categorical_max_ratio", 0.05)),
491
+ reuse_best_params=bool(cfg.get("reuse_best_params", False)),
492
+ )
493
+
494
+ model_dir_override = _resolve_path_value(
495
+ explain_cfg.get("model_dir"),
496
+ model_name=model_name,
497
+ base_dir=config_path.parent,
498
+ data_dir=None,
499
+ )
500
+ if model_dir_override is not None:
501
+ model.output_manager.model_dir = model_dir_override
502
+ result_dir_override = _resolve_path_value(
503
+ explain_cfg.get("result_dir") or explain_cfg.get("results_dir"),
504
+ model_name=model_name,
505
+ base_dir=config_path.parent,
506
+ data_dir=None,
507
+ )
508
+ if result_dir_override is not None:
509
+ model.output_manager.result_dir = result_dir_override
510
+ plot_dir_override = _resolve_path_value(
511
+ explain_cfg.get("plot_dir"),
512
+ model_name=model_name,
513
+ base_dir=config_path.parent,
514
+ data_dir=None,
515
+ )
516
+ if plot_dir_override is not None:
517
+ model.output_manager.plot_dir = plot_dir_override
518
+
519
+ if save_root is None:
520
+ save_dir = Path(model.output_manager.result_dir) / "explain"
521
+ else:
522
+ save_dir = Path(save_root)
523
+ save_dir.mkdir(parents=True, exist_ok=True)
524
+
525
+ print(f"\n=== Explain model {model_name} ===")
526
+ _explain_for_model(
527
+ model,
528
+ model_name=model_name,
529
+ model_keys=model_keys,
530
+ methods=methods,
531
+ on_train=on_train,
532
+ save_dir=save_dir,
533
+ explain_cfg=explain_cfg,
534
+ )
535
+
536
+
537
+ def main() -> None:
538
+ if configure_run_logging:
539
+ configure_run_logging(prefix="explain_entry")
540
+ args = _parse_args()
541
+ explain_from_config(args)
542
+
543
+
544
+ if __name__ == "__main__":
545
+ main()
@@ -0,0 +1 @@
1
+ """Scripts package for ins_pricing CLI tools."""