tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
tanml/ui/app.py ADDED
@@ -0,0 +1,1205 @@
1
+ # tanml/ui/app.py
2
+ from __future__ import annotations
3
+
4
+ import os, time, uuid, json, hashlib
5
+ from pathlib import Path
6
+ from typing import Optional, Dict, Any, Callable, Tuple, List
7
+ import os
8
+ import pandas as pd
9
+ import streamlit as st
10
+
11
+ from sklearn.model_selection import train_test_split
12
+
13
+ # TanML internals
14
+ from tanml.utils.data_loader import load_dataframe
15
+ from tanml.engine.core_engine_agent import ValidationEngine
16
+ from tanml.report.report_builder import ReportBuilder
17
+ from importlib.resources import files
18
+
19
+ # Model registry (20-model suite)
20
+ from tanml.models.registry import (
21
+ list_models, ui_schema_for, build_estimator, infer_task_from_target, get_spec
22
+ )
23
+
24
+
25
+ from pathlib import Path
26
+ from importlib.resources import files
27
+
28
+
29
+ try:
30
+
31
+ _max_mb = int(os.environ.get("TANML_MAX_MB", "1024"))
32
+ st.set_option("server.maxUploadSize", _max_mb)
33
+ st.set_option("server.maxMessageSize", _max_mb)
34
+ st.set_option("browser.gatherUsageStats", False)
35
+ except Exception:
36
+ pass
37
+
38
+ def _choose_report_template(task_type: str) -> Path:
39
+ """Return the correct .docx template for 'regression' or 'classification'."""
40
+ # packaged location (recommended): tanml/report/templates/*.docx
41
+ try:
42
+ templates_pkg = files("tanml.report.templates")
43
+ name = "report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx"
44
+ p = templates_pkg / name
45
+ if p.is_file():
46
+ return Path(str(p))
47
+ except Exception:
48
+ pass
49
+
50
+ # repo fallback
51
+ repo_guess = Path(__file__).resolve().parents[1] / "report" / "templates"
52
+ p2 = repo_guess / ("report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx")
53
+ if p2.exists():
54
+ return p2
55
+
56
+ # cwd fallback
57
+ return Path.cwd() / ("report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx")
58
+
59
+
60
+ CAST9_DEFAULT = bool(int(os.getenv("TANML_CAST9", "0")))
61
+ # --- KPI row styling: align values on one baseline ---
62
+ st.markdown("""
63
+ <style>
64
+ .tanml-kpi-label{
65
+ font-size:0.80rem; opacity:.8; white-space:nowrap;
66
+ height:20px; display:flex; align-items:flex-end;
67
+ }
68
+ .tanml-kpi-value{
69
+ font-size:1.6rem; font-weight:700; line-height:1; margin-top:4px;
70
+ }
71
+ </style>
72
+ """, unsafe_allow_html=True)
73
+
74
+
75
+
76
+
77
+
78
+ def _filter_metrics_for_task(summary: Dict[str, Any]) -> Dict[str, Any]:
79
+ """Keep only metrics relevant to the task_type inside summary."""
80
+ if not isinstance(summary, dict):
81
+ return summary or {}
82
+
83
+ task = summary.get("task_type")
84
+ cls_keys = {"auc", "ks", "f1", "pr_auc", "rules_failed", "task_type"}
85
+ reg_keys = {"rmse", "mae", "r2", "rules_failed", "task_type"}
86
+
87
+ if task == "classification":
88
+ return {k: v for k, v in summary.items() if k in cls_keys}
89
+ if task == "regression":
90
+ return {k: v for k, v in summary.items() if k in reg_keys}
91
+ return summary
92
+
93
+
94
+ def _g(d, *keys, default=None):
95
+ cur = d
96
+ for k in keys:
97
+ if isinstance(cur, dict) and k in cur:
98
+ cur = cur[k]
99
+ else:
100
+ return default
101
+ return cur
102
+
103
+ def _fmt2(v, *, decimals=2, dash="—"):
104
+ if v is None:
105
+ return dash
106
+ try:
107
+ if isinstance(v, float):
108
+ return f"{v:.{decimals}f}"
109
+ if isinstance(v, int):
110
+ return str(v)
111
+ return str(v)
112
+ except Exception:
113
+ return dash
114
+
115
+ # ---------- TVR (Train/Validate/Report) helpers ----------
116
+
117
+ def _tvr_key(section_id: str, name: str) -> str:
118
+ return f"tvr::{section_id}::{name}"
119
+
120
+ def tvr_clear_extras(section_id: str):
121
+ st.session_state.pop(_tvr_key(section_id, "extras"), None)
122
+
123
+ def tvr_reset(section_id: str):
124
+ """Start fresh but keep history. Also hard-reset all UI widgets to defaults."""
125
+ tvr_init(section_id)
126
+
127
+ # Reset TVR state (keep history)
128
+ st.session_state[_tvr_key(section_id, "stage")] = "idle"
129
+ st.session_state[_tvr_key(section_id, "bytes")] = None
130
+ st.session_state[_tvr_key(section_id, "file")] = None
131
+ st.session_state[_tvr_key(section_id, "ts")] = None
132
+ st.session_state[_tvr_key(section_id, "summary")] = None
133
+ st.session_state[_tvr_key(section_id, "label")] = None
134
+ st.session_state[_tvr_key(section_id, "cfg")] = None
135
+ st.session_state.pop(_tvr_key(section_id, "extras"), None)
136
+
137
+ # Hard-reset: clear widget states so UI returns to coded defaults
138
+ keys_to_clear = [
139
+ # Uploaders (single-file flow)
140
+ "upl_cleaned", "upl_raw_single",
141
+ # Uploaders (train/test flow)
142
+ "upl_train", "upl_test", "upl_raw_global",
143
+
144
+ # Sidebar options
145
+ "opt_eda", "opt_eda_max",
146
+ "opt_corr", "opt_vif",
147
+ "opt_rawcheck", "opt_modelmeta",
148
+ "opt_stress", "opt_stress_eps", "opt_stress_frac",
149
+ "opt_cluster", "opt_cluster_k", "opt_cluster_maxk",
150
+ "opt_shap", "opt_shap_bg", "opt_shap_test",
151
+ "opt_vifnorm",
152
+
153
+ # Correlation settings
154
+ "opt_corr_method", "opt_corr_cap", "opt_corr_thr",
155
+
156
+ # Repro seed
157
+ "opt_seed",
158
+
159
+ # Model selection
160
+ "mdl_task", "mdl_lib", "mdl_algo",
161
+
162
+ # Thresholds
163
+ "thr_auc", "thr_f1", "thr_ks",
164
+
165
+ # Internal helpers
166
+ "__thr_block__", "model_selection", "effective_cfg",
167
+ ]
168
+ for k in keys_to_clear:
169
+ st.session_state.pop(k, None)
170
+
171
+ def tvr_init(section_id: str):
172
+ for k, v in {
173
+ "stage": "idle",
174
+ "bytes": None,
175
+ "file": None,
176
+ "ts": None,
177
+ "summary": None,
178
+ "label": None,
179
+ "cfg": None,
180
+ #"history": [],
181
+ }.items():
182
+ st.session_state.setdefault(_tvr_key(section_id, k), v)
183
+
184
+ def tvr_finish(section_id: str, *, report_path: Path = None, report_bytes: bytes = None,
185
+ file_name: str, summary: dict = None, label: str = None, cfg: dict = None):
186
+ """Store the finished report (bytes + metadata) and mark section as 'ready'."""
187
+ tvr_init(section_id)
188
+ if report_bytes is None:
189
+ report_bytes = Path(report_path).read_bytes()
190
+ ts = int(time.time())
191
+ st.session_state[_tvr_key(section_id, "bytes")] = report_bytes
192
+ st.session_state[_tvr_key(section_id, "file")] = file_name
193
+ st.session_state[_tvr_key(section_id, "ts")] = ts
194
+ st.session_state[_tvr_key(section_id, "summary")] = summary or {}
195
+ st.session_state[_tvr_key(section_id, "label")] = label or "Run"
196
+ st.session_state[_tvr_key(section_id, "cfg")] = cfg
197
+ st.session_state[_tvr_key(section_id, "stage")] = "ready"
198
+
199
+
200
+
201
+ def tvr_render_ready(section_id: str, *, header_text="Refit, Validate & Report"):
202
+ tvr_init(section_id)
203
+ if st.session_state[_tvr_key(section_id, "stage")] != "ready":
204
+ return
205
+
206
+ st.subheader(header_text)
207
+ s = st.session_state[_tvr_key(section_id, "summary")] or {}
208
+ # drop PSI-like keys if present
209
+ s = {k: v for k, v in s.items() if "psi" not in k.lower()}
210
+ # keep only task-appropriate metrics
211
+ s = _filter_metrics_for_task(s)
212
+
213
+
214
+ if s:
215
+ st.caption("Summary (last run)")
216
+ st.write({k: (round(v, 2) if isinstance(v, (int, float)) else v) for k, v in s.items()})
217
+
218
+ st.download_button(
219
+ "⬇️ Download report",
220
+ data=st.session_state[_tvr_key(section_id, "bytes")],
221
+ file_name=st.session_state[_tvr_key(section_id, "file")] or
222
+ f"tanml_report_{st.session_state[_tvr_key(section_id,'ts')]}.docx",
223
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
224
+ key=f"tvr_dl::{section_id}::{st.session_state[_tvr_key(section_id,'ts')]}",
225
+ width="stretch",
226
+ )
227
+
228
+ # Keep only the "New model / new run" action
229
+ if st.button("✨ New model / new run", key=f"tvr_new::{section_id}", width="stretch"):
230
+ tvr_reset(section_id)
231
+ st.rerun()
232
+
233
+ def tvr_render_history(section_id: str, *, title="🗂️ Past runs"):
234
+ return # no-op
235
+
236
+ def tvr_store_extras(section_id: str, extras: Dict[str, Any]):
237
+ st.session_state[_tvr_key(section_id, "extras")] = extras
238
+
239
+ # ==========================
240
+ # Filesystem / Session utils
241
+ # ==========================
242
+
243
+ def _session_dir() -> Path:
244
+ """Per-user ephemeral run directory with artifacts subfolder."""
245
+ sid = st.session_state.get("_session_id")
246
+ if not sid:
247
+ sid = str(uuid.uuid4())[:8]
248
+ st.session_state["_session_id"] = sid
249
+ d = Path(".ui_runs") / sid
250
+ d.mkdir(parents=True, exist_ok=True)
251
+ (d / "artifacts").mkdir(parents=True, exist_ok=True)
252
+ return d
253
+
254
+ def _save_upload(upload, dest_dir: Path) -> Optional[Path]:
255
+ """Persist uploaded file to disk. If CSV, convert once to Parquet for efficiency."""
256
+ if upload is None:
257
+ return None
258
+ name = Path(upload.name).name
259
+ path = dest_dir / name
260
+ with open(path, "wb") as f:
261
+ f.write(upload.getbuffer())
262
+ if path.suffix.lower() == ".csv":
263
+ try:
264
+ df = load_dataframe(path)
265
+ pq_path = path.with_suffix(".parquet")
266
+ df.to_parquet(pq_path, index=False)
267
+ return pq_path
268
+ except Exception as e:
269
+ st.warning(f"CSV→Parquet conversion failed (using CSV): {e}")
270
+ return path
271
+
272
+ # ==========================
273
+ # Data helpers
274
+ # ==========================
275
+
276
+ def _pick_target(df: pd.DataFrame) -> str:
277
+ if "target" in df.columns:
278
+ return "target"
279
+ return df.columns[-1]
280
+
281
+ def _normalize_vif(df: pd.DataFrame) -> pd.DataFrame:
282
+ """Cast numerics to float64 and round to 9 decimals to stabilize VIF."""
283
+ df = df.copy()
284
+ num_cols = df.select_dtypes(include="number").columns
285
+ if len(num_cols):
286
+ df[num_cols] = df[num_cols].astype("float64").round(9)
287
+ return df
288
+
289
+ def _schema_align_or_error(train_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[str]]:
290
+ """
291
+ Ensure test_df has same columns as train_df (name & order).
292
+ - If extra columns in test: drop them.
293
+ - If missing columns in test: return error string.
294
+ - Coerce test dtypes to train dtypes when safe (numeric<->numeric).
295
+ """
296
+ train_cols = list(train_df.columns)
297
+ test_cols_set = set(test_df.columns)
298
+ missing = [c for c in train_cols if c not in test_cols_set]
299
+ if missing:
300
+ return test_df, f"Test set is missing required columns: {missing}"
301
+ aligned = test_df[train_cols].copy()
302
+ for c in train_cols:
303
+ td = train_df[c].dtype
304
+ if pd.api.types.is_numeric_dtype(td) and not pd.api.types.is_numeric_dtype(aligned[c].dtype):
305
+ aligned[c] = pd.to_numeric(aligned[c], errors="coerce")
306
+ return aligned, None
307
+
308
+ def _row_overlap_pct(train_df: pd.DataFrame, test_df: pd.DataFrame, cols: List[str]) -> float:
309
+ """Approximate leakage check via row-hash overlap on selected columns."""
310
+ if not cols:
311
+ return 0.0
312
+ def _hash_rows(df: pd.DataFrame) -> set:
313
+ sub = df[cols].copy()
314
+ num = sub.select_dtypes(include="number").columns
315
+ if len(num):
316
+ sub[num] = sub[num].round(9)
317
+ sub = sub.astype(str)
318
+ return set(hashlib.md5(("|".join(row)).encode("utf-8")).hexdigest() for row in sub.values)
319
+
320
+ try:
321
+ tr = _hash_rows(train_df)
322
+ te = _hash_rows(test_df)
323
+ if not tr or not te:
324
+ return 0.0
325
+ inter = len(tr.intersection(te))
326
+ return 100.0 * inter / max(1, len(te))
327
+ except Exception:
328
+ return 0.0
329
+
330
+ # ==========================
331
+ # Seeds / Rule / Engine helpers
332
+ # ==========================
333
+
334
+ def _derive_component_seeds(global_seed: int, *, split_random: bool,
335
+ stress_enabled: bool, cluster_enabled: bool, shap_enabled: bool) -> Dict[str, Optional[int]]:
336
+ base = int(global_seed)
337
+ return {
338
+ "split": base if split_random else None,
339
+ "model": base + 1,
340
+ "stress": (base + 2) if stress_enabled else None,
341
+ "cluster": (base + 3) if cluster_enabled else None,
342
+ "shap": (base + 4) if shap_enabled else None,
343
+ }
344
+
345
+ def _build_rule_cfg(
346
+ *,
347
+ saved_raw: Optional[Path],
348
+ auc_min: float,
349
+ f1_min: float,
350
+ ks_min: float,
351
+ eda_enabled: bool,
352
+ eda_max_plots: int,
353
+ corr_enabled: bool,
354
+ vif_enabled: bool,
355
+ raw_data_check_enabled: bool,
356
+ model_meta_enabled: bool,
357
+ stress_enabled: bool,
358
+ stress_epsilon: float,
359
+ stress_perturb_fraction: float,
360
+ cluster_enabled: bool,
361
+ cluster_k: int,
362
+ cluster_max_k: int,
363
+ shap_enabled: bool,
364
+ shap_bg_size: int,
365
+ shap_test_size: int,
366
+ artifacts_dir: Path,
367
+ split_strategy: str,
368
+ test_size: float,
369
+ seed_global: int,
370
+ component_seeds: Dict[str, Optional[int]],
371
+ in_scope_cols: List[str],
372
+ ) -> Dict[str, Any]:
373
+ cfg: Dict[str, Any] = {
374
+ "paths": {},
375
+ "data": {
376
+ "source": "separate" if split_strategy == "supplied" else "single",
377
+ "split": {
378
+ "strategy": split_strategy,
379
+ "test_size": float(test_size),
380
+ },
381
+ "in_scope_columns": list(in_scope_cols),
382
+ },
383
+ "checks_scope": {"use_only_in_scope": True, "reference_split": "train"},
384
+ "options": {"save_artifacts_dir": str(artifacts_dir)},
385
+ "reproducibility": {
386
+ "seed_global": int(seed_global),
387
+ "component_seeds": component_seeds,
388
+ },
389
+ "auc_roc": {"min": float(auc_min)},
390
+ "f1": {"min": float(f1_min)},
391
+ "ks": {"min": float(ks_min)},
392
+ "EDACheck": {"enabled": bool(eda_enabled), "max_plots": int(eda_max_plots)},
393
+ "correlation": {"enabled": bool(corr_enabled)},
394
+ "VIFCheck": {"enabled": bool(vif_enabled)},
395
+ "raw_data_check": {"enabled": bool(raw_data_check_enabled)},
396
+ "model_meta": {"enabled": bool(model_meta_enabled)},
397
+ "StressTestCheck": {
398
+ "enabled": bool(stress_enabled),
399
+ "epsilon": float(stress_epsilon),
400
+ "perturb_fraction": float(stress_perturb_fraction),
401
+ },
402
+ "InputClusterCoverageCheck": {
403
+ "enabled": bool(cluster_enabled),
404
+ "n_clusters": int(cluster_k),
405
+ "max_k": int(cluster_max_k),
406
+ },
407
+ "explainability": {
408
+ "shap": {
409
+ "enabled": bool(shap_enabled),
410
+ "background_sample_size": int(shap_bg_size),
411
+ "test_sample_size": int(shap_test_size),
412
+ }
413
+ },
414
+ "train_test_split": {"test_size": float(test_size)},
415
+ }
416
+ if saved_raw:
417
+ cfg["paths"]["raw_data"] = str(saved_raw)
418
+ return cfg
419
+
420
+ def _try_run_engine(
421
+ engine: ValidationEngine, progress_cb: Optional[Callable[[str], None]] = None
422
+ ) -> Dict[str, Any]:
423
+ try:
424
+ return engine.run_all_checks(progress_callback=progress_cb)
425
+ except TypeError:
426
+ return engine.run_all_checks()
427
+
428
+ def metric_no_trunc(label: str, value) -> None:
429
+ """Metric-style display without Streamlit's label truncation."""
430
+ st.caption(label)
431
+ st.markdown(
432
+ f"<div style='font-size:1.6rem; font-weight:600; line-height:1.1'>{value}</div>",
433
+ unsafe_allow_html=True,
434
+ )
435
+
436
+ # ==========================
437
+ # UI helpers (Correlation & Regression renderers)
438
+ # ==========================
439
+
440
+ def _render_correlation_outputs(results, title="Numeric Correlation"):
441
+ import os
442
+ import pandas as pd
443
+ st.subheader(title)
444
+ corr_res = (results or {}).get("CorrelationCheck", {}) or {}
445
+ art = corr_res.get("artifacts", corr_res)
446
+ summary = corr_res.get("summary", {}) or {}
447
+ notes = corr_res.get("notes", []) or []
448
+
449
+ heatmap_path = art.get("heatmap_path")
450
+ if heatmap_path and os.path.exists(heatmap_path):
451
+ caption = (
452
+ f"Correlation Heatmap ({summary.get('method','')}) — "
453
+ + ("full matrix" if summary.get('plotted_full_matrix') else
454
+ f"showing {summary.get('plotted_features')}/{summary.get('n_numeric_features')} features (subset)")
455
+ )
456
+ st.image(heatmap_path, caption=caption, width="stretch")
457
+ else:
458
+ st.info("Heatmap not available.")
459
+
460
+ # Format the threshold so 0.4 shows as 0.40 (avoid 0.39999999999999997)
461
+ thr_raw = summary.get("threshold", None)
462
+ thr_str = _fmt2(float(thr_raw), decimals=2) if isinstance(thr_raw, (int, float)) else "—"
463
+
464
+ c1, c2, c3, c4 = st.columns(4)
465
+ labels = ["Numeric features", "Pairs evaluated", f"Pairs ≥ {thr_str}", "Method"]
466
+ values = [
467
+ summary.get("n_numeric_features", "—"),
468
+ summary.get("n_pairs_total", "—"),
469
+ summary.get("n_pairs_flagged_ge_threshold", "—"),
470
+ (summary.get("method", "—") or "—").capitalize(),
471
+ ]
472
+ for col, lab in zip((c1, c2, c3, c4), labels):
473
+ col.markdown(f'<div class="tanml-kpi-label">{lab}</div>', unsafe_allow_html=True)
474
+ for col, val in zip((c1, c2, c3, c4), values):
475
+ col.markdown(f'<div class="tanml-kpi-value">{val}</div>', unsafe_allow_html=True)
476
+
477
+ main_csv = art.get("top_pairs_main_csv")
478
+ if main_csv and os.path.exists(main_csv):
479
+ prev_df = pd.read_csv(main_csv)
480
+ st.caption("Top high-correlation pairs (preview)")
481
+ st.dataframe(prev_df.head(3), width="stretch", height=160, hide_index=True)
482
+ else:
483
+ st.info("No high-correlation pairs at current threshold.")
484
+
485
+ full_csv = art.get("top_pairs_csv")
486
+ if full_csv and os.path.exists(full_csv):
487
+ with open(full_csv, "rb") as f:
488
+ st.download_button(
489
+ "⬇️ Download full pairs CSV",
490
+ f,
491
+ file_name="correlation_top_pairs.csv",
492
+ mime="text/csv",
493
+ key=f"corrcsv::{full_csv}"
494
+ )
495
+ for n in notes:
496
+ st.caption(f"• {n}")
497
+
498
+
499
+ def _render_regression_outputs(results):
500
+ """Show regression tiles + detailed metrics only (no visuals)."""
501
+ task = results.get("task_type", "classification")
502
+ if task != "regression":
503
+ return
504
+
505
+ # Key tiles
506
+ st.subheader("Key Metrics")
507
+ c1, c2, c3 = st.columns(3)
508
+ c1.metric("RMSE", _fmt2(_g(results, "summary", "rmse")))
509
+ c2.metric("MAE", _fmt2(_g(results, "summary", "mae")))
510
+ r2_val = _g(results, "summary", "r2")
511
+ c3.metric("R²", _fmt2(r2_val, decimals=2))
512
+
513
+ # Detailed metrics
514
+ st.subheader("Regression Metrics (Detailed)")
515
+ adj_r2 = _g(results, "RegressionMetrics", "r2_adjusted")
516
+ med_ae = _g(results, "RegressionMetrics", "median_ae")
517
+ mv = _g(results, "RegressionMetrics", "mape_or_smape")
518
+ m_is_mape = bool(_g(results, "RegressionMetrics", "mape_used", default=False))
519
+
520
+ st.table({
521
+ "Metric": ["Adjusted R²", "Median AE", "MAPE/SMAPE"],
522
+ "Value": [
523
+ _fmt2(adj_r2, decimals=2) if adj_r2 is not None else "N/A",
524
+ _fmt2(med_ae, decimals=2) if med_ae is not None else "—",
525
+ (f"{_fmt2(mv, decimals=2)}% (MAPE)" if m_is_mape else
526
+ (f"{_fmt2(mv, decimals=2)}% (SMAPE)" if mv is not None else "N/A")),
527
+ ],
528
+ })
529
+
530
+ # Notes, if any
531
+ notes = _g(results, "RegressionMetrics", "notes", default=[]) or []
532
+ if notes:
533
+ with st.expander("Metric notes"):
534
+ for n in notes:
535
+ st.write("• " + str(n))
536
+
537
+ def tvr_render_extras(section_id: str):
538
+ extras = st.session_state.get(_tvr_key(section_id, "extras"))
539
+ if not extras:
540
+ return
541
+
542
+ results = extras.get("results", {}) or {}
543
+ task = results.get("task_type", "classification")
544
+
545
+ # Correlation always (if present)
546
+ _render_correlation_outputs(results)
547
+
548
+ # Task-aware metrics/dashboard
549
+ if task == "regression":
550
+ _render_regression_outputs(results)
551
+ else:
552
+ st.subheader("Key Metrics")
553
+ summary = (results.get("summary") or {})
554
+ auc_val = _fmt2(summary.get("auc"))
555
+ ks_val = _fmt2(summary.get("ks"))
556
+ rules_failed_v = _fmt2(summary.get("rules_failed"))
557
+ c1, c2, c3 = st.columns(3)
558
+ c1.metric("AUC", auc_val)
559
+ c2.metric("KS", ks_val)
560
+ c3.metric("Rules failed", rules_failed_v)
561
+
562
+ st.subheader("Run Summary")
563
+ st.write({
564
+ "Train rows": int(extras.get("train_rows", 0) or 0),
565
+ "Test rows": int(extras.get("test_rows", 0) or 0),
566
+ "Target": extras.get("target", "—"),
567
+ "Features used": int(extras.get("n_features", 0) or 0),
568
+ "Seed used": extras.get("seed_used", "—"),
569
+ "Effective config": extras.get("eff_path", "—"),
570
+ "Artifacts dir": extras.get("artifacts_dir", "—"),
571
+ })
572
+
573
+ # ------- Artifacts -------
574
+ art_dir = Path(extras.get("artifacts_dir") or "")
575
+ if not art_dir or not art_dir.exists():
576
+ return
577
+
578
+ # Optional: show first SHAP image (if any)
579
+ shap_imgs = sorted(art_dir.glob("**/*shap*.*"))
580
+ if shap_imgs:
581
+ st.image(str(shap_imgs[0]), caption="SHAP summary", width="stretch")
582
+
583
+ st.subheader("Artifacts")
584
+
585
+ SKIP_SUFFIXES = ("_top_pairs_main.csv",)
586
+ pretty_names = {
587
+ "correlation_top_pairs.csv": "Flagged pairs (full)",
588
+ "heatmap.png": "Correlation heatmap",
589
+ "pearson_corr.csv": "Pearson correlation matrix",
590
+ "spearman_corr.csv": "Spearman correlation matrix",
591
+ }
592
+ order_hint = [
593
+ "correlation_top_pairs.csv",
594
+ "heatmap.png",
595
+ "pearson_corr.csv",
596
+ "spearman_corr.csv",
597
+ ]
598
+
599
+ all_paths = [p for p in art_dir.glob("**/*") if p.is_file()]
600
+ files_list = [p for p in all_paths if not any(p.name.endswith(sfx) for sfx in SKIP_SUFFIXES)]
601
+ if not files_list:
602
+ st.caption("No artifacts were saved.")
603
+ return
604
+
605
+ files_list.sort(key=lambda p: (order_hint.index(p.name) if p.name in order_hint else 999, p.name.lower()))
606
+
607
+ for p in files_list[:100]:
608
+ label = pretty_names.get(p.name, p.name)
609
+ with open(p, "rb") as fh:
610
+ st.download_button(
611
+ f"⬇️ Download {label}",
612
+ fh.read(),
613
+ file_name=p.name,
614
+ key=f"art::{section_id}::{p}",
615
+ width="stretch",
616
+ )
617
+
618
+ # ==========================
619
+ # UI — Refit-only (20 models)
620
+ # ==========================
621
+
622
+ st.set_page_config(page_title="TanML — Refit & Validate", layout="wide")
623
+ st.title("TanML • Refit & Validate")
624
+
625
+ run_dir = _session_dir()
626
+ artifacts_dir = run_dir / "artifacts"
627
+
628
+ # Sidebar
629
+
630
+ with st.sidebar.expander("Checks & Options", expanded=True):
631
+ eda_enabled = st.checkbox("EDA plots", True, key="opt_eda")
632
+ eda_max_plots = st.number_input("EDA max plots (-1=all numeric)", value=-1, step=1, key="opt_eda_max")
633
+ corr_enabled = st.checkbox("Correlation matrix", True, key="opt_corr")
634
+ vif_enabled = st.checkbox("VIF check", True, key="opt_vif")
635
+ raw_data_check_enabled = st.checkbox("RawDataCheck (needs raw)", True, key="opt_rawcheck")
636
+ model_meta_enabled = st.checkbox("Model metadata", True, key="opt_modelmeta")
637
+
638
+ with st.sidebar.expander("Robustness / Stress Testing", expanded=True):
639
+ stress_enabled = st.checkbox("StressTestCheck", True, key="opt_stress")
640
+ stress_epsilon = st.number_input("Epsilon (noise)", 0.0, 1.0, 0.01, 0.01, key="opt_stress_eps")
641
+ stress_perturb_fraction = st.number_input("Perturb fraction", 0.0, 1.0, 0.20, 0.05, key="opt_stress_frac")
642
+
643
+ with st.sidebar.expander("Input Cluster Coverage", expanded=False):
644
+ cluster_enabled = st.checkbox("InputClusterCoverageCheck", True, key="opt_cluster")
645
+ cluster_k = st.number_input("n_clusters", 2, 50, 5, 1, key="opt_cluster_k")
646
+ cluster_max_k = st.number_input("max_k (elbow cap)", 2, 100, 10, 1, key="opt_cluster_maxk")
647
+
648
+ with st.sidebar.expander("Explainability (SHAP)", expanded=True):
649
+ shap_enabled = st.checkbox("Enable SHAP", True, key="opt_shap")
650
+ shap_bg_size = st.number_input("Background sample size", 10, 100000, 100, 10, key="opt_shap_bg")
651
+ shap_test_size = st.number_input("Test rows to explain", 10, 100000, 200, 10, key="opt_shap_test")
652
+
653
+ # with st.sidebar.expander("Numeric normalization (VIF stabilization)", expanded=False):
654
+ # apply_vif_norm = st.checkbox(
655
+ # "Cast numerics to float64 and round to 9 decimals",
656
+ # value=True,
657
+ # #help="Stabilizes VIF across CSV vs Parquet.",
658
+ # key="opt_vifnorm"
659
+ # )
660
+
661
+ if "cast9_round9" not in st.session_state:
662
+ st.session_state["cast9_round9"] = CAST9_DEFAULT
663
+
664
+ cast9 = bool(st.session_state["cast9_round9"])
665
+
666
+ apply_vif_norm = cast9
667
+
668
+ st.sidebar.subheader("Reproducibility")
669
+ seed_global = st.sidebar.number_input(
670
+ "Random seed",
671
+ min_value=0, max_value=2_147_483_647, value=42, step=1,
672
+ help="Controls random split, model refit, stress noise, clustering, and SHAP sampling.",
673
+ key="opt_seed"
674
+ )
675
+
676
+ with st.sidebar.expander("Correlation Settings", expanded=False):
677
+ corr_method = st.radio("Method", ["pearson", "spearman"], index=0, horizontal=True, key="opt_corr_method")
678
+ corr_cap = st.slider("Heatmap features (default 20, max 60)", min_value=10, max_value=60, value=20, step=5, key="opt_corr_cap")
679
+ corr_thr = st.number_input("High-correlation threshold (|r| ≥)", min_value=0.0, max_value=0.99, value=0.80, step=0.05, key="opt_corr_thr")
680
+ st.caption("Tip: ≥ 0.90 often means near-duplicate; confirm with VIF.")
681
+
682
+ corr_ui_cfg = {
683
+ "enabled": bool(corr_enabled),
684
+ "method": corr_method,
685
+ "high_corr_threshold": float(corr_thr),
686
+ "heatmap_max_features_default": int(corr_cap),
687
+ "heatmap_max_features_limit": 60,
688
+ "subset_strategy": "cluster",
689
+ "top_pairs_max": 200,
690
+ "sample_rows": 150_000,
691
+ "seed": int(seed_global),
692
+ "save_csv": True,
693
+ "save_fig": True,
694
+ "appendix_csv_cap": None,
695
+ }
696
+
697
+ def render_model_form(y_train, seed_global: int):
698
+ """Return (library, algorithm, params, task) using the 20-model registry,
699
+ but never show per-model seed; we inject sidebar seed automatically.
700
+ """
701
+ task_auto = infer_task_from_target(y_train)
702
+ task = st.radio(
703
+ "Task",
704
+ ["classification", "regression"],
705
+ index=0 if task_auto == "classification" else 1,
706
+ horizontal=True,
707
+ key="mdl_task"
708
+ )
709
+
710
+ libraries_all = ["sklearn", "xgboost", "lightgbm", "catboost"]
711
+ library = st.selectbox("Library", libraries_all, index=0, key="mdl_lib")
712
+
713
+ avail = [(lib, algo) for (lib, algo), spec in list_models(task).items() if lib == library]
714
+ if not avail:
715
+ st.error(f"No algorithms available for {library} / {task}. Is the library installed?")
716
+ st.stop()
717
+ algo_names = [a for (_, a) in avail]
718
+ algo = st.selectbox("Algorithm", algo_names, index=0, key="mdl_algo")
719
+
720
+ spec = get_spec(library, algo)
721
+ schema = ui_schema_for(library, algo)
722
+ defaults = spec.defaults or {}
723
+
724
+ seed_keys = [k for k in ("random_state", "seed", "random_seed") if k in defaults]
725
+ params = {}
726
+
727
+ with st.expander("Hyperparameters", expanded=True):
728
+ if seed_keys:
729
+ st.caption("ℹ️ Model seed is taken from the sidebar **Random seed**.")
730
+
731
+ for name, (typ, choices, helptext) in schema.items():
732
+ if name in seed_keys:
733
+ continue
734
+
735
+ default_val = defaults.get(name)
736
+
737
+ if typ == "choice":
738
+ opts = list(choices) if choices else []
739
+ show = ["None" if o is None else o for o in opts]
740
+ if show:
741
+ if default_val is None and "None" in show:
742
+ idx = show.index("None")
743
+ elif default_val in show:
744
+ idx = show.index(default_val)
745
+ else:
746
+ idx = 0
747
+ sel = st.selectbox(name, show, index=idx, help=helptext)
748
+ params[name] = None if sel == "None" else sel
749
+ else:
750
+ params[name] = st.text_input(
751
+ name,
752
+ value=str(default_val) if default_val is not None else "",
753
+ help=helptext
754
+ )
755
+
756
+ elif typ == "bool":
757
+ params[name] = st.checkbox(
758
+ name,
759
+ value=bool(default_val) if default_val is not None else False,
760
+ help=helptext
761
+ )
762
+
763
+ elif typ == "int":
764
+ params[name] = int(st.number_input(
765
+ name,
766
+ value=int(default_val) if default_val is not None else 0,
767
+ step=1,
768
+ help=helptext
769
+ ))
770
+
771
+ elif typ == "float":
772
+ params[name] = float(st.number_input(
773
+ name,
774
+ value=float(default_val) if default_val is not None else 0.0,
775
+ help=helptext
776
+ ))
777
+
778
+ else: # "str"
779
+ params[name] = st.text_input(
780
+ name,
781
+ value=str(default_val) if default_val is not None else "",
782
+ help=helptext
783
+ )
784
+
785
+ for k in seed_keys:
786
+ params[k] = int(seed_global)
787
+
788
+ return library, algo, params, task
789
+
790
+ # ==========================
791
+ # Main layout — single flow
792
+ # ==========================
793
+
794
+ left, right = st.columns([1.35, 1])
795
+
796
+ with left:
797
+ st.header("1) Choose data source")
798
+ data_source = st.radio(
799
+ "Data source",
800
+ ("Single cleaned file (you split)", "Already split: Train & Test"),
801
+ index=0, horizontal=True
802
+ )
803
+
804
+ saved_raw = None
805
+ cleaned_df = train_df = test_df = None
806
+
807
+ if data_source.startswith("Single"):
808
+ st.subheader("Upload files")
809
+ cleaned_file = st.file_uploader(
810
+ "Cleaned dataset (required)",
811
+ key="upl_cleaned"
812
+ )
813
+ raw_file = st.file_uploader(
814
+ "Raw dataset (optional)",
815
+ key="upl_raw_single"
816
+ )
817
+
818
+ saved_cleaned = _save_upload(cleaned_file, run_dir)
819
+ saved_raw = _save_upload(raw_file, run_dir)
820
+
821
+ df_preview = None
822
+ if saved_cleaned:
823
+ st.success(f"Cleaned file saved: `{saved_cleaned}`")
824
+ try:
825
+ df_preview = load_dataframe(saved_cleaned)
826
+ cleaned_df = df_preview
827
+ st.write("Preview (top 10 rows):")
828
+ st.dataframe(df_preview.head(10), width="stretch")
829
+ except Exception as e:
830
+ st.error(f"Could not read cleaned file: {e}")
831
+
832
+ st.subheader("Configure data")
833
+ if df_preview is not None:
834
+ target_default = _pick_target(df_preview)
835
+ cols = list(df_preview.columns)
836
+ target = st.selectbox("Target column", options=cols, index=cols.index(target_default) if target_default in cols else 0)
837
+ features = st.multiselect(
838
+ "Features",
839
+ options=[c for c in cols if c != target],
840
+ default=[c for c in cols if c != target],
841
+ )
842
+ else:
843
+ target, features = None, []
844
+
845
+ test_size = st.slider("Hold-out test size", 0.1, 0.5, 0.3, 0.05)
846
+
847
+ else:
848
+ st.subheader("Upload TRAIN/TEST (cleaned)")
849
+ train_cleaned = st.file_uploader(
850
+ "Train (cleaned) — required",
851
+ key="upl_train"
852
+ )
853
+ test_cleaned = st.file_uploader(
854
+ "Test (cleaned) — required",
855
+ key="upl_test"
856
+ )
857
+ raw_file = st.file_uploader(
858
+ "Raw dataset (optional, global)",
859
+ key="upl_raw_global"
860
+ )
861
+
862
+ saved_train = _save_upload(train_cleaned, run_dir)
863
+ saved_test = _save_upload(test_cleaned, run_dir)
864
+ saved_raw = _save_upload(raw_file, run_dir)
865
+
866
+ df_tr = df_te = None
867
+ if saved_train:
868
+ try:
869
+ df_tr = load_dataframe(saved_train)
870
+ train_df = df_tr
871
+ except Exception as e:
872
+ st.error(f"Could not read train: {e}")
873
+ if saved_test:
874
+ try:
875
+ df_te = load_dataframe(saved_test)
876
+ test_df = df_te
877
+ except Exception as e:
878
+ st.error(f"Could not read test: {e}")
879
+
880
+ if df_tr is not None:
881
+ st.write("Train preview (top 10):")
882
+ st.dataframe(df_tr.head(10), width="stretch")
883
+ target_default = _pick_target(df_tr)
884
+ cols = list(df_tr.columns)
885
+ target = st.selectbox("Target column", options=cols, index=cols.index(target_default) if target_default in cols else 0)
886
+ features = st.multiselect(
887
+ "Features",
888
+ options=[c for c in cols if c != target],
889
+ default=[c for c in cols if c != target],
890
+ )
891
+ else:
892
+ target, features = None, []
893
+
894
+ with right:
895
+ st.header("2) Refit, Validate & Report")
896
+ tvr_render_ready("refit", header_text="Run & Report (last run)")
897
+ tvr_render_extras("refit")
898
+
899
+ report_name = st.text_input(
900
+ "Report file name (.docx)",
901
+ value=f"tanml_report_{int(time.time())}.docx"
902
+ )
903
+
904
+ # --- Choose model BEFORE running ---
905
+ if data_source.startswith("Single"):
906
+ if 'cleaned_df' in locals() and cleaned_df is not None and target:
907
+ y_for_task = cleaned_df[target]
908
+ else:
909
+ y_for_task = pd.Series([], dtype="float64")
910
+ else:
911
+ if 'train_df' in locals() and train_df is not None and target:
912
+ y_for_task = train_df[target]
913
+ else:
914
+ y_for_task = pd.Series([], dtype="float64")
915
+
916
+ library_selected, algo_selected, user_hp, task_selected = render_model_form(y_for_task, seed_global)
917
+
918
+ # ---- Conditional thresholds (left panel) ----
919
+ with st.sidebar:
920
+ if task_selected == "classification":
921
+ st.subheader("Thresholds")
922
+ c1, c2, c3 = st.columns(3)
923
+ auc_min = c1.number_input("AUC ≥", 0.0, 1.0, 0.60, 0.01, key="thr_auc")
924
+ f1_min = c2.number_input("F1 ≥", 0.0, 1.0, 0.60, 0.01, key="thr_f1")
925
+ ks_min = c3.number_input("KS ≥", 0.0, 1.0, 0.20, 0.01, key="thr_ks")
926
+ st.session_state["__thr_block__"] = {"AUC_min": auc_min, "F1_min": f1_min, "KS_min": ks_min}
927
+ else:
928
+ # Regression (or anything not classification) → hide thresholds entirely.
929
+ auc_min = 0.0
930
+ f1_min = 0.0
931
+ ks_min = 0.0
932
+ st.session_state["__thr_block__"] = {"problem_type": "regression"}
933
+
934
+ st.session_state["model_selection"] = {
935
+ "library": library_selected,
936
+ "algo": algo_selected,
937
+ "hp": user_hp,
938
+ "task": task_selected,
939
+ }
940
+
941
+ if data_source.startswith("Single"):
942
+ ready = bool((locals().get("saved_cleaned") is not None) and target and features)
943
+ else:
944
+ ready = bool((locals().get("saved_train") is not None) and (locals().get("saved_test") is not None) and target and features)
945
+ ready = ready and bool(st.session_state.get("model_selection", {}).get("algo"))
946
+
947
+ go = st.button("▶️ Refit & validate", type="primary", disabled=not ready)
948
+ if not ready:
949
+ st.info("Provide data, pick target + features, choose a model, then run.")
950
+
951
+ if go:
952
+ try:
953
+ # ---- Build X_train/X_test/y_train/y_test ----
954
+ if data_source.startswith("Single"):
955
+ cleaned_df = cleaned_df if cleaned_df is not None else load_dataframe(saved_cleaned)
956
+ if target not in cleaned_df.columns:
957
+ st.error(f"Target '{target}' not found in cleaned data."); st.stop()
958
+
959
+ safe_features = [c for c in features if c in cleaned_df.columns and c != target]
960
+ if not safe_features:
961
+ safe_features = [c for c in cleaned_df.columns if c != target]
962
+
963
+ X = cleaned_df[safe_features].copy()
964
+ y = cleaned_df[target].copy()
965
+
966
+ if apply_vif_norm:
967
+ cleaned_df = _normalize_vif(cleaned_df)
968
+ X = _normalize_vif(X)
969
+
970
+ X_train, X_test, y_train, y_test = train_test_split(
971
+ X, y, test_size=float(test_size), random_state=seed_global, shuffle=True
972
+ )
973
+ df_checks = pd.concat([X_train, y_train], axis=1)
974
+ if apply_vif_norm:
975
+ df_checks = _normalize_vif(df_checks)
976
+
977
+ split_strategy = "random"
978
+ saved_raw_ = saved_raw
979
+
980
+ else:
981
+ train_df = train_df if train_df is not None else load_dataframe(saved_train)
982
+ test_df = test_df if test_df is not None else load_dataframe(saved_test)
983
+
984
+ if target not in train_df.columns: st.error(f"Target '{target}' not in TRAIN."); st.stop()
985
+ if target not in test_df.columns: st.error(f"Target '{target}' not in TEST."); st.stop()
986
+
987
+ X_train = train_df[features].copy(); y_train = train_df[target].copy()
988
+ te_sub = test_df[features + [target]].copy()
989
+ te_aligned, err = _schema_align_or_error(train_df[features + [target]], te_sub)
990
+ if err: st.error(err); st.stop()
991
+ X_test = te_aligned[features].copy(); y_test = te_aligned[target].copy()
992
+
993
+ if apply_vif_norm:
994
+ train_df = _normalize_vif(train_df); test_df = _normalize_vif(test_df)
995
+ X_train = _normalize_vif(X_train); X_test = _normalize_vif(X_test)
996
+
997
+ df_checks = pd.concat([X_train, y_train], axis=1)
998
+ if apply_vif_norm: df_checks = _normalize_vif(df_checks)
999
+
1000
+ overlap_pct = _row_overlap_pct(
1001
+ pd.concat([X_train, y_train], axis=1),
1002
+ pd.concat([X_test, y_test], axis=1),
1003
+ cols=features + [target]
1004
+ )
1005
+ if overlap_pct > 0:
1006
+ st.warning(f"Potential Train↔Test row overlap: ~{overlap_pct:.2f}% (by row hash).")
1007
+
1008
+ split_strategy = "supplied"
1009
+ saved_raw_ = saved_raw
1010
+
1011
+ # ---- Build estimator from selection ----
1012
+ sel = st.session_state.get("model_selection") or {}
1013
+ library_selected = sel.get("library")
1014
+ algo_selected = sel.get("algo")
1015
+ user_hp = sel.get("hp") or {}
1016
+ task_selected = sel.get("task")
1017
+
1018
+ if not library_selected or not algo_selected:
1019
+ st.error("Please choose a library and algorithm before running.")
1020
+ st.stop()
1021
+
1022
+ try:
1023
+ model = build_estimator(library_selected, algo_selected, user_hp)
1024
+ except ImportError:
1025
+ st.error(
1026
+ f"Missing dependency for '{library_selected}.{algo_selected}'. "
1027
+ f"Install the library (e.g., 'pip install {library_selected}') and try again."
1028
+ )
1029
+ st.stop()
1030
+
1031
+ model.fit(X_train, y_train)
1032
+ if not hasattr(model, "feature_names_in_"):
1033
+ try:
1034
+ model.feature_names_in_ = X_train.columns.to_numpy()
1035
+ except Exception:
1036
+ pass
1037
+
1038
+ component_seeds = _derive_component_seeds(
1039
+ seed_global,
1040
+ split_random=(split_strategy == "random"),
1041
+ stress_enabled=stress_enabled,
1042
+ cluster_enabled=cluster_enabled,
1043
+ shap_enabled=shap_enabled,
1044
+ )
1045
+
1046
+ rule_cfg = _build_rule_cfg(
1047
+ saved_raw=saved_raw_,
1048
+ auc_min=auc_min, f1_min=f1_min, ks_min=ks_min,
1049
+ eda_enabled=eda_enabled, eda_max_plots=int(eda_max_plots),
1050
+ corr_enabled=corr_enabled, vif_enabled=vif_enabled,
1051
+ raw_data_check_enabled=raw_data_check_enabled and bool(saved_raw_),
1052
+ model_meta_enabled=model_meta_enabled,
1053
+ stress_enabled=stress_enabled,
1054
+ stress_epsilon=stress_epsilon,
1055
+ stress_perturb_fraction=stress_perturb_fraction,
1056
+ cluster_enabled=cluster_enabled,
1057
+ cluster_k=int(cluster_k),
1058
+ cluster_max_k=int(cluster_max_k),
1059
+ shap_enabled=shap_enabled,
1060
+ shap_bg_size=int(shap_bg_size),
1061
+ shap_test_size=int(shap_test_size),
1062
+ artifacts_dir=artifacts_dir,
1063
+ split_strategy=split_strategy,
1064
+ test_size=float(test_size) if split_strategy == "random" else 0.0,
1065
+ seed_global=seed_global,
1066
+ component_seeds=component_seeds,
1067
+ in_scope_cols=list(X_train.columns) + [target],
1068
+ )
1069
+
1070
+
1071
+ rule_cfg.setdefault("explainability", {}).setdefault("shap", {})["out_dir"] = str(artifacts_dir)
1072
+
1073
+ # 2) Backward-compat for wrappers that look under "SHAPCheck"
1074
+ rule_cfg["SHAPCheck"] = {
1075
+ "enabled": bool(shap_enabled),
1076
+ "background_size": int(shap_bg_size),
1077
+ "sample_size": int(shap_test_size),
1078
+ "out_dir": str(artifacts_dir),
1079
+ }
1080
+ # Inject correlation settings
1081
+ rule_cfg.setdefault("CorrelationCheck", {}).update({
1082
+ "enabled": bool(corr_enabled), **{
1083
+ "method": corr_method,
1084
+ "high_corr_threshold": float(corr_thr),
1085
+ "heatmap_max_features_default": int(corr_cap),
1086
+ "heatmap_max_features_limit": 60,
1087
+ "subset_strategy": "cluster",
1088
+ "top_pairs_max": 200,
1089
+ "sample_rows": 150_000,
1090
+ "seed": int(seed_global),
1091
+ "save_csv": True,
1092
+ "save_fig": True,
1093
+ "appendix_csv_cap": None,
1094
+ }
1095
+ })
1096
+ rule_cfg.setdefault("correlation", {}).update({"enabled": bool(corr_enabled)})
1097
+
1098
+ raw_df_loaded = load_dataframe(saved_raw_) if saved_raw_ else None
1099
+
1100
+ engine = ValidationEngine(
1101
+ model, X_train, X_test, y_train, y_test, rule_cfg, df_checks,
1102
+ raw_df=raw_df_loaded, ctx=st.session_state
1103
+ )
1104
+
1105
+ eff_path = run_dir / "effective_config.yaml"
1106
+ with st.status("Refitting & running checks…", expanded=True) as status:
1107
+ def cb(msg: str):
1108
+ try: status.write(msg)
1109
+ except Exception: pass
1110
+
1111
+ results = _try_run_engine(engine, progress_cb=cb)
1112
+
1113
+ # Stamp task for downstream UI/history
1114
+ results["task_type"] = task_selected
1115
+ results.setdefault("summary", {})["task_type"] = task_selected
1116
+
1117
+ rules_meta = {}
1118
+ if task_selected == "classification":
1119
+ rules_meta = {"auc_roc_min": auc_min, "f1_min": f1_min, "ks_min": ks_min}
1120
+
1121
+ results.update({
1122
+ "validation_date": pd.Timestamp.now(tz="America/Chicago").strftime("%Y-%m-%d %H:%M:%S %Z"),
1123
+ "model_path": "(refit in UI)",
1124
+ "validated_by": "TanML UI (Refit-only)",
1125
+ "rules": rules_meta, # conditional now
1126
+ "data_split": "random" if split_strategy == "random" else "supplied_train_test",
1127
+ "reproducibility": {"seed_global": int(seed_global), "component_seeds": component_seeds, "split_strategy": split_strategy},
1128
+ "model_provenance": {
1129
+ "refit_always": True,
1130
+ "library": library_selected,
1131
+ "algorithm": algo_selected,
1132
+ "hyperparameters_used": user_hp,
1133
+ "hyperparameters_source": "user_form"
1134
+ }
1135
+ })
1136
+
1137
+ eff_paths = (
1138
+ {"cleaned": str(saved_cleaned), "raw": str(saved_raw_) if saved_raw_ else None}
1139
+ if split_strategy == "random" else
1140
+ {"train_cleaned": str(saved_train), "test_cleaned": str(saved_test), "raw": str(saved_raw_) if saved_raw_ else None}
1141
+ )
1142
+ effective_cfg = {
1143
+ "scenario": "Refit",
1144
+ "mode": "Refit & Validate",
1145
+ "paths": {**eff_paths, "artifacts_dir": str(artifacts_dir)},
1146
+ "data": {
1147
+ "target": target,
1148
+ "features": list(X_train.columns),
1149
+ "split": "random" if split_strategy == "random" else "supplied",
1150
+ "test_size": float(test_size) if split_strategy == "random" else None,
1151
+ "random_state": seed_global if split_strategy == "random" else None,
1152
+ },
1153
+ "model_refit": {"library": library_selected, "type": algo_selected, "hyperparameters": user_hp},
1154
+ "checks": rule_cfg,
1155
+ "reproducibility": {"seed_global": int(seed_global), "component_seeds": component_seeds},
1156
+ }
1157
+ try:
1158
+ from ruamel.yaml import YAML
1159
+ YAML().dump(effective_cfg, eff_path.open("w"))
1160
+ except Exception:
1161
+ eff_path.write_text(json.dumps(effective_cfg, indent=2))
1162
+
1163
+ report_path = run_dir / report_name
1164
+ report_path.parent.mkdir(parents=True, exist_ok=True)
1165
+
1166
+ # ⬇️ choose correct template based on task (classification/regression)
1167
+ template_path = _choose_report_template(task_selected)
1168
+
1169
+ # Build the report with the chosen template
1170
+ ReportBuilder(
1171
+ results,
1172
+ template_path=str(template_path),
1173
+ output_path=report_path
1174
+ ).build()
1175
+
1176
+ status.update(label="Done!", state="complete", expanded=False)
1177
+
1178
+ tvr_store_extras("refit", {
1179
+ "results": results,
1180
+ "train_rows": len(y_train),
1181
+ "test_rows": len(y_test),
1182
+ "target": target,
1183
+ "n_features": len(X_train.columns),
1184
+ "seed_used": seed_global if split_strategy == "random" else "N/A (user-supplied split)",
1185
+ "eff_path": str(eff_path),
1186
+ "artifacts_dir": str(artifacts_dir),
1187
+ })
1188
+
1189
+ tvr_finish(
1190
+ "refit",
1191
+ report_path=report_path,
1192
+ file_name=report_path.name,
1193
+ summary=results.get("summary", {}),
1194
+ label=f"Refit — {library_selected}.{algo_selected}",
1195
+ cfg=effective_cfg,
1196
+ )
1197
+ tvr_render_ready("refit", header_text="Run & Report (last run)")
1198
+ tvr_render_extras("refit")
1199
+
1200
+ except Exception as e:
1201
+ st.error(f"Refit/validate failed: {e}")
1202
+ st.stop()
1203
+
1204
+
1205
+ st.markdown("---")