tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
@@ -1,55 +1,287 @@
1
+ # tanml/checks/explainability/shap_check.py
1
2
  from tanml.checks.base import BaseCheck
3
+
4
+ import matplotlib
5
+ matplotlib.use("Agg")
6
+ import matplotlib.pyplot as plt
7
+
2
8
  import shap
9
+ import numpy as np
3
10
  import pandas as pd
4
- import matplotlib.pyplot as plt
5
11
  import traceback
12
+ import warnings
6
13
  from pathlib import Path
7
14
  from datetime import datetime
15
+ from scipy import sparse as sp
16
+
17
+ def _safe_numeric_cast_df(df: pd.DataFrame) -> pd.DataFrame:
18
+ """Convert columns to numeric where possible; leave others unchanged."""
19
+ out = {}
20
+ for c in df.columns:
21
+ s = df[c]
22
+ try:
23
+ out[c] = pd.to_numeric(s) # no 'errors' kwarg
24
+ except Exception:
25
+ out[c] = s
26
+ return pd.DataFrame(out, index=df.index)
8
27
 
9
28
 
10
29
  class SHAPCheck(BaseCheck):
30
+ """
31
+ SHAP for regression + binary classification (no multiclass).
32
+
33
+ Config under rule_config["explainability"]["shap"]:
34
+ - enabled: bool (default True, also checked under rule_config["SHAPCheck"]["enabled"])
35
+ - task: "auto" | "classification" | "regression" (default "auto")
36
+ - algorithm: "auto" | "tree" | "linear" | "kernel" | "permutation" (default "auto")
37
+ - model_output: "auto" | "raw" | "log_odds" | "probability" (tree-only hint; default "auto")
38
+ - background_strategy: "sample" | "kmeans" (default "sample")
39
+ - background_sample_size: int (default 100)
40
+ - test_sample_size: int (default 200)
41
+ - max_display: int (default 20)
42
+ - seed: int (default 42)
43
+ - out_dir: str (optional) (preferred save folder)
44
+ """
45
+
11
46
  def __init__(self, model, X_train, X_test, y_train, y_test, rule_config=None, cleaned_df=None):
12
47
  super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data=cleaned_df)
13
48
  self.cleaned_df = cleaned_df
14
49
 
15
- def run(self):
16
- result = {}
50
+ # -------------------------- helpers --------------------------
17
51
 
52
+ @staticmethod
53
+ def _to_df(X, names=None):
54
+ if isinstance(X, pd.DataFrame):
55
+ return X
56
+ if sp.issparse(X):
57
+ df = pd.DataFrame.sparse.from_spmatrix(X)
58
+ else:
59
+ df = pd.DataFrame(np.asarray(X))
60
+ if names is not None and len(names) == df.shape[1]:
61
+ df.columns = list(names)
62
+ return df
63
+
64
+ @staticmethod
65
+ def _task(y, forced="auto"):
66
+ if forced and forced != "auto":
67
+ return forced
18
68
  try:
19
- expl_cfg = self.rule_config.get("explainability", {})
20
- bg_n = expl_cfg.get("background_sample_size", 100)
21
- test_n = expl_cfg.get("test_sample_size", 200)
69
+ yv = y.iloc[:, 0] if isinstance(y, pd.DataFrame) else (y if isinstance(y, pd.Series) else pd.Series(y))
70
+ uniq = pd.Series(yv).dropna().unique()
71
+ return "classification" if len(uniq) <= 2 else "regression"
72
+ except Exception:
73
+ return "regression"
74
+
75
+ @staticmethod
76
+ def _pos_cls_idx(model, X_one):
77
+ """Return index of the positive class (1/True if available, else max-label) for binary classification."""
78
+ try:
79
+ if hasattr(model, "classes_") and len(model.classes_) == 2:
80
+ classes = list(model.classes_)
81
+ for pos in (1, True):
82
+ if pos in classes:
83
+ return classes.index(pos)
84
+ return classes.index(max(classes))
85
+ if hasattr(model, "predict_proba"):
86
+ proba = model.predict_proba(X_one)
87
+ return 1 if proba.shape[1] == 2 else 0
88
+ except Exception:
89
+ pass
90
+ return 1
91
+
92
+ @staticmethod
93
+ def _looks_like_tree(m):
94
+ mod = type(m).__module__.lower()
95
+ name = type(m).__name__.lower()
96
+ return (
97
+ "xgboost" in mod
98
+ or "lightgbm" in mod
99
+ or "catboost" in mod
100
+ or "sklearn.ensemble" in mod
101
+ or "sklearn.tree" in mod
102
+ or "randomforest" in name
103
+ or "gradientboost" in name
104
+ or "extratrees" in name
105
+ or "decisiontree" in name
106
+ )
107
+
108
+ @staticmethod
109
+ def _looks_like_linear(m):
110
+ mod = type(m).__module__.lower()
111
+ name = type(m).__name__.lower()
112
+ return (
113
+ "sklearn.linear_model" in mod
114
+ or "logistic" in name
115
+ or "linear" in name
116
+ or "ridge" in name
117
+ or "lasso" in name
118
+ or "elastic" in name
119
+ )
120
+
121
+ def _predict_fn(self, is_cls: bool, pos_idx: int | None):
122
+ """
123
+ Vectorized prediction function for permutation/kernel explainers.
124
+ Returns positive-class probability when classification is detected and predict_proba is available.
125
+ """
126
+ if is_cls and hasattr(self.model, "predict_proba"):
127
+ def f(X):
128
+ p = self.model.predict_proba(X)
129
+ i = 1 if (p.ndim == 2 and p.shape[1] == 2) else (pos_idx or 0)
130
+ return p[:, i]
131
+ return f
132
+ return self.model.predict
133
+
134
+ def _explainer(self, algorithm, background, model_output_hint, is_cls, pos_idx):
135
+ """
136
+ Choose fastest viable explainer:
137
+ - tree → TreeExplainer (with interventional perturbation)
138
+ - linear → LinearExplainer
139
+ - default/auto → PermutationExplainer (avoid slow Kernel, unless explicitly requested)
140
+ """
141
+ m = self.model
142
+ alg = (algorithm or "auto").lower()
143
+
144
+ # Prefer fast paths in auto
145
+ if alg == "tree" or (alg == "auto" and self._looks_like_tree(m)):
146
+ mo = None if model_output_hint == "auto" else model_output_hint
147
+ expl = shap.TreeExplainer(
148
+ m, data=background, feature_perturbation="interventional", model_output=mo
149
+ )
150
+ return expl, "tree"
22
151
 
23
- X_sample = self.X_test[:test_n]
24
- background = shap.utils.sample(self.X_train, bg_n, random_state=42)
152
+ if alg == "linear" or (alg == "auto" and self._looks_like_linear(m)):
153
+ return shap.LinearExplainer(m, background), "linear"
25
154
 
26
- X_sample = pd.DataFrame(X_sample)
27
- background = pd.DataFrame(background)
155
+ if alg == "permutation" or alg == "auto":
156
+ fn = self._predict_fn(is_cls, pos_idx)
157
+ return shap.explainers.Permutation(fn, background, max_evals=2000), "perm"
28
158
 
29
- explainer = shap.Explainer(self.model, background)
30
- shap_exp = explainer(X_sample)
159
+ # Only use Kernel if explicitly requested
160
+ if alg == "kernel":
161
+ fn = self._predict_fn(is_cls, pos_idx)
162
+ return shap.KernelExplainer(fn, background), "kernel"
31
163
 
32
- if shap_exp.values.ndim == 3:
33
- shap_exp.values = shap_exp.values[:, :, 1]
34
- shap_exp.base_values = shap_exp.base_values[:, 1]
164
+ # Fallback (should not hit)
165
+ fn = self._predict_fn(is_cls, pos_idx)
166
+ return shap.explainers.Permutation(fn, background, max_evals=2000), "perm"
35
167
 
36
- segment = self.rule_config.get("meta", {}).get("segment", "global")
168
+ # ---------------------------- main ----------------------------
169
+
170
+ def run(self):
171
+ out = {}
172
+ try:
173
+ warnings.filterwarnings("ignore", category=UserWarning)
174
+
175
+ # -------- read config (note: under explainability.shap) ----------
176
+ exp_cfg = (self.rule_config or {}).get("explainability", {}) or {}
177
+ cfg = exp_cfg.get("shap", {}) if isinstance(exp_cfg, dict) else {}
178
+ seed = int(cfg.get("seed", 42))
179
+ bg_n = int(cfg.get("background_sample_size", 100))
180
+ test_n = int(cfg.get("test_sample_size", 200))
181
+ task_forced = (cfg.get("task") or "auto").lower()
182
+ algorithm = (cfg.get("algorithm") or "auto").lower()
183
+ bg_strategy = (cfg.get("background_strategy") or "sample").lower()
184
+ model_output_hint = (cfg.get("model_output") or "auto").lower()
185
+ max_display = int(cfg.get("max_display", 20))
186
+
187
+ # -------- resolve output directory + timestamp ----------
188
+ out_dir_opt = cfg.get("out_dir")
189
+ options_dir = ((self.rule_config or {}).get("options") or {}).get("save_artifacts_dir")
190
+ # prefer explicit shap.out_dir, then global artifacts dir, then local fallback
191
+ outdir = Path(out_dir_opt or options_dir or (Path(__file__).resolve().parents[2] / "tmp_report_assets"))
192
+ outdir.mkdir(parents=True, exist_ok=True)
37
193
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
38
- output_path = Path(f"reports/images/shap_summary_{segment}_{ts}.png")
39
- output_path.parent.mkdir(parents=True, exist_ok=True)
40
194
 
41
- plt.figure(figsize=(8, 6))
42
- shap.plots.beeswarm(shap_exp, show=False)
43
- plt.savefig(output_path, bbox_inches="tight")
195
+ # -------- materialize dataframes & coerce dtypes ----------
196
+ feature_names = list(self.X_train.columns) if isinstance(self.X_train, pd.DataFrame) else None
197
+ Xtr = self._to_df(self.X_train, feature_names)
198
+ Xte = self._to_df(self.X_test, feature_names)
199
+ if Xtr.empty or Xte.empty:
200
+ raise ValueError("Empty X_train or X_test for SHAP.")
201
+
202
+ # Avoid slow implicit object->float conversions during plotting
203
+ Xtr = _safe_numeric_cast_df(Xtr)
204
+ Xte = _safe_numeric_cast_df(Xte)
205
+
206
+
207
+ # -------- task resolution & sanity for binary classification ----------
208
+ task = self._task(self.y_train, forced=task_forced)
209
+ is_cls = (task == "classification")
210
+ if is_cls:
211
+ yv = self.y_train if isinstance(self.y_train, (pd.Series, pd.DataFrame)) else pd.Series(self.y_train)
212
+ if len(pd.Series(yv).dropna().unique()) > 2:
213
+ raise ValueError("Binary classification only: y_train has >2 classes.")
214
+
215
+ # positive-class index hint (used for permutation/kernel predict function)
216
+ pos_idx_hint = self._pos_cls_idx(self.model, Xte.iloc[:1].values)
217
+
218
+ # -------- background selection ----------
219
+ if bg_strategy == "kmeans" and len(Xtr) > bg_n and not sp.issparse(self.X_train):
220
+ background = shap.kmeans(Xtr, bg_n, seed=seed)
221
+ else:
222
+ background = shap.utils.sample(Xtr, bg_n, random_state=seed)
223
+
224
+ # -------- slice test rows to explain ----------
225
+ Xs = Xte.head(test_n)
226
+
227
+ # -------- choose explainer & compute SHAP once ----------
228
+ explainer, kind = self._explainer(algorithm, background, model_output_hint, is_cls, pos_idx_hint)
229
+ if kind == "tree":
230
+ sv = explainer(Xs, check_additivity=False) # big speedup, visually identical plots
231
+ else:
232
+ sv = explainer(Xs)
233
+
234
+ bg_shape = background.shape if hasattr(background, "shape") else None
235
+ print(f"SHAP explainer={type(explainer).__name__} kind={kind} Xs={Xs.shape} "
236
+ f"bg={'kmeans' if bg_shape is None else bg_shape}")
237
+
238
+ # -------- squeeze to 2-D for binary cls (if needed) ----------
239
+ vals = sv.values
240
+ if hasattr(vals, "ndim") and vals.ndim == 3:
241
+ pos_idx = self._pos_cls_idx(self.model, Xs.iloc[:1].values)
242
+ sv.values = vals[:, :, pos_idx]
243
+ if isinstance(sv.base_values, np.ndarray) and sv.base_values.ndim == 2:
244
+ sv.base_values = sv.base_values[:, pos_idx]
245
+ else:
246
+ pos_idx = None if task == "regression" else self._pos_cls_idx(self.model, Xs.iloc[:1].values)
247
+
248
+ # -------- save plots ----------
249
+ segment = "global"
250
+
251
+ beeswarm_path = outdir / f"shap_beeswarm_{segment}_{ts}.png"
252
+ plt.figure(figsize=(9, 6))
253
+ shap.plots.beeswarm(sv, max_display=max_display, show=False)
254
+ plt.tight_layout()
255
+ plt.savefig(beeswarm_path, bbox_inches="tight", dpi=120, transparent=False)
44
256
  plt.close()
45
257
 
46
- print(f"✅ SHAP plot saved at: {output_path}")
47
- result["shap_plot_path"] = str(output_path)
48
- result["status"] = "SHAP plot generated successfully"
258
+ bar_path = outdir / f"shap_bar_{segment}_{ts}.png"
259
+ plt.figure(figsize=(9, 6))
260
+ shap.plots.bar(sv, max_display=max_display, show=False)
261
+ plt.tight_layout()
262
+ plt.savefig(bar_path, bbox_inches="tight", dpi=120, transparent=False)
263
+ plt.close()
49
264
 
50
- except Exception:
51
- err = traceback.format_exc()
52
- print(f"⚠️ SHAPCheck failed:\n{err}")
53
- result["status"] = f"SHAP plot failed:\n{err}"
265
+ # -------- top features ----------
266
+ # mean absolute SHAP across rows
267
+ mean_abs = np.abs(sv.values).mean(axis=0)
268
+ idx = np.argsort(mean_abs)[::-1][:max_display]
269
+ cols = list(Xs.columns)
270
+ top = [{"feature": cols[i] if i < len(cols) else f"f{i}", "mean_abs_shap": float(mean_abs[i])} for i in idx]
271
+ top_list_pairs = [[d["feature"], d["mean_abs_shap"]] for d in top] # compat with old report builder
54
272
 
55
- return result
273
+ out.update({
274
+ "status": "ok",
275
+ "task": task,
276
+ "positive_class_index": pos_idx if task == "classification" else None,
277
+ # new + old keys (backward-compatible)
278
+ "plots": {"beeswarm": str(beeswarm_path), "bar": str(bar_path)},
279
+ "images": {"beeswarm": str(beeswarm_path), "bar": str(bar_path)},
280
+ "top_features": top,
281
+ "shap_top_features": top_list_pairs,
282
+ })
283
+ print(f"✅ SHAP saved: {beeswarm_path}, {bar_path}")
284
+
285
+ except Exception:
286
+ out["status"] = "error: " + traceback.format_exc()
287
+ return out
@@ -1,59 +1,191 @@
1
1
  # tanml/checks/logit_stats.py
2
-
2
+ import numpy as np
3
+ import pandas as pd
3
4
  import statsmodels.api as sm
5
+ from contextlib import suppress
6
+
7
+ # ---------- helpers: winsorize + standardize (no sklearn) ----------
8
+ def _winsorize_df(df: pd.DataFrame, low_q=0.005, high_q=0.995) -> pd.DataFrame:
9
+ with suppress(Exception):
10
+ q_low = df.quantile(low_q)
11
+ q_hi = df.quantile(high_q)
12
+ df = df.clip(lower=q_low, upper=q_hi, axis=1)
13
+ return df
14
+
15
+ def _standardize_df(df: pd.DataFrame) -> pd.DataFrame:
16
+ mu = df.mean()
17
+ sd = df.std(ddof=0).replace(0, np.nan)
18
+ z = (df - mu) / sd
19
+ return z.fillna(0.0)
20
+
21
+ def _prep_design_matrix_df(
22
+ X: pd.DataFrame,
23
+ ref_columns: pd.Index | None = None,
24
+ add_const: bool = True
25
+ ) -> pd.DataFrame:
26
+ """
27
+ Consistent design prep for statsmodels:
28
+ - ensure DataFrame
29
+ - one-hot encode categoricals (drop_first=True)
30
+ - align to ref_columns (if provided; excludes 'const')
31
+ - coerce to numeric, replace inf, fillna(0)
32
+ - drop zero-variance columns
33
+ - winsorize tails, standardize
34
+ - optionally add intercept
35
+ """
36
+ Xd = pd.DataFrame(X).copy()
37
+ Xd = pd.get_dummies(Xd, drop_first=True)
38
+
39
+ if ref_columns is not None:
40
+ ref_wo_const = [c for c in ref_columns if c != "const"]
41
+ Xd = Xd.reindex(columns=ref_wo_const, fill_value=0.0)
42
+
43
+ # numeric + clean
44
+ for c in Xd.columns:
45
+ Xd[c] = pd.to_numeric(Xd[c], errors="coerce")
46
+ Xd = Xd.replace([np.inf, -np.inf], np.nan).fillna(0.0)
4
47
 
5
- class LogisticStatsCheck:
6
- def __init__(self, model, X, y, config):
7
- self.model = model
8
- self.X = X
9
- self.y = y
10
- self.config = config or {}
48
+ # drop zero-variance columns (except 'const' which isn't present yet)
49
+ var = Xd.var(numeric_only=True)
50
+ drop_cols = list(var.index[var == 0])
51
+ if drop_cols:
52
+ Xd = Xd.drop(columns=drop_cols, errors="ignore")
11
53
 
12
- def run(self):
54
+ # stabilize logits: winsorize + standardize
55
+ Xd = _winsorize_df(Xd, 0.005, 0.995)
56
+ Xd = _standardize_df(Xd)
57
+
58
+ if add_const:
59
+ Xd = sm.add_constant(Xd, has_constant="add")
60
+ return Xd
61
+
62
+ def _encode_binary_target(y) -> pd.Series:
63
+ yb = pd.Series(y)
64
+ uniq = pd.unique(yb.dropna())
65
+ try:
66
+ cats = pd.Series(uniq).astype("category").cat.categories
67
+ if set(map(int, cats)) == {0, 1}:
68
+ return yb.astype(int)
69
+ except Exception:
70
+ pass
71
+ if len(uniq) != 2:
72
+ raise ValueError("Logit requires a binary target with exactly two classes.")
73
+ # map majority -> 0, minority -> 1 (stable ordering)
74
+ counts = yb.value_counts().sort_values(ascending=False).index.tolist()
75
+ return yb.map({counts[0]: 0, counts[1]: 1}).astype(int)
76
+
77
+ def _fit_stage1_ridge(yb: pd.Series, Xd: pd.DataFrame, alpha_grid=(1.0, 0.3, 0.1, 0.03, 0.01)):
78
+ """
79
+ Stage-1: ridge-regularized logit for stability. Returns best result and alpha.
80
+ """
81
+ best = None
82
+ for a in alpha_grid:
13
83
  try:
14
- # 1) Add constant and fit the statsmodels Logit
15
- Xc = sm.add_constant(self.X, has_constant='add')
16
- res = sm.Logit(self.y, Xc).fit(disp=False)
17
-
18
- # 2) Extract coefficient table
19
- coef = res.params
20
- stderr = res.bse
21
- zscore = coef / stderr
22
- pvals = res.pvalues
23
-
24
- table = []
25
- for feat in coef.index:
26
- label = "Intercept" if feat.lower() == "const" else feat
27
- table.append({
28
- "feature": label,
29
- "coefficient": float(coef[feat]),
30
- "std_error": float(stderr[feat]),
31
- "z_score": float(zscore[feat]),
32
- "p_value": float(pvals[feat]),
33
- })
34
-
35
- # 3) Fit statistics
36
- fit = {
37
- "log_lik": float(res.llf),
38
- "aic": float(res.aic),
39
- "bic": float(res.bic),
40
- "pseudo_r2": float(res.prsquared),
41
- }
42
-
43
- # 4) Full summary text
44
- summary = res.summary().as_text()
45
-
46
- return {
47
- "table": table,
48
- "fit": fit,
49
- "summary": summary,
50
- "object": res
51
- }
52
-
53
- except Exception as e:
54
- return {
55
- "table": [],
56
- "fit": {},
57
- "summary": f"LogisticStatsCheck failed: {e}",
58
- "object": None
59
- }
84
+ res = sm.Logit(yb, Xd).fit_regularized(alpha=a, L1_wt=0.0, maxiter=2000, disp=False)
85
+ coef = res.params.values
86
+ if not np.all(np.isfinite(coef)):
87
+ continue
88
+ score = float(getattr(res, "llf", -np.inf))
89
+ cand = {"alpha": a, "res": res, "score": score}
90
+ if best is None or cand["score"] > best["score"] + 1e-9 or \
91
+ (abs(cand["score"] - best["score"]) <= 1e-9 and a < best["alpha"]):
92
+ best = cand
93
+ except Exception:
94
+ continue
95
+ if best is None:
96
+ # stronger ridge last-ditch
97
+ res = sm.Logit(yb, Xd).fit_regularized(alpha=3.0, L1_wt=0.0, maxiter=2000, disp=False)
98
+ best = {"alpha": 3.0, "res": res, "score": float(getattr(res, "llf", -np.inf))}
99
+ return best
100
+
101
+ def _fit_stage2_inference(yb: pd.Series, Xd: pd.DataFrame, start_params: np.ndarray):
102
+ """
103
+ Stage-2: try unpenalized MLE (for classic p-values). If it fails, fall back to GLM Binomial with robust SEs.
104
+ Returns (model_tag, results).
105
+ """
106
+ # Unpenalized MLE with good starts
107
+ try:
108
+ mle = sm.Logit(yb, Xd).fit(start_params=start_params, method="lbfgs", maxiter=5000, disp=False)
109
+ return ("logit_mle", mle)
110
+ except Exception:
111
+ pass
112
+
113
+ # Robust GLM fallback
114
+ glm = sm.GLM(yb, Xd, family=sm.families.Binomial())
115
+ glm_res = glm.fit(cov_type="HC3") # robust/sandwich SEs
116
+ return ("glm_robust", glm_res)
117
+
118
+ def compute_logit_stats(X: pd.DataFrame, y) -> dict:
119
+ """
120
+ Train-only stats summary (Sections a & b):
121
+ - robust design prep (winsorize + z-score)
122
+ - Stage-1 ridge to stabilize
123
+ - Stage-2 unpenalized MLE for inference (fallback to GLM robust)
124
+ - returns summary text and a tidy coef table
125
+ """
126
+ # design + target
127
+ Xd = _prep_design_matrix_df(X, ref_columns=None, add_const=True)
128
+ yb = _encode_binary_target(y)
129
+
130
+ # Stage-1: ridge (stability)
131
+ stg1 = _fit_stage1_ridge(yb, Xd, alpha_grid=(1.0, 0.3, 0.1, 0.03, 0.01))
132
+ start = stg1["res"].params.values
133
+
134
+ # Stage-2: inference
135
+ model_tag, res = _fit_stage2_inference(yb, Xd, start_params=start)
136
+
137
+ # summary text
138
+ with suppress(Exception):
139
+ # summary2() is prettier when available
140
+ summary_text = res.summary2().as_text()
141
+ if 'summary_text' not in locals():
142
+ summary_text = str(res.summary())
143
+
144
+ # coef table
145
+ params = res.params
146
+ bse = res.bse
147
+ # z or t values depending on object; label as z
148
+ if hasattr(res, "tvalues"):
149
+ stat_vals = res.tvalues
150
+ elif hasattr(res, "zvalues"):
151
+ stat_vals = res.zvalues
152
+ else:
153
+ stat_vals = params / bse.replace(0, np.nan)
154
+
155
+ pvals = res.pvalues
156
+ ci = res.conf_int(alpha=0.05)
157
+ ci.columns = ["ci_low", "ci_high"]
158
+
159
+ coef_df = pd.DataFrame({
160
+ "feature": params.index,
161
+ "coef": params.values,
162
+ "std err": bse.values,
163
+ "z": stat_vals.values,
164
+ "P>|z|": pvals.values,
165
+ "ci_low": ci["ci_low"].values,
166
+ "ci_high": ci["ci_high"].values,
167
+ })
168
+
169
+ # const first
170
+ if "const" in coef_df["feature"].values:
171
+ coef_df = pd.concat(
172
+ [coef_df.loc[coef_df["feature"] == "const"],
173
+ coef_df.loc[coef_df["feature"] != "const"].sort_values("feature")],
174
+ ignore_index=True
175
+ )
176
+
177
+ # round to 6 decimals so small effects don't show up as 0.000000
178
+ for c in ["coef","std err","z","P>|z|","ci_low","ci_high"]:
179
+ coef_df[c] = pd.to_numeric(coef_df[c], errors="coerce").round(6)
180
+
181
+ # annotate header line to reflect fallback if used
182
+ if model_tag == "glm_robust":
183
+ summary_text = "Results: GLM Binomial (robust SE)\n" + summary_text
184
+ else:
185
+ summary_text = "Results: Logit (unpenalized MLE)\n" + summary_text
186
+
187
+ return {
188
+ "summary_text": summary_text,
189
+ "coef_table_headers": ["feature","coef","std err","z","P>|z|","ci_low","ci_high"],
190
+ "coef_table_rows": coef_df.to_dict(orient="records"),
191
+ }