tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,18 @@
1
1
  # tanml/report/report_builder.py
2
- from docxtpl import DocxTemplate, InlineImage
3
- from docx.shared import Inches, Mm
4
2
  from docx import Document
3
+ from docx.shared import Inches, Mm
5
4
  from pathlib import Path
6
- import os, imgkit, copy as pycopy
5
+ import os, re, math, copy as pycopy
7
6
  from importlib.resources import files
7
+ import numpy as np # needed for rounding helpers etc.
8
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
9
+ from docx.enum.table import WD_ALIGN_VERTICAL
10
+
11
+
8
12
 
9
13
  TMP_DIR = Path(__file__).resolve().parents[1] / "tmp_report_assets"
10
14
  TMP_DIR.mkdir(exist_ok=True)
11
15
 
12
-
13
16
  class AttrDict(dict):
14
17
  def __getattr__(self, item):
15
18
  try:
@@ -17,20 +20,432 @@ class AttrDict(dict):
17
20
  except KeyError:
18
21
  raise AttributeError(item)
19
22
 
23
+ def _read_csv_as_rows(path, max_rows=None):
24
+ import csv
25
+ rows = []
26
+ if not path or not os.path.exists(path):
27
+ return rows
28
+ with open(path, newline="", encoding="utf-8") as f:
29
+ r = csv.DictReader(f)
30
+ for i, row in enumerate(r):
31
+ if max_rows is not None and i >= max_rows:
32
+ break
33
+ rows.append(row)
34
+ return rows
35
+
36
+
37
+ USE_TEXT_PLACEHOLDERS = True
38
+ TXT = {
39
+ "not_applicable": "Not applicable",
40
+ "not_provided": "Dataset not provided",
41
+ "none_detected": "None detected",
42
+ "no_issues": "None (no issues detected)",
43
+ "unknown": "Unknown",
44
+ "dash": "—",
45
+ }
46
+ def _p(key: str) -> str:
47
+ return TXT["dash"] if not USE_TEXT_PLACEHOLDERS else TXT.get(key, TXT["unknown"])
48
+
49
+
50
+ def _fallback_pairs_from_corr_matrix(corr_csv_path: str, threshold: float, top_k: int = 200):
51
+ """
52
+ Build 'top pairs' rows from a saved correlation MATRIX CSV when the engine
53
+ did not emit 'top_pairs_main_csv'/'top_pairs_csv'.
54
+
55
+ Returns a list of dicts with headers:
56
+ ["feature_i", "feature_j", "corr", "n_used", "pct_missing_i", "pct_missing_j"]
57
+ """
58
+ rows = []
59
+ if not corr_csv_path or not os.path.exists(corr_csv_path):
60
+ return rows
61
+ try:
62
+ import pandas as pd
63
+ df = pd.read_csv(corr_csv_path, index_col=0)
64
+ # Ensure square matrix with aligned labels
65
+ if df.shape[0] == 0 or df.shape[0] != df.shape[1]:
66
+ return rows
67
+ cols = list(df.columns)
68
+
69
+ # Collect upper triangle |r| >= threshold
70
+ pairs = []
71
+ for i in range(len(cols)):
72
+ for j in range(i + 1, len(cols)):
73
+ try:
74
+ r = float(df.iloc[i, j])
75
+ except Exception:
76
+ continue
77
+ if abs(r) >= float(threshold):
78
+ pairs.append((cols[i], cols[j], r))
79
+
80
+ # Sort by absolute correlation desc, then take top_k
81
+ pairs.sort(key=lambda t: abs(t[2]), reverse=True)
82
+ pairs = pairs[: int(top_k)]
83
+
84
+ # Map to expected schema; we don't have n_used / pct_missing_* here
85
+ for a, b, r in pairs:
86
+ rows.append({
87
+ "feature_i": a,
88
+ "feature_j": b,
89
+ "corr": f"{float(r):.4f}",
90
+ "n_used": "", # unknown in matrix-only fallback
91
+ "pct_missing_i": "", # unknown
92
+ "pct_missing_j": "", # unknown
93
+ })
94
+ return rows
95
+ except Exception:
96
+ return []
97
+
98
+ def build_table(doc, headers, rows):
99
+ tbl = doc.add_table(rows=1, cols=len(headers))
100
+ tbl.style = "Table Grid"
101
+ for i, h in enumerate(headers):
102
+ tbl.rows[0].cells[i].text = str(h)
103
+ for r in rows:
104
+ vals = [r.get(h, "") for h in headers] if isinstance(r, dict) else list(r)
105
+ vals += [""] * (len(headers) - len(vals))
106
+ row = tbl.add_row().cells
107
+ for i, v in enumerate(vals):
108
+ row[i].text = str(v)
109
+ return tbl
110
+
111
+ # --- Pretty Decile Lift Table (3-line headers, auto-fit, numeric alignment) ---
112
+ def build_decile_lift_table(doc, headers, rows):
113
+ """
114
+ Renders the decile lift table with:
115
+ - compact labels
116
+ - up to THREE lines per header cell (forced line breaks)
117
+ - widths auto-scaled to page text width
118
+ - right-aligned numeric columns
119
+ """
120
+
121
+ # 1) Compact labels (keep Total)
122
+ label_map = {
123
+ "decile": "Decile",
124
+ "total": "Total",
125
+ "events": "Events",
126
+ "avg_score": "Avg score",
127
+ "event_rate": "Event rate",
128
+ "lift": "Lift",
129
+ "cum_events": "Cum. events",
130
+ "cum_total": "Cum. total",
131
+ "cum_capture_rate": "Cum. capture rate",
132
+ "cum_population": "Cum. population",
133
+ "cum_gain": "Cum. gain",
134
+ }
135
+ pretty_headers = [label_map.get(h.strip(), h.replace("_", " ").strip()) for h in headers]
136
+
137
+ # 2) Helper: force up to 3 lines per label (balanced by words)
138
+ def split_to_max_lines(label: str, max_lines: int = 3):
139
+ parts = str(label).split()
140
+ if len(parts) <= 1:
141
+ return [label]
142
+ # Greedy balance into up to max_lines buckets
143
+ buckets = [[] for _ in range(max_lines)]
144
+ # pre-fill minimal distribution
145
+ for i, w in enumerate(parts):
146
+ buckets[i % max_lines].append(w)
147
+ # join each bucket; trim empty
148
+ lines = [" ".join(b).strip() for b in buckets if b]
149
+ # remove trailing empties
150
+ lines = [ln for ln in lines if ln]
151
+ return lines[:max_lines]
152
+
153
+ # 3) Build table (single header row; we’ll insert line breaks in each header cell)
154
+ tbl = doc.add_table(rows=1, cols=len(pretty_headers))
155
+ tbl.style = "Table Grid"
156
+ tbl.autofit = False
157
+
158
+ # 4) Header row with forced breaks (up to 3 lines)
159
+ hdr_row = tbl.rows[0]
160
+ for j, h in enumerate(pretty_headers):
161
+ cell = hdr_row.cells[j]
162
+ # Clear cell safely
163
+ cell.text = ""
164
+ p = cell.paragraphs[0] if cell.paragraphs else cell.add_paragraph()
165
+ p.alignment = WD_ALIGN_PARAGRAPH.CENTER
166
+ cell.vertical_alignment = WD_ALIGN_VERTICAL.CENTER
167
+
168
+ lines = split_to_max_lines(h, max_lines=3)
169
+ if not lines:
170
+ continue
171
+ run = p.add_run(str(lines[0]))
172
+ for seg in lines[1:]:
173
+ run.add_break() # hard line break
174
+ p.add_run(str(seg))
175
+
176
+ # 5) Body rows (preserve original header order)
177
+ for r in rows:
178
+ vals = [r.get(h, "") for h in headers] if isinstance(r, dict) else list(r)
179
+ vals += [""] * (len(pretty_headers) - len(vals))
180
+ cells = tbl.add_row().cells
181
+ for j, v in enumerate(vals):
182
+ cells[j].text = "" if v is None else str(v)
183
+
184
+ # 6) Base widths (tight) in inches — will be scaled to fit page text width
185
+ # Order corresponds to incoming headers: decile, total, events, ...
186
+ base_widths = [0.50, 0.60, 0.60, 0.70, 0.70, 0.55, 0.80, 0.80, 0.95, 0.95, 0.85]
187
+ widths = base_widths[:len(pretty_headers)]
188
+
189
+ # 7) Compute usable width from document section (with a small safety margin)
190
+ try:
191
+ sec = doc.sections[0]
192
+ usable_in = float(sec.page_width - sec.left_margin - sec.right_margin) / 914400.0
193
+ usable_in = max(usable_in - 0.05, 5.8)
194
+ except Exception:
195
+ usable_in = 6.7
196
+
197
+ # 8) Auto-scale to fit
198
+ total_w = sum(widths)
199
+ if total_w > 0 and usable_in > 0:
200
+ scale = min(1.0, usable_in / total_w)
201
+ widths = [w * scale for w in widths]
202
+
203
+ # 9) Apply widths
204
+ for j, w in enumerate(widths):
205
+ for row in tbl.rows:
206
+ row.cells[j].width = Inches(w)
207
+
208
+ # 10) Right-align numeric columns
209
+ numeric_like = {
210
+ "Total", "Events", "Avg score", "Event rate", "Lift",
211
+ "Cum. events", "Cum. total", "Cum. capture rate",
212
+ "Cum. population", "Cum. gain",
213
+ }
214
+ for i, h in enumerate(pretty_headers):
215
+ right = h in numeric_like
216
+ for row in tbl.rows[1:]:
217
+ for p in row.cells[i].paragraphs:
218
+ p.alignment = WD_ALIGN_PARAGRAPH.RIGHT if right else WD_ALIGN_PARAGRAPH.LEFT
219
+
220
+ return tbl
221
+
222
+
223
+
224
+ def insert_after(doc, anchor, tbl):
225
+ for p in doc.paragraphs:
226
+ if anchor.lower() in p.text.lower():
227
+ parent = p._p.getparent()
228
+ parent.insert(parent.index(p._p) + 1, tbl._tbl)
229
+ return
230
+ print(f"⚠️ anchor «{anchor}» not found")
20
231
 
232
+ def insert_image_grid(doc, anchor: str, img_paths, cols: int = 3, width_in: float = 2.2):
233
+ paths = [p for p in (img_paths or []) if p and os.path.exists(p)]
234
+ if not paths:
235
+ for p in doc.paragraphs:
236
+ if anchor.lower() in p.text.lower():
237
+ after = p.insert_paragraph_after()
238
+ after.add_run("(no plots available)")
239
+ return
240
+ print(f"⚠️ anchor «{anchor}» not found (no plots)")
241
+ return
242
+
243
+ rows = max(1, math.ceil(len(paths) / max(1, cols)))
244
+ tbl = doc.add_table(rows=rows, cols=cols)
245
+ tbl.style = "Table Grid"
246
+ i = 0
247
+ for r in range(rows):
248
+ for c in range(cols):
249
+ cell = tbl.cell(r, c)
250
+ if i < len(paths):
251
+ para = cell.paragraphs[0]
252
+ para.text = ""
253
+ run = para.add_run()
254
+ run.add_picture(paths[i], width=Inches(width_in))
255
+ i += 1
256
+ else:
257
+ cell.text = ""
258
+ insert_after(doc, anchor, tbl)
259
+
260
+ # ---------- placeholder & image replacement ----------
261
+ _PLACEHOLDER = re.compile(r"\{\{([A-Za-z0-9_\.]+)\}\}")
262
+ _IMG_MARKER = re.compile(r"\[\[IMG:([A-Za-z0-9_\-]+)\]\]")
263
+
264
+ def _get_nested(d, dotted, default=""):
265
+ cur = d
266
+ for part in dotted.split("."):
267
+ if not isinstance(cur, dict) or part not in cur:
268
+ return default
269
+ cur = cur[part]
270
+ return cur
271
+
272
+ def _replace_text_placeholders(doc, mapping):
273
+ # paragraphs
274
+ for p in doc.paragraphs:
275
+ full = p.text
276
+ def repl(m):
277
+ key = m.group(1)
278
+ return str(mapping.get(key, _get_nested(mapping, key, "")) or "")
279
+ new_full = _PLACEHOLDER.sub(repl, full)
280
+ if new_full != full:
281
+ for _ in range(len(p.runs)-1, -1, -1):
282
+ p.runs[_].text = ""
283
+ if p.runs:
284
+ p.runs[0].text = new_full
285
+ else:
286
+ p.add_run(new_full)
287
+ # table cells
288
+ for t in doc.tables:
289
+ for row in t.rows:
290
+ for cell in row.cells:
291
+ for p in cell.paragraphs:
292
+ full = p.text
293
+ new_full = _PLACEHOLDER.sub(
294
+ lambda m: str(mapping.get(m.group(1), _get_nested(mapping, m.group(1), "")) or ""),
295
+ full
296
+ )
297
+ if new_full != full:
298
+ for _ in range(len(p.runs)-1, -1, -1):
299
+ p.runs[_].text = ""
300
+ if p.runs:
301
+ p.runs[0].text = new_full
302
+ else:
303
+ p.add_run(new_full)
304
+
305
+ def _replace_image_markers(doc, images_map, *, width_in=4.8):
306
+ def _put_image_in_paragraph(p, path):
307
+ if path and os.path.exists(path):
308
+ p.text = ""
309
+ run = p.add_run()
310
+ run.add_picture(path, width=Inches(width_in))
311
+ else:
312
+ p.text = "(image not available)"
313
+ # paragraphs
314
+ for p in list(doc.paragraphs):
315
+ m = _IMG_MARKER.search(p.text)
316
+ if m:
317
+ key = m.group(1)
318
+ _put_image_in_paragraph(p, images_map.get(key))
319
+ # table cells
320
+ for t in doc.tables:
321
+ for row in t.rows:
322
+ for cell in row.cells:
323
+ for p in cell.paragraphs:
324
+ m = _IMG_MARKER.search(p.text)
325
+ if m:
326
+ key = m.group(1)
327
+ _put_image_in_paragraph(p, images_map.get(key))
328
+
329
+ # ---------- formatters ----------
330
+ def _fmt2(x, nd=2, dash="—"):
331
+ try:
332
+ if x is None: return dash
333
+ xf = float(x)
334
+ if xf != xf: # NaN
335
+ return dash
336
+ return f"{xf:.{nd}f}"
337
+ except Exception:
338
+ return dash
339
+
340
+ def _fmt_ratio_or_pct(x, pct=False, dash="—", nd=2):
341
+ try:
342
+ if x is None:
343
+ return dash
344
+ xf = float(x)
345
+ if pct:
346
+ return f"{xf*100:.{nd}f}%"
347
+ return f"{xf:.{nd}f}"
348
+ except Exception:
349
+ return dash
350
+
351
+ def _fmt_list(lst, *, max_items=20, sep=", "):
352
+ if not lst:
353
+ return "—"
354
+ if isinstance(lst, (str, bytes)):
355
+ return str(lst)
356
+ try:
357
+ seq = list(lst)
358
+ except Exception:
359
+ return str(lst)
360
+ if len(seq) <= max_items:
361
+ return sep.join(map(str, seq))
362
+ head = sep.join(map(str, seq[:max_items]))
363
+ return f"{head}{sep}… (+{len(seq)-max_items} more)"
364
+
365
+
366
+ # put this right below _fmt_list(...)
367
+ def _fmt_list_or_message(lst, *, empty_msg="None (no issues detected)", max_items=20, sep=", "):
368
+ if not lst:
369
+ return empty_msg
370
+ try:
371
+ seq = list(lst) if not isinstance(lst, (str, bytes)) else [lst]
372
+ except Exception:
373
+ return str(lst)
374
+ if len(seq) <= max_items:
375
+ return sep.join(map(str, seq))
376
+ head = sep.join(map(str, seq[:max_items]))
377
+ return f"{head}{sep}… (+{len(seq)-max_items} more)"
378
+
379
+
380
+ def _fmt_feature_names(v, *, max_names=30):
381
+ if v is None:
382
+ return ""
383
+ if isinstance(v, (list, tuple)):
384
+ if len(v) <= max_names:
385
+ return ", ".join(map(str, v))
386
+ head = ", ".join(map(str, v[:max_names]))
387
+ return f"{head}, … (+{len(v)-max_names} more)"
388
+ return str(v)
389
+
390
+ def _fmt_target_balance(tb):
391
+ if not isinstance(tb, dict) or not tb:
392
+ return ""
393
+ vals = list(tb.values())
394
+ are_probs = all(isinstance(x, (int, float)) and 0 <= float(x) <= 1 for x in vals)
395
+ items = []
396
+ for k in sorted(tb.keys()):
397
+ v = tb[k]
398
+ if are_probs:
399
+ try:
400
+ items.append(f"{k}: {float(v)*100:.1f}%")
401
+ except Exception:
402
+ items.append(f"{k}: {v}")
403
+ else:
404
+ items.append(f"{k}: {v}")
405
+ return ", ".join(items)
406
+
407
+ # ---------- SHAP helpers ----------
408
+ def _attach_shap_to_context(results: dict, context: dict) -> None:
409
+ shap_ctx = {
410
+ "shap_section": False,
411
+ "shap_beeswarm_path": None,
412
+ "shap_bar_path": None,
413
+ "shap_top_features": [],
414
+ }
415
+ try:
416
+ shap_res = (results or {}).get("SHAPCheck") or {}
417
+ if "SHAPCheck" in shap_res and isinstance(shap_res["SHAPCheck"], dict):
418
+ shap_res = shap_res["SHAPCheck"]
419
+ plots = shap_res.get("plots") or {}
420
+ bees = plots.get("beeswarm")
421
+ barp = plots.get("bar")
422
+ legacy = shap_res.get("shap_plot_path")
423
+ if not bees and legacy:
424
+ bees = legacy
425
+ shap_ctx["shap_beeswarm_path"] = bees if bees and os.path.exists(bees) else None
426
+ shap_ctx["shap_bar_path"] = barp if barp and os.path.exists(barp) else None
427
+ shap_ctx["shap_top_features"] = shap_res.get("top_features") or []
428
+ shap_ctx["shap_section"] = bool(
429
+ shap_ctx["shap_beeswarm_path"] or shap_ctx["shap_bar_path"] or shap_ctx["shap_top_features"]
430
+ )
431
+ except Exception as e:
432
+ print("⚠️ SHAP context attach failed:", e)
433
+ context.update(shap_ctx)
434
+
435
+ # --- main class --------------------------------------------------------------
21
436
  class ReportBuilder:
22
- """Build a Word report from validation results."""
437
+ """Build a Word report from validation results (no docxtpl)."""
23
438
 
24
439
  def __init__(self, results, template_path, output_path):
25
- self.results = results
440
+ self.results = results or {}
26
441
  self.template_path = template_path or files("tanml.report.templates").joinpath("report_template.docx")
27
-
28
442
  self.output_path = output_path
29
443
 
30
- corr = results.get("CorrelationCheck", {})
31
- self.corr_heatmap_path = corr.get("heatmap_path")
32
- self.corr_pearson_path = corr.get("pearson_csv", "N/A")
33
- self.corr_spearman_path = corr.get("spearman_csv", "N/A")
444
+ corr = self.results.get("CorrelationCheck", {}) or {}
445
+ corr_artifacts = corr.get("artifacts", corr)
446
+ self.corr_heatmap_path = corr_artifacts.get("heatmap_path")
447
+ self.corr_pearson_path = corr_artifacts.get("pearson_csv", "N/A")
448
+ self.corr_spearman_path = corr_artifacts.get("spearman_csv", "N/A")
34
449
 
35
450
  def _grab(self, name, default=None):
36
451
  return (
@@ -39,192 +454,810 @@ class ReportBuilder:
39
454
  or default
40
455
  )
41
456
 
42
- def build(self):
43
- doc = DocxTemplate(str(self.template_path))
44
-
457
+ @staticmethod
458
+ def _extract_baseline_logit_rows(ctx):
459
+ """
460
+ Read Logit baseline metrics from ctx["LogitStats"]["baseline_metrics"]
461
+ and map to a simple 2-col table.
462
+ """
463
+ logit = (ctx.get("LogitStats") or {})
464
+ bm = logit.get("baseline_metrics") or {}
465
+ summary = bm.get("summary") or bm # support both shapes
45
466
 
46
- # Ensure RuleEngineCheck exists so template never crashes
47
- self.results.setdefault(
48
- "RuleEngineCheck", AttrDict({"rules": {}, "overall_pass": True})
49
- )
467
+ order = [
468
+ ("ROC-AUC", ("AUC","auc","roc_auc","rocauc")),
469
+ ("KS", ("KS","ks")),
470
+ ("F1", ("F1","f1","f1_score")),
471
+ ("PR AUC", ("PR_AUC","pr_auc","average_precision","prauc")),
472
+ ("Gini", ("GINI","gini")),
473
+ ("Precision", ("Precision","precision","prec")),
474
+ ("Recall", ("Recall","recall","tpr","sensitivity")),
475
+ ("Accuracy", ("Accuracy","accuracy","acc")),
476
+ ("Brier Score", ("Brier","brier","brier_score")),
477
+ ("Log Loss", ("Log Loss","log_loss","logloss")),
478
+ ]
479
+ rows = []
480
+ for label, keys in order:
481
+ for k in keys:
482
+ if k in summary and summary[k] is not None:
483
+ v = summary[k]
484
+ try:
485
+ v = f"{float(v):.4f}"
486
+ except Exception:
487
+ v = str(v)
488
+ rows.append({"metric": label, "value": v})
489
+ break
490
+ return rows
50
491
 
51
- # Jinja context
492
+ def build(self):
493
+ # ===== 1) Build context =================================================
494
+ self.results.setdefault("RuleEngineCheck", AttrDict({"rules": {}, "overall_pass": True}))
52
495
  ctx = pycopy.deepcopy(self.results)
53
- ctx.update(self.results.get("check_results", {})) # 1st-level flatten
496
+ ctx.update(self.results.get("check_results", {}))
54
497
 
498
+ # unwrap {"Key":{"Key":...}}
55
499
  for k, v in list(ctx.items()):
56
500
  if isinstance(v, dict) and k in v and len(v) == 1:
57
501
  ctx[k] = v[k]
58
502
 
503
+ # defaults
59
504
  for k, note in [
60
505
  ("RawDataCheck", "Raw-data check skipped"),
61
506
  ("CleaningReproCheck", "Cleaning-repro check skipped"),
62
507
  ]:
63
508
  ctx.setdefault(k, AttrDict({"note": note}))
64
509
 
510
+ # Model meta fallback
65
511
  if "ModelMetaCheck" not in ctx or "model_type" not in ctx["ModelMetaCheck"]:
66
512
  meta_fields = [
67
- "model_type",
68
- "model_class",
69
- "module",
70
- "n_features",
71
- "feature_names",
72
- "n_train_rows",
73
- "target_balance",
74
- "hyperparam_table",
75
- "attributes",
513
+ "model_type", "model_class", "module", "n_features", "feature_names",
514
+ "n_train_rows", "target_balance", "hyperparam_table", "attributes",
76
515
  ]
77
516
  meta = {f: self.results.get(f) for f in meta_fields if self.results.get(f) is not None}
78
517
  ctx["ModelMetaCheck"] = AttrDict(meta or {"note": "Model metadata not available"})
79
518
 
80
- # SHAP image
81
- shap_path = self._grab("SHAPCheck", {}).get("shap_plot_path")
82
- ctx["shap_plot"] = (
83
- InlineImage(doc, shap_path, width=Inches(5))
84
- if shap_path and os.path.exists(shap_path)
85
- else "SHAP plot not available"
86
- )
519
+ # SHAP
520
+ _attach_shap_to_context(self.results, ctx)
87
521
 
88
- # EDA
522
+ # EDA
89
523
  eda = self._grab("EDACheck", {})
90
524
  ctx["eda_summary_path"] = eda.get("summary_stats", "N/A")
91
525
  ctx["eda_missing_path"] = eda.get("missing_values", "N/A")
92
- ctx["eda_images"] = [
93
- InlineImage(doc, os.path.join("reports/eda", fn), width=Inches(4.5))
526
+ ctx["eda_images_paths"] = [
527
+ os.path.join("reports/eda", fn)
528
+ for fn in (eda.get("visualizations", []) or [])
94
529
  if os.path.exists(os.path.join("reports/eda", fn))
95
- else f"Missing: {fn}"
96
- for fn in eda.get("visualizations", [])
97
530
  ]
98
531
 
99
- # Correlation visuals
100
- if self.corr_heatmap_path and os.path.exists(self.corr_heatmap_path):
101
- ctx["correlation_heatmap"] = InlineImage(
102
- doc, self.corr_heatmap_path, width=Inches(5)
103
- )
104
- else:
105
- ctx["correlation_heatmap"] = "Heatmap not available"
532
+ # Correlation
533
+ corr_res = self._grab("CorrelationCheck", {}) or {}
534
+ corr_artifacts = corr_res.get("artifacts", corr_res)
535
+ corr_summary = corr_res.get("summary", {}) or {}
536
+
537
+ def _pick(d, *keys, default=None):
538
+ for k in keys:
539
+ if k in d and d[k] is not None:
540
+ return d[k]
541
+ return default
542
+
543
+ top_pairs_csv = (
544
+ corr_artifacts.get("top_pairs_csv")
545
+ or corr_artifacts.get("pairs_csv")
546
+ or corr_artifacts.get("csv")
547
+ )
548
+ top_pairs_main_csv = (
549
+ corr_artifacts.get("top_pairs_main_csv")
550
+ or corr_artifacts.get("main_csv")
551
+ )
552
+ heatmap_path = (
553
+ corr_artifacts.get("heatmap_path")
554
+ or corr_artifacts.get("heatmap")
555
+ or self.corr_heatmap_path
556
+ )
557
+
558
+ method = _pick(corr_summary, "method", "corr_method")
559
+ threshold = _pick(corr_summary, "threshold", "high_corr_threshold")
560
+ n_numeric_features = _pick(corr_summary, "n_numeric_features", "numeric_features", "n_features_numeric")
561
+ plotted_features = _pick(corr_summary, "plotted_features", "features_plotted")
562
+ plotted_full_matrix = _pick(corr_summary, "plotted_full_matrix", "full_matrix")
563
+ n_pairs_total = _pick(corr_summary, "n_pairs_total", "pairs_total")
564
+ n_pairs_flagged = _pick(corr_summary, "n_pairs_flagged_ge_threshold", "n_pairs_flagged")
565
+
566
+ ctx.setdefault("CorrelationCheck", {})
567
+ ctx["CorrelationCheck"].update({
568
+ "method": method,
569
+ "threshold": threshold,
570
+ "n_numeric_features": n_numeric_features,
571
+ "plotted_features": plotted_features,
572
+ "plotted_full_matrix": plotted_full_matrix,
573
+ "n_pairs_total": n_pairs_total,
574
+ "n_pairs_flagged_ge_threshold": n_pairs_flagged,
575
+ "top_pairs_csv": top_pairs_csv,
576
+ "top_pairs_main_csv": top_pairs_main_csv,
577
+ })
106
578
  ctx["correlation_pearson_path"] = self.corr_pearson_path
107
579
  ctx["correlation_spearman_path"] = self.corr_spearman_path
580
+ ctx["correlation_heatmap_path"] = heatmap_path
108
581
 
109
- # Performance
110
- perf = self._grab("PerformanceCheck", {})
111
- if not perf:
112
- perf = {
113
- "accuracy": "N/A",
114
- "auc": "N/A",
115
- "ks": "N/A",
116
- "f1": "N/A",
117
- "confusion_matrix": [],
118
- }
119
- ctx.setdefault("check_results", {})["PerformanceCheck"] = perf
120
- ctx["PerformanceCheck"] = perf
582
+ corr_preview_headers = [
583
+ "feature_i", "feature_j", "corr",
584
+ "n_used", "pct_missing_i", "pct_missing_j",
585
+ ]
586
+ pairs_csv_for_preview = top_pairs_main_csv or top_pairs_csv
587
+ corr_preview_rows = _read_csv_as_rows(pairs_csv_for_preview, max_rows=None)
121
588
 
122
- # Logistic summary image
123
- if "LogisticStatsCheck_obj" in self.results:
124
- try:
125
- add_logit_summary_image(
126
- doc, self.results["LogisticStatsCheck_obj"], ctx, "LogitSummaryImg"
127
- )
128
- except Exception as e:
129
- print("⚠️ logistic summary image failed:", e)
589
+ # round numeric fields — esp. 'corr' to 4 decimals
590
+ rounded_corr_rows = []
591
+ for r in corr_preview_rows:
592
+ new_row = {}
593
+ for h in corr_preview_headers:
594
+ val = r.get(h, "")
595
+ if h == "corr":
596
+ try:
597
+ new_row[h] = f"{float(val):.4f}"
598
+ except Exception:
599
+ new_row[h] = val
600
+ elif h in ("pct_missing_i", "pct_missing_j"):
601
+ try:
602
+ new_row[h] = f"{float(val):.2f}"
603
+ except Exception:
604
+ new_row[h] = val
605
+ else:
606
+ new_row[h] = val
607
+ rounded_corr_rows.append(new_row)
130
608
 
131
- # VIF
132
- vif = self._grab("VIFCheck", {})
133
- ctx["VIFCheck"] = AttrDict(
134
- vif
135
- if isinstance(vif, dict) and "vif_table" in vif
136
- else {"vif_table": [], "high_vif_features": [], "error": "Invalid VIFCheck"}
137
- )
609
+ corr_preview_rows = rounded_corr_rows
610
+
611
+
612
+
613
+ # ---- Fallback: if no top-pairs CSV rows, derive from pearson/spearman matrix ----
614
+ if not corr_preview_rows:
615
+ cc_summary = (ctx.get("CorrelationCheck") or {}).get("summary", {}) or {}
616
+ thr = cc_summary.get("threshold")
617
+ if thr is None:
618
+ # Also try the normalized keys we stored earlier
619
+ thr = (ctx.get("CorrelationCheck") or {}).get("threshold") or 0.80
620
+
621
+ # Prefer pearson; fall back to spearman
622
+ pearson_path = ctx.get("correlation_pearson_path")
623
+ spearman_path = ctx.get("correlation_spearman_path")
624
+
625
+ fallback_rows = []
626
+ if pearson_path and os.path.exists(pearson_path):
627
+ fallback_rows = _fallback_pairs_from_corr_matrix(pearson_path, float(thr), top_k=200)
628
+ elif spearman_path and os.path.exists(spearman_path):
629
+ fallback_rows = _fallback_pairs_from_corr_matrix(spearman_path, float(thr), top_k=200)
630
+
631
+ # Use fallback if we found any
632
+ if fallback_rows:
633
+ corr_preview_rows = fallback_rows
634
+
635
+
636
+ # ---------- Performance (classification) ----------
637
+ perf_root = self.results.get("performance", {}) or {}
638
+ cls = perf_root.get("classification", {}) or {}
639
+ cls_summary = (cls.get("summary") or {})
640
+ cls_plots = (cls.get("plots") or {})
641
+ cls_tables = (cls.get("tables") or {})
138
642
 
139
- # Stress / cluster
643
+ def _pick_first(d, *keys):
644
+ for k in keys:
645
+ if k in d and d[k] is not None:
646
+ return d[k]
647
+ low = {str(k).lower(): v for k, v in d.items()}
648
+ for k in keys:
649
+ lk = str(k).lower()
650
+ if lk in low and low[lk] is not None:
651
+ return low[lk]
652
+ return None
653
+
654
+ ctx["classification_summary"] = {
655
+ "AUC": _pick_first(cls_summary, "AUC", "auc", "roc_auc"),
656
+ "KS": _pick_first(cls_summary, "KS", "ks"),
657
+ "GINI": _pick_first(cls_summary, "GINI", "gini"),
658
+ "PR_AUC": _pick_first(cls_summary, "PR_AUC", "pr_auc", "average_precision"),
659
+ "F1": _pick_first(cls_summary, "F1", "f1", "f1_score"),
660
+ "Precision": _pick_first(cls_summary, "Precision", "precision"),
661
+ "Recall": _pick_first(cls_summary, "Recall", "recall", "tpr", "sensitivity"),
662
+ "Accuracy": _pick_first(cls_summary, "Accuracy", "accuracy"),
663
+ "Brier": _pick_first(cls_summary, "Brier", "brier", "brier_score"),
664
+ }
665
+ ctx["classification_tables"] = {
666
+ "confusion": _read_csv_as_rows(_pick_first(cls_tables, "confusion_csv", "confusion")),
667
+ "lift": _read_csv_as_rows(_pick_first(cls_tables, "lift_csv", "decile_lift_csv", "gain_lift_csv")),
668
+ }
669
+ ctx["classification_plot_paths"] = {
670
+ "roc": _pick_first(cls_plots, "roc", "roc_curve"),
671
+ "pr": _pick_first(cls_plots, "pr", "pr_curve", "precision_recall"),
672
+ "lift": _pick_first(cls_plots, "lift", "cumulative_gain", "lift_curve"),
673
+ "calibration": _pick_first(cls_plots, "calibration", "reliability"),
674
+ "confusion": _pick_first(cls_plots, "confusion", "confusion_matrix"),
675
+ "ks": _pick_first(cls_plots, "ks", "ks_curve"),
676
+ }
677
+
678
+ # Regression metrics
679
+ reg = self._grab("RegressionMetrics", {}) or {}
680
+ reg_art = reg.get("artifacts", {}) or {}
681
+ ctx.setdefault("RegressionMetrics", {})
682
+ ctx["RegressionMetrics"].update({
683
+ "notes": reg.get("notes", []),
684
+ "rmse": reg.get("rmse"),
685
+ "mae": reg.get("mae"),
686
+ "median_ae": reg.get("median_ae"),
687
+ "r2": reg.get("r2"),
688
+ "r2_adjusted": reg.get("r2_adjusted"),
689
+ "mape_or_smape": reg.get("mape_or_smape"),
690
+ "mape_used": reg.get("mape_used"),
691
+ "artifacts": {
692
+ "pred_vs_actual": reg_art.get("pred_vs_actual"),
693
+ "residuals_vs_pred": reg_art.get("residuals_vs_pred"),
694
+ "residual_hist": reg_art.get("residual_hist"),
695
+ "qq_plot": reg_art.get("qq_plot"),
696
+ "abs_error_box": reg_art.get("abs_error_box"),
697
+ "abs_error_violin": reg_art.get("abs_error_violin"),
698
+ }
699
+ })
700
+
701
+ # Summary (ensure)
702
+ ctx.setdefault("summary", {})
703
+ ctx["summary"].setdefault("rmse", ctx["RegressionMetrics"].get("rmse"))
704
+ ctx["summary"].setdefault("mae", ctx["RegressionMetrics"].get("mae"))
705
+ ctx["summary"].setdefault("r2", ctx["RegressionMetrics"].get("r2"))
706
+ if "task_type" not in ctx:
707
+ ctx["task_type"] = "regression" if ctx["RegressionMetrics"].get("rmse") is not None else "classification"
708
+
709
+ # Stress list→dict back-compat
140
710
  if isinstance(self.results.get("StressTestCheck"), list):
141
711
  ctx["StressTestCheck"] = {"table": self.results["StressTestCheck"]}
142
712
 
143
- cluster_rows = self._grab("InputClusterCheck", {}).get("cluster_table", [])
144
- ctx.setdefault("InputClusterCheck", {})["cluster_table"] = [
713
+ # Input Cluster Coverage
714
+ icc = self._grab("InputClusterCoverageCheck", {}) or self._grab("InputClusterCheck", {}) or {}
715
+ icc_art = icc.get("artifacts", icc)
716
+ cluster_rows = icc.get("cluster_table") or icc_art.get("cluster_table") or []
717
+ cluster_csv = (icc.get("cluster_csv") or icc_art.get("cluster_csv") or icc_art.get("csv") or "—")
718
+ plot_path = (icc.get("cluster_plot_img") or icc_art.get("cluster_plot_img") or icc_art.get("plot_img"))
719
+
720
+ ctx.setdefault("InputClusterCheck", {})
721
+ ctx["InputClusterCheck"]["cluster_csv"] = cluster_csv
722
+ ctx["InputClusterCheck"]["cluster_table"] = [
145
723
  {
146
- "Cluster": r.get("Cluster") or r.get("cluster"),
724
+ "Cluster": r.get("Cluster") or r.get("cluster") or r.get("label"),
147
725
  "Count": r.get("Count") or r.get("count"),
148
726
  "Percent": r.get("Percent") or r.get("percent"),
149
727
  }
150
- for r in cluster_rows
151
- if isinstance(r, dict)
728
+ for r in cluster_rows if isinstance(r, dict)
152
729
  ]
153
- plot_path = self._grab("InputClusterCheck", {}).get("cluster_plot_img")
154
- ctx["InputClusterCheck"]["cluster_plot_img"] = (
155
- InlineImage(doc, plot_path, width=Inches(5))
156
- if plot_path and os.path.exists(plot_path)
157
- else "Plot not available"
730
+ ctx["ks_curve_path"] = ctx["classification_plot_paths"].get("ks")
731
+
732
+ # VIF normalize
733
+ vif = self._grab("VIFCheck", {})
734
+ ctx["VIFCheck"] = AttrDict(
735
+ vif if isinstance(vif, dict) and "vif_table" in vif
736
+ else {"vif_table": [], "high_vif_features": [], "error": "Invalid VIFCheck"}
158
737
  )
159
738
 
160
- # Render DOCX template
161
- print("🟢 ctx top-level keys:", list(ctx.keys()))
162
- print("🔍 RawDataCheck value:", ctx.get("RawDataCheck"))
739
+ # ===== 2) Open template, replace text placeholders & images ==========
740
+ doc = Document(str(self.template_path))
163
741
 
164
- doc.render(ctx)
165
- doc.save(self.output_path)
742
+ # -------- scalar_map (text placeholders) ----------------------------
743
+ scalar_map = {
744
+ "validation_date": ctx.get("validation_date", "") or "",
745
+ "validated_by": ctx.get("validated_by", "") or "",
746
+ "model_path": ctx.get("model_path", "") or "",
747
+ "task_type": (ctx.get("task_type") or "classification").title(),
748
+ "ModelMetaCheck.model_class": (ctx.get("ModelMetaCheck", {}) or {}).get("model_class", ""),
749
+ "ModelMetaCheck.module": (ctx.get("ModelMetaCheck", {}) or {}).get("module", ""),
750
+ "ModelMetaCheck.model_type": (ctx.get("ModelMetaCheck", {}) or {}).get("model_type", ""),
751
+ "ModelMetaCheck.n_features": (ctx.get("ModelMetaCheck", {}) or {}).get("n_features", ""),
752
+ "ModelMetaCheck.feature_names": _fmt_feature_names((ctx.get("ModelMetaCheck", {}) or {}).get("feature_names")),
753
+ "ModelMetaCheck.n_train_rows": (ctx.get("ModelMetaCheck", {}) or {}).get("n_train_rows", ""),
754
+ "ModelMetaCheck.target_balance": _fmt_target_balance((ctx.get("ModelMetaCheck", {}) or {}).get("target_balance")),
755
+ }
166
756
 
167
- # Auto-insert tables after anchors
168
- tbl_specs = [
169
- {
170
- "anchor": "Stress Testing Results",
171
- "headers": [
172
- "feature",
173
- "perturbation",
174
- "accuracy",
175
- "auc",
176
- "delta_accuracy",
177
- "delta_auc",
178
- ],
179
- "rows": ctx.get("StressTestCheck", {}).get("table", []),
180
- },
181
- {
182
- "anchor": "Cluster Summary Table:",
183
- "headers": ["Cluster", "Count", "Percent"],
184
- "rows": ctx.get("InputClusterCheck", {}).get("cluster_table", []),
185
- },
186
- {
187
- "anchor": "Variance Inflation Factor (VIF) Check",
188
- "headers": ["Feature", "VIF"],
189
- "rows": ctx.get("VIFCheck", {}).get("vif_table", []),
190
- },
191
- ]
757
+ # --- Logistic (Logit) summary text for classification template ---
758
+ logit_ctx = ctx.get("LogitStats") or {}
759
+ scalar_map["LogitStats.summary_text"] = logit_ctx.get("summary_text") or ""
760
+
761
+ # REGRESSION rounded values / labels
762
+ scalar_map.update({
763
+ "summary.rmse2": _fmt2(ctx.get("summary", {}).get("rmse")),
764
+ "summary.mae2": _fmt2(ctx.get("summary", {}).get("mae")),
765
+ "summary.r22": _fmt2(ctx.get("summary", {}).get("r2")),
766
+ "RegressionMetrics.r2_adjusted2": _fmt2(ctx.get("RegressionMetrics", {}).get("r2_adjusted")),
767
+ "RegressionMetrics.median_ae2": _fmt2(ctx.get("RegressionMetrics", {}).get("median_ae")),
768
+ })
769
+
770
+ scalar_map.update({
771
+ "eda_summary_path": ctx.get("eda_summary_path") or "(not available)",
772
+ "eda_missing_path": ctx.get("eda_missing_path") or "(not available)",
773
+ })
774
+ rm = ctx.get("RegressionMetrics", {}) or {}
775
+ mape_or_smape = rm.get("mape_or_smape")
776
+ mape_used = rm.get("mape_used")
777
+ scalar_map["RegressionMetrics.mape_label"] = (
778
+ "N/A" if mape_or_smape is None else f"{_fmt2(mape_or_smape)}% ({'MAPE' if mape_used else 'SMAPE'})"
779
+ )
780
+ notes_list = rm.get("notes") or []
781
+ notes_text = "\n".join(map(str, notes_list)).strip() if notes_list else ""
782
+ scalar_map["RegressionMetrics.notes_text"] = notes_text if notes_text else "None"
783
+
784
+ # CLASSIFICATION rounded values
785
+ cs = ctx.get("classification_summary", {}) or {}
786
+ scalar_map.update({
787
+ "classification_summary.AUC2": _fmt2(cs.get("AUC")),
788
+ "classification_summary.KS2": _fmt2(cs.get("KS")),
789
+ "classification_summary.F12": _fmt2(cs.get("F1")),
790
+ "classification_summary.PR_AUC2": _fmt2(cs.get("PR_AUC")),
791
+ "classification_summary.GINI2": _fmt2(cs.get("GINI")),
792
+ "classification_summary.Precision2": _fmt2(cs.get("Precision")),
793
+ "classification_summary.Recall2": _fmt2(cs.get("Recall")),
794
+ "classification_summary.Accuracy2": _fmt2(cs.get("Accuracy")),
795
+ "classification_summary.Brier2": _fmt2(cs.get("Brier")),
796
+ })
797
+
798
+ # DataQualityCheck (train/test)
799
+ dq = ctx.get("DataQualityCheck") or {}
800
+ def _pick_dq(dq_root, split, key, fallback_key=None):
801
+ split_d = dq_root.get(split) or {}
802
+ if key in split_d and split_d[key] is not None:
803
+ return split_d[key]
804
+ if fallback_key and (fallback_key in dq_root) and dq_root[fallback_key] is not None:
805
+ return dq_root[fallback_key]
806
+ alt = {
807
+ ("avg_missing",): ["avg_missing_rate", "missing_rate"],
808
+ ("columns_with_missing",): ["cols_with_missing", "missing_columns"],
809
+ }
810
+ for k_tuple, alts in alt.items():
811
+ if key in k_tuple:
812
+ for a in alts:
813
+ if split_d.get(a) is not None:
814
+ return split_d[a]
815
+ if fallback_key and dq_root.get(a) is not None:
816
+ return dq_root[a]
817
+ return None
818
+
819
+ train_avg_missing = _pick_dq(dq, "train", "avg_missing", fallback_key="avg_missing")
820
+ test_avg_missing = _pick_dq(dq, "test", "avg_missing", fallback_key="avg_missing")
821
+ train_cols_mis = _pick_dq(dq, "train", "columns_with_missing", fallback_key="columns_with_missing")
822
+ test_cols_mis = _pick_dq(dq, "test", "columns_with_missing", fallback_key="columns_with_missing")
823
+ const_cols_dq = dq.get("constant_columns")
824
+ if const_cols_dq is None:
825
+ const_cols_dq = (dq.get("train") or {}).get("constant_columns") or (dq.get("test") or {}).get("constant_columns")
826
+
827
+ scalar_map.update({
828
+ "DataQualityCheck.train_avg_missing": _fmt_ratio_or_pct(train_avg_missing, pct=True, dash=_p("unknown")),
829
+ "DataQualityCheck.test_avg_missing": _fmt_ratio_or_pct(test_avg_missing, pct=True, dash=_p("unknown")),
830
+ "DataQualityCheck.train_cols_missing": _fmt_list_or_message(train_cols_mis, empty_msg=_p("none_detected"), max_items=25),
831
+ "DataQualityCheck.test_cols_missing": _fmt_list_or_message(test_cols_mis, empty_msg=_p("none_detected"), max_items=25),
832
+ "DataQualityCheck.constant_columns_str": _fmt_list_or_message(const_cols_dq, empty_msg=_p("none_detected"), max_items=25),
833
+ })
834
+
835
+ # RawDataCheck
836
+ rd = ctx.get("RawDataCheck") or {}
837
+ total_rows = rd.get("total_rows") or rd.get("n_rows") or rd.get("rows")
838
+ total_cols = rd.get("total_columns") or rd.get("n_columns") or rd.get("columns")
839
+ avg_missing = rd.get("avg_missing") or rd.get("avg_missing_rate") or rd.get("missing_rate")
840
+ dup_rows = rd.get("duplicate_rows") or rd.get("n_duplicate_rows") or rd.get("duplicates")
841
+ cols_with_missing = rd.get("columns_with_missing") or rd.get("cols_with_missing") or rd.get("missing_columns")
842
+ const_cols = rd.get("constant_columns") or rd.get("constant")
843
+ raw_skipped = bool(rd.get("note")) and all(
844
+ v is None for v in [total_rows, total_cols, avg_missing, dup_rows, cols_with_missing, const_cols]
845
+ )
846
+
847
+ scalar_map.update({
848
+ "RawDataCheck.total_rows":
849
+ _p("not_provided") if raw_skipped else (total_rows if total_rows is not None else _p("unknown")),
850
+ "RawDataCheck.total_columns":
851
+ _p("not_provided") if raw_skipped else (total_cols if total_cols is not None else _p("unknown")),
852
+ "RawDataCheck.avg_missing_pct":
853
+ _p("not_provided") if raw_skipped else _fmt_ratio_or_pct(avg_missing, pct=True, dash=_p("unknown")),
854
+ "RawDataCheck.columns_with_missing_str":
855
+ _p("not_provided") if raw_skipped else _fmt_list_or_message(cols_with_missing, empty_msg=_p("none_detected"), max_items=30),
856
+ "RawDataCheck.duplicate_rows":
857
+ _p("not_provided") if raw_skipped else (str(dup_rows) if dup_rows not in (None, 0) else _p("none_detected")),
858
+ "RawDataCheck.constant_columns_str":
859
+ _p("not_provided") if raw_skipped else _fmt_list_or_message(const_cols, empty_msg=_p("none_detected"), max_items=30),
860
+ })
861
+
862
+
863
+ # Correlation formatted fields & notes
864
+ cc = (ctx.get("CorrelationCheck") or {})
865
+ def _fmt_int(x, dash="—"):
866
+ try:
867
+ if x is None: return dash
868
+ return str(int(round(float(x))))
869
+ except Exception:
870
+ return dash
871
+ def _fmt_float(x, nd=2, dash="—"):
872
+ try:
873
+ if x is None: return dash
874
+ return f"{float(x):.{nd}f}"
875
+ except Exception:
876
+ return dash
877
+
878
+ notes_val = cc.get("notes")
879
+ if isinstance(notes_val, (list, tuple)):
880
+ notes_text = "Notes:\n" + "\n".join(f"- {str(n)}" for n in notes_val) if notes_val else ""
881
+ elif isinstance(notes_val, str):
882
+ notes_text = "Notes:\n- " + notes_val if notes_val.strip() else ""
883
+ else:
884
+ notes_text = ""
885
+
886
+ scalar_map.update({
887
+ "CorrelationCheck.method2": cc.get("method", "") or "",
888
+ "CorrelationCheck.threshold2": _fmt_float(cc.get("threshold")),
889
+ "CorrelationCheck.plotted_features2": _fmt_int(cc.get("plotted_features")),
890
+ "CorrelationCheck.n_numeric_features2": _fmt_int(cc.get("n_numeric_features")),
891
+ "CorrelationCheck.n_pairs_flagged_ge_threshold2": _fmt_int(cc.get("n_pairs_flagged_ge_threshold")),
892
+ "CorrelationCheck.n_pairs_total2": _fmt_int(cc.get("n_pairs_total")),
893
+ "CorrelationCheck.top_pairs_csv_path": cc.get("top_pairs_csv") or cc.get("top_pairs_main_csv") or "",
894
+ "CorrelationCheck.notes_text": notes_text,
895
+ })
896
+ scalar_map.update({
897
+ "correlation_pearson_path": ctx.get("correlation_pearson_path", "") or "",
898
+ "correlation_spearman_path": ctx.get("correlation_spearman_path", "") or "",
899
+ })
900
+
901
+ # VIF fallback text
902
+ vif_ctx = ctx.get("VIFCheck") or {}
903
+ vif_rows = vif_ctx.get("vif_table") or []
904
+ scalar_map["VIFCheck.note_text"] = "" if (isinstance(vif_rows, list) and len(vif_rows) > 0) else "_No VIF results were generated for this run._"
905
+
906
+ # Stress fallback text
907
+ st_rows = (ctx.get("StressTestCheck") or {}).get("table") or []
908
+ scalar_map["StressTest.note_text"] = "" if (isinstance(st_rows, list) and len(st_rows) > 0) else "_No stress-test results were generated for this run._"
909
+
910
+ # InputCluster fallback text
911
+ icc_rows = ctx.get("InputClusterCheck", {}).get("cluster_table") or []
912
+ scalar_map["InputClusterCheck.note_text"] = "" if (isinstance(icc_rows, list) and len(icc_rows) > 0) else "_No cluster summary was generated for this run._"
913
+ scalar_map["InputClusterCheck.cluster_csv"] = ctx.get("InputClusterCheck", {}).get("cluster_csv") or "(not available)"
914
+
915
+ # RuleEngineCheck pretty text
916
+ rec = ctx.get("RuleEngineCheck") or {}
917
+ rules = rec.get("rules") or {}
918
+ def _fmt_rule_value(v):
919
+ if isinstance(v, bool):
920
+ return "✅ Pass" if v else "❌ Fail"
921
+ return str(v)
922
+ overall_pass_val = rec.get("overall_pass")
923
+ if isinstance(overall_pass_val, bool):
924
+ overall_pass_str = "✅ Yes" if overall_pass_val else "❌ No"
925
+ else:
926
+ overall_pass_str = str(overall_pass_val) if overall_pass_val is not None else "—"
927
+ if isinstance(rules, dict) and rules:
928
+ lines = [f"- {str(k)}: {_fmt_rule_value(v)}" for k, v in rules.items()]
929
+ rules_text = "\n".join(lines)
930
+ else:
931
+ rules_text = "No rule results were generated for this run."
932
+ scalar_map.update({
933
+ "RuleEngineCheck.overall_pass_str": overall_pass_str,
934
+ "RuleEngineCheck.rules_text": rules_text,
935
+ })
936
+
937
+ # ModelMeta hyperparams & attributes text
938
+ mm = ctx.get("ModelMetaCheck") or {}
939
+ hyper_table = mm.get("hyperparam_table") or []
940
+ if isinstance(hyper_table, list) and hyper_table:
941
+ hp_lines = []
942
+ for row in hyper_table:
943
+ try:
944
+ k = row.get("param"); v = row.get("value")
945
+ except Exception:
946
+ k, v = None, None
947
+ if k is not None:
948
+ hp_lines.append(f"- {k}: {v}")
949
+ hyperparams_text = "\n".join(hp_lines) if hp_lines else "_No hyperparameters provided._"
950
+ else:
951
+ hyperparams_text = "_No hyperparameters provided._"
952
+
953
+ attrs = mm.get("attributes")
954
+ if isinstance(attrs, dict) and attrs:
955
+ attr_lines = [f"- {k}: {v}" for k, v in attrs.items()]
956
+ attributes_text = "\n".join(attr_lines)
957
+ else:
958
+ attributes_text = (str(attrs) if attrs not in (None, "", []) else "_No attributes available._")
959
+
960
+ scalar_map.update({
961
+ "ModelMetaCheck.hyperparams_text": hyperparams_text,
962
+ "ModelMetaCheck.attributes_text": attributes_text,
963
+ })
964
+
965
+ # Logistic stats (optional)
966
+ lsf = ctx.get("LogisticStatsFit") or {}
967
+ scalar_map.update({
968
+ "LogisticStatsFit.log_lik2": _fmt2(lsf.get("log_lik")),
969
+ "LogisticStatsFit.aic2": _fmt2(lsf.get("aic")),
970
+ "LogisticStatsFit.bic2": _fmt2(lsf.get("bic")),
971
+ "LogisticStatsFit.pseudo_r22": _fmt2(lsf.get("pseudo_r2")),
972
+ "LogisticStatsSummary_text": (
973
+ ctx.get("LogisticStatsSummary_text")
974
+ or ctx.get("LogisticStatsSummary")
975
+ or "Logistic diagnostics not available."
976
+ ),
977
+ })
978
+
979
+ # Linear (OLS) stats (optional)
980
+ lin = ctx.get("LinearStats") or {}
981
+ scalar_map["LinearStats.summary_text"] = lin.get("summary_text", "") or ""
982
+
983
+ # EDA inclusion controls
984
+ all_eda = ctx.get("eda_images_paths") or []
985
+ opts = ctx.get("report_options") or {}
986
+ K = opts.get("eda_max_images")
987
+ if K is None:
988
+ K = len(all_eda)
989
+ else:
990
+ try:
991
+ K = int(K)
992
+ except Exception:
993
+ K = len(all_eda)
994
+ eda_cols = int(opts.get("eda_grid_cols", 3))
995
+ eda_subset = all_eda[:K] if K >= 0 else all_eda
996
+ scalar_map["eda_count_note"] = f"(showing {len(eda_subset)} of {len(all_eda)})" if len(all_eda) != len(eda_subset) else ""
997
+ ctx["_eda_subset"] = eda_subset
998
+ ctx["_eda_cols"] = eda_cols
999
+
1000
+ # text replacements
1001
+ _replace_text_placeholders(doc, scalar_map)
1002
+
1003
+ # image markers
1004
+ images_map = {
1005
+ "roc": ctx["classification_plot_paths"].get("roc"),
1006
+ "pr": ctx["classification_plot_paths"].get("pr"),
1007
+ "lift": ctx["classification_plot_paths"].get("lift"),
1008
+ "calibration": ctx["classification_plot_paths"].get("calibration"),
1009
+ "confusion": ctx["classification_plot_paths"].get("confusion"),
1010
+ "ks": ctx.get("ks_curve_path"),
1011
+ "correlation_heatmap": ctx.get("correlation_heatmap_path"),
1012
+ "shap_beeswarm": ctx.get("shap_beeswarm_path"),
1013
+ "shap_bar": ctx.get("shap_bar_path"),
1014
+ "reg_pred_vs_actual": ctx["RegressionMetrics"]["artifacts"].get("pred_vs_actual"),
1015
+ "reg_residuals_vs_pred": ctx["RegressionMetrics"]["artifacts"].get("residuals_vs_pred"),
1016
+ "reg_residual_hist": ctx["RegressionMetrics"]["artifacts"].get("residual_hist"),
1017
+ "reg_qq": ctx["RegressionMetrics"]["artifacts"].get("qq_plot"),
1018
+ "reg_abs_error_box": ctx["RegressionMetrics"]["artifacts"].get("abs_error_box"),
1019
+ "reg_abs_error_violin": ctx["RegressionMetrics"]["artifacts"].get("abs_error_violin"),
1020
+ "cluster_plot": plot_path,
1021
+ "logit_summary": ctx.get("logit_summary_path"),
1022
+ }
1023
+ _replace_image_markers(doc, images_map, width_in=4.8)
1024
+
1025
+ # save once before inserting tables/grids by anchors
1026
+ doc.save(self.output_path)
192
1027
 
1028
+ # ===== 3) Post-render insertions via anchors ==========================
193
1029
  docx = Document(self.output_path)
194
- for spec in tbl_specs:
195
- if spec["rows"]:
196
- tbl = build_table(docx, spec["headers"], spec["rows"])
197
- insert_after(docx, spec["anchor"], tbl)
198
- print(f"✅ added table after «{spec['anchor']}»")
199
- docx.save(self.output_path)
200
1030
 
1031
+ # EDA grid
1032
+ insert_image_grid(
1033
+ docx,
1034
+ anchor="Distribution Plots",
1035
+ img_paths=ctx.get("_eda_subset") or [],
1036
+ cols=ctx.get("_eda_cols") or 3,
1037
+ width_in=2.2
1038
+ )
201
1039
 
202
- def build_table(doc, headers, rows):
203
- tbl = doc.add_table(rows=1, cols=len(headers))
204
- tbl.style = "Table Grid"
205
- for i, h in enumerate(headers):
206
- tbl.rows[0].cells[i].text = str(h)
207
- for r in rows:
208
- vals = [r.get(h, "") for h in headers] if isinstance(r, dict) else list(r)
209
- vals += [""] * (len(headers) - len(vals))
210
- row = tbl.add_row().cells
211
- for i, v in enumerate(vals):
212
- row[i].text = str(v)
213
- return tbl
1040
+ # helpers
1041
+ def _round_row_numbers(row, places=2):
1042
+ if not isinstance(row, dict):
1043
+ return row
1044
+ out = {}
1045
+ for k, v in row.items():
1046
+ try:
1047
+ if v is None:
1048
+ out[k] = v
1049
+ else:
1050
+ fv = float(v)
1051
+ out[k] = round(fv, places) if np.isfinite(fv) else v
1052
+ except Exception:
1053
+ out[k] = v
1054
+ return out
214
1055
 
1056
+ stress_rows = (ctx.get("StressTestCheck", {}) or {}).get("table", []) or []
1057
+ stress_rows = [_round_row_numbers(r) for r in stress_rows]
215
1058
 
216
- def insert_after(doc, anchor, tbl):
217
- for p in doc.paragraphs:
218
- if anchor.lower() in p.text.lower():
219
- parent = p._p.getparent()
220
- parent.insert(parent.index(p._p) + 1, tbl._tbl)
221
- return
222
- print(f"⚠️ anchor «{anchor}» not found")
1059
+ # OLS coefficients (rounded)
1060
+ ols_coeff_rows = (ctx.get("LinearStats") or {}).get("coeff_table", []) or []
1061
+ ols_coeff_rows = [_round_row_numbers(r, places=4) for r in ols_coeff_rows]
223
1062
 
1063
+ # Ensure logistic minimal coef table exists/rounded (legacy 2-col)
1064
+ if "coef_table" not in ctx or not isinstance(ctx["coef_table"], list):
1065
+ ctx["coef_table"] = []
1066
+ logit_coeff_rows = ctx.get("coef_table", []) or []
1067
+ logit_coeff_rows = [_round_row_numbers(r, places=4) for r in logit_coeff_rows]
1068
+ ctx["coef_table"] = logit_coeff_rows
224
1069
 
225
- def add_logit_summary_image(tpl_doc, sm_results, ctx, key):
226
- html = TMP_DIR / "logit_summary.html"
227
- html.write_text(sm_results.summary().as_html(), encoding="utf8")
228
- png = TMP_DIR / "logit_summary.png"
229
- imgkit.from_file(str(html), str(png), options={"quiet": ""})
230
- ctx[key] = InlineImage(tpl_doc, str(png), width=Mm(160))
1070
+ # ---------- Build full logistic coefficients table (OLS-like) ----------
1071
+ def _coerce_float_or_blank(v):
1072
+ try:
1073
+ if v is None:
1074
+ return ""
1075
+ fv = float(v)
1076
+ return fv if np.isfinite(fv) else ""
1077
+ except Exception:
1078
+ return ""
1079
+
1080
+ logit_full_rows = []
1081
+
1082
+
1083
+ stats_full = None
1084
+ for source_key in ("LogisticStats", "LogisticStatsCheck"):
1085
+ cand = self._grab(source_key) or {}
1086
+ if isinstance(cand, dict) and isinstance(cand.get("coef_table_full"), list):
1087
+ stats_full = cand["coef_table_full"]
1088
+ break
1089
+ if isinstance(stats_full, list) and stats_full:
1090
+ for r in stats_full:
1091
+ logit_full_rows.append({
1092
+ "feature": str(r.get("feature")),
1093
+ "coef": _coerce_float_or_blank(r.get("coef")),
1094
+ "std err": _coerce_float_or_blank(r.get("std_err")),
1095
+ "z": _coerce_float_or_blank(r.get("z")),
1096
+ "P>|z|": _coerce_float_or_blank(r.get("p") or r.get("p_value") or r.get("p>|z|")),
1097
+ "ci_low": _coerce_float_or_blank(r.get("ci_low")),
1098
+ "ci_high": _coerce_float_or_blank(r.get("ci_high")),
1099
+ })
1100
+
1101
+ # (2) fallback: synthesize from sklearn-like attributes (coef_ / intercept_)
1102
+ if not logit_full_rows and (ctx.get("task_type") or "").lower() == "classification":
1103
+ mm = ctx.get("ModelMetaCheck") or {}
1104
+ attrs = (mm.get("attributes") or {})
1105
+ feat_names = mm.get("feature_names") or []
1106
+ def _flatten(x):
1107
+ if x is None: return None
1108
+ arr = np.array(x)
1109
+ return arr.flatten().tolist()
1110
+ flat_coefs = _flatten(attrs.get("coef_"))
1111
+ intercept = attrs.get("intercept_")
1112
+ if isinstance(flat_coefs, list) and len(flat_coefs) == len(feat_names):
1113
+ for f, c in zip(feat_names, flat_coefs):
1114
+ logit_full_rows.append({
1115
+ "feature": str(f),
1116
+ "coef": _coerce_float_or_blank(c),
1117
+ "std err": "", "z": "", "P>|z|": "", "ci_low": "", "ci_high": ""
1118
+ })
1119
+ if intercept is not None:
1120
+ try:
1121
+ b0 = float(intercept[0] if isinstance(intercept, (list, tuple, np.ndarray)) else intercept)
1122
+ except Exception:
1123
+ b0 = None
1124
+ logit_full_rows.insert(0, {
1125
+ "feature": "Intercept",
1126
+ "coef": _coerce_float_or_blank(b0),
1127
+ "std err": "", "z": "", "P>|z|": "", "ci_low": "", "ci_high": ""
1128
+ })
1129
+
1130
+ ctx["logit_coef_full_rows"] = logit_full_rows
1131
+
1132
+ # ---------- task-aware table list ----------
1133
+ task = (ctx.get("task_type") or "classification").lower()
1134
+
1135
+ if task == "regression":
1136
+ stress_headers = ["feature", "perturbation", "rmse", "r2", "delta_rmse", "delta_r2"]
1137
+ tbl_specs = [
1138
+ {
1139
+ "anchor": "Top High-Correlation Feature Pairs (|r| ≥ Threshold)",
1140
+ "headers": ["feature_i", "feature_j", "corr", "n_used", "pct_missing_i", "pct_missing_j"],
1141
+ "rows": corr_preview_rows,
1142
+ },
1143
+ {
1144
+ "anchor": "Stress Testing Results",
1145
+ "headers": stress_headers,
1146
+ "rows": stress_rows,
1147
+ },
1148
+ {
1149
+ "anchor": "Cluster Summary Table",
1150
+ "headers": ["Cluster", "Count", "Percent"],
1151
+ "rows": ctx.get("InputClusterCheck", {}).get("cluster_table", []),
1152
+ },
1153
+ {
1154
+ "anchor": "Variance Inflation Factor (VIF) Check",
1155
+ "headers": ["Feature", "VIF"],
1156
+ "rows": ctx.get("VIFCheck", {}).get("vif_table", []),
1157
+ },
1158
+ {
1159
+ "anchor": "OLS Coefficients (Regression)",
1160
+ "headers": ["feature", "coef", "std err", "t", "P>|t|", "ci_low", "ci_high"],
1161
+ "rows": ols_coeff_rows,
1162
+ },
1163
+ {
1164
+ "anchor": "Top SHAP Features",
1165
+ "headers": ["feature", "mean_abs_shap"],
1166
+ "rows": ctx.get("shap_top_features", []) or [],
1167
+ },
1168
+ ]
1169
+ else:
1170
+ stress_headers = ["feature", "perturbation", "accuracy", "auc", "delta_accuracy", "delta_auc"]
1171
+
1172
+ labels = None
1173
+ try:
1174
+ labels = (cls.get("labels") or cls_tables.get("labels") or ctx.get("classification_labels"))
1175
+ except Exception:
1176
+ labels = None
1177
+
1178
+ def _mk_confusion_headers(labs):
1179
+ if not labs or not isinstance(labs, (list, tuple)) or len(labs) == 0:
1180
+ return ["", "Pred 0", "Pred 1"]
1181
+ return [""] + [f"Pred {str(l)}" for l in labs]
1182
+
1183
+ confusion_rows = ctx.get("classification_tables", {}).get("confusion", []) or []
1184
+ confusion_headers = _mk_confusion_headers(labels)
1185
+
1186
+ logit_stats = ctx.get("LogitStats") or {}
1187
+ logit_headers = logit_stats.get("coef_table_headers") or ["feature","coef","std err","z","P>|z|","ci_low","ci_high"]
1188
+ logit_rows = (
1189
+ logit_stats.get("coef_table_rows")
1190
+ or ctx.get("logit_coef_full_rows")
1191
+ or ctx.get("coef_table")
1192
+ or []
1193
+ )
1194
+
1195
+ baseline_logit_rows = self._extract_baseline_logit_rows(ctx)
1196
+
1197
+ tbl_specs = [
1198
+ {
1199
+ "anchor": "Top High-Correlation Feature Pairs (|r| ≥ Threshold)",
1200
+ "headers": ["feature_i", "feature_j", "corr", "n_used", "pct_missing_i", "pct_missing_j"],
1201
+ "rows": corr_preview_rows,
1202
+ },
1203
+ {
1204
+ "anchor": "Stress Testing Results",
1205
+ "headers": stress_headers,
1206
+ "rows": stress_rows,
1207
+ },
1208
+ {
1209
+ "anchor": "Cluster Summary Table",
1210
+ "headers": ["Cluster", "Count", "Percent"],
1211
+ "rows": ctx.get("InputClusterCheck", {}).get("cluster_table", []),
1212
+ },
1213
+ {
1214
+ "anchor": "Variance Inflation Factor (VIF) Check",
1215
+ "headers": ["Feature", "VIF"],
1216
+ "rows": ctx.get("VIFCheck", {}).get("vif_table", []),
1217
+ },
1218
+ {
1219
+ "anchor": "Confusion Matrix (Classification)",
1220
+ "headers": confusion_headers,
1221
+ "rows": confusion_rows if confusion_rows else [{"": "(no confusion matrix available)"}],
1222
+ },
1223
+ {
1224
+ "anchor": "Decile Lift Table",
1225
+ "headers": ["decile", "total", "events", "avg_score", "event_rate", "lift",
1226
+ "cum_events", "cum_total", "cum_capture_rate", "cum_population", "cum_gain"],
1227
+ "rows": ctx.get("classification_tables", {}).get("lift", []),
1228
+ },
1229
+ {
1230
+ "anchor": "Baseline Metrics",
1231
+ "headers": ["metric", "value"],
1232
+ "rows": baseline_logit_rows,
1233
+ },
1234
+ {
1235
+ "anchor": "Logistic Regression Coefficients ",
1236
+ "headers": logit_headers,
1237
+ "rows": logit_rows if logit_rows else [
1238
+ {"feature": "(no coefficients available)", "coef": "", "std err": "", "z": "", "P>|z|": "", "ci_low": "", "ci_high": ""}
1239
+ ],
1240
+ },
1241
+ {
1242
+ "anchor": "Feature Importances (Tree-Based Models)",
1243
+ "headers": ["feature", "importance"],
1244
+ "rows": ctx.get("feature_importance_table", []) or [],
1245
+ },
1246
+ {
1247
+ "anchor": "Top SHAP Features",
1248
+ "headers": ["feature", "mean_abs_shap"],
1249
+ "rows": ctx.get("shap_top_features", []) or [],
1250
+ },
1251
+ ]
1252
+
1253
+ for spec in tbl_specs:
1254
+ if spec["rows"]:
1255
+ if spec["anchor"] == "Decile Lift Table":
1256
+ tbl = build_decile_lift_table(docx, spec["headers"], spec["rows"])
1257
+ else:
1258
+ tbl = build_table(docx, spec["headers"], spec["rows"])
1259
+ insert_after(docx, spec["anchor"], tbl)
1260
+ print(f"✅ added table after «{spec['anchor']}»")
1261
+
1262
+
1263
+ docx.save(self.output_path)