tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -1
- tanml/check_runners/cleaning_repro_runner.py +2 -2
- tanml/check_runners/correlation_runner.py +49 -12
- tanml/check_runners/explainability_runner.py +12 -22
- tanml/check_runners/logistic_stats_runner.py +196 -17
- tanml/check_runners/performance_runner.py +82 -26
- tanml/check_runners/raw_data_runner.py +29 -14
- tanml/check_runners/regression_metrics_runner.py +195 -0
- tanml/check_runners/stress_test_runner.py +23 -6
- tanml/check_runners/vif_runner.py +33 -27
- tanml/checks/correlation.py +241 -41
- tanml/checks/explainability/shap_check.py +261 -29
- tanml/checks/logit_stats.py +186 -54
- tanml/checks/performance_classification.py +305 -0
- tanml/checks/raw_data.py +58 -23
- tanml/checks/regression_metrics.py +167 -0
- tanml/checks/stress_test.py +157 -53
- tanml/cli/main.py +99 -27
- tanml/engine/check_agent_registry.py +20 -10
- tanml/engine/core_engine_agent.py +199 -37
- tanml/models/registry.py +329 -0
- tanml/report/report_builder.py +1180 -147
- tanml/report/templates/report_template_cls.docx +0 -0
- tanml/report/templates/report_template_reg.docx +0 -0
- tanml/ui/app.py +1205 -0
- tanml/utils/data_loader.py +105 -15
- tanml-0.1.7.dist-info/METADATA +164 -0
- tanml-0.1.7.dist-info/RECORD +54 -0
- tanml/cli/arg_parser.py +0 -31
- tanml/cli/init_cmd.py +0 -8
- tanml/cli/validate_cmd.py +0 -7
- tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
- tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
- tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
- tanml/engine/segmentation_agent.py +0 -118
- tanml/engine/validation_agent.py +0 -91
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/model_loader.py +0 -35
- tanml/utils/r_loader.py +0 -30
- tanml/utils/sas_loader.py +0 -50
- tanml/utils/yaml_generator.py +0 -34
- tanml/utils/yaml_loader.py +0 -5
- tanml/validate.py +0 -209
- tanml-0.1.6.dist-info/METADATA +0 -317
- tanml-0.1.6.dist-info/RECORD +0 -62
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
tanml/ui/app.py
ADDED
|
@@ -0,0 +1,1205 @@
|
|
|
1
|
+
# tanml/ui/app.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os, time, uuid, json, hashlib
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Dict, Any, Callable, Tuple, List
|
|
7
|
+
import os
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import streamlit as st
|
|
10
|
+
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
|
|
13
|
+
# TanML internals
|
|
14
|
+
from tanml.utils.data_loader import load_dataframe
|
|
15
|
+
from tanml.engine.core_engine_agent import ValidationEngine
|
|
16
|
+
from tanml.report.report_builder import ReportBuilder
|
|
17
|
+
from importlib.resources import files
|
|
18
|
+
|
|
19
|
+
# Model registry (20-model suite)
|
|
20
|
+
from tanml.models.registry import (
|
|
21
|
+
list_models, ui_schema_for, build_estimator, infer_task_from_target, get_spec
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from importlib.resources import files
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
|
|
31
|
+
_max_mb = int(os.environ.get("TANML_MAX_MB", "1024"))
|
|
32
|
+
st.set_option("server.maxUploadSize", _max_mb)
|
|
33
|
+
st.set_option("server.maxMessageSize", _max_mb)
|
|
34
|
+
st.set_option("browser.gatherUsageStats", False)
|
|
35
|
+
except Exception:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def _choose_report_template(task_type: str) -> Path:
|
|
39
|
+
"""Return the correct .docx template for 'regression' or 'classification'."""
|
|
40
|
+
# packaged location (recommended): tanml/report/templates/*.docx
|
|
41
|
+
try:
|
|
42
|
+
templates_pkg = files("tanml.report.templates")
|
|
43
|
+
name = "report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx"
|
|
44
|
+
p = templates_pkg / name
|
|
45
|
+
if p.is_file():
|
|
46
|
+
return Path(str(p))
|
|
47
|
+
except Exception:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
# repo fallback
|
|
51
|
+
repo_guess = Path(__file__).resolve().parents[1] / "report" / "templates"
|
|
52
|
+
p2 = repo_guess / ("report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx")
|
|
53
|
+
if p2.exists():
|
|
54
|
+
return p2
|
|
55
|
+
|
|
56
|
+
# cwd fallback
|
|
57
|
+
return Path.cwd() / ("report_template_reg.docx" if task_type == "regression" else "report_template_cls.docx")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
CAST9_DEFAULT = bool(int(os.getenv("TANML_CAST9", "0")))
|
|
61
|
+
# --- KPI row styling: align values on one baseline ---
|
|
62
|
+
st.markdown("""
|
|
63
|
+
<style>
|
|
64
|
+
.tanml-kpi-label{
|
|
65
|
+
font-size:0.80rem; opacity:.8; white-space:nowrap;
|
|
66
|
+
height:20px; display:flex; align-items:flex-end;
|
|
67
|
+
}
|
|
68
|
+
.tanml-kpi-value{
|
|
69
|
+
font-size:1.6rem; font-weight:700; line-height:1; margin-top:4px;
|
|
70
|
+
}
|
|
71
|
+
</style>
|
|
72
|
+
""", unsafe_allow_html=True)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _filter_metrics_for_task(summary: Dict[str, Any]) -> Dict[str, Any]:
|
|
79
|
+
"""Keep only metrics relevant to the task_type inside summary."""
|
|
80
|
+
if not isinstance(summary, dict):
|
|
81
|
+
return summary or {}
|
|
82
|
+
|
|
83
|
+
task = summary.get("task_type")
|
|
84
|
+
cls_keys = {"auc", "ks", "f1", "pr_auc", "rules_failed", "task_type"}
|
|
85
|
+
reg_keys = {"rmse", "mae", "r2", "rules_failed", "task_type"}
|
|
86
|
+
|
|
87
|
+
if task == "classification":
|
|
88
|
+
return {k: v for k, v in summary.items() if k in cls_keys}
|
|
89
|
+
if task == "regression":
|
|
90
|
+
return {k: v for k, v in summary.items() if k in reg_keys}
|
|
91
|
+
return summary
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _g(d, *keys, default=None):
|
|
95
|
+
cur = d
|
|
96
|
+
for k in keys:
|
|
97
|
+
if isinstance(cur, dict) and k in cur:
|
|
98
|
+
cur = cur[k]
|
|
99
|
+
else:
|
|
100
|
+
return default
|
|
101
|
+
return cur
|
|
102
|
+
|
|
103
|
+
def _fmt2(v, *, decimals=2, dash="—"):
|
|
104
|
+
if v is None:
|
|
105
|
+
return dash
|
|
106
|
+
try:
|
|
107
|
+
if isinstance(v, float):
|
|
108
|
+
return f"{v:.{decimals}f}"
|
|
109
|
+
if isinstance(v, int):
|
|
110
|
+
return str(v)
|
|
111
|
+
return str(v)
|
|
112
|
+
except Exception:
|
|
113
|
+
return dash
|
|
114
|
+
|
|
115
|
+
# ---------- TVR (Train/Validate/Report) helpers ----------
|
|
116
|
+
|
|
117
|
+
def _tvr_key(section_id: str, name: str) -> str:
|
|
118
|
+
return f"tvr::{section_id}::{name}"
|
|
119
|
+
|
|
120
|
+
def tvr_clear_extras(section_id: str):
|
|
121
|
+
st.session_state.pop(_tvr_key(section_id, "extras"), None)
|
|
122
|
+
|
|
123
|
+
def tvr_reset(section_id: str):
|
|
124
|
+
"""Start fresh but keep history. Also hard-reset all UI widgets to defaults."""
|
|
125
|
+
tvr_init(section_id)
|
|
126
|
+
|
|
127
|
+
# Reset TVR state (keep history)
|
|
128
|
+
st.session_state[_tvr_key(section_id, "stage")] = "idle"
|
|
129
|
+
st.session_state[_tvr_key(section_id, "bytes")] = None
|
|
130
|
+
st.session_state[_tvr_key(section_id, "file")] = None
|
|
131
|
+
st.session_state[_tvr_key(section_id, "ts")] = None
|
|
132
|
+
st.session_state[_tvr_key(section_id, "summary")] = None
|
|
133
|
+
st.session_state[_tvr_key(section_id, "label")] = None
|
|
134
|
+
st.session_state[_tvr_key(section_id, "cfg")] = None
|
|
135
|
+
st.session_state.pop(_tvr_key(section_id, "extras"), None)
|
|
136
|
+
|
|
137
|
+
# Hard-reset: clear widget states so UI returns to coded defaults
|
|
138
|
+
keys_to_clear = [
|
|
139
|
+
# Uploaders (single-file flow)
|
|
140
|
+
"upl_cleaned", "upl_raw_single",
|
|
141
|
+
# Uploaders (train/test flow)
|
|
142
|
+
"upl_train", "upl_test", "upl_raw_global",
|
|
143
|
+
|
|
144
|
+
# Sidebar options
|
|
145
|
+
"opt_eda", "opt_eda_max",
|
|
146
|
+
"opt_corr", "opt_vif",
|
|
147
|
+
"opt_rawcheck", "opt_modelmeta",
|
|
148
|
+
"opt_stress", "opt_stress_eps", "opt_stress_frac",
|
|
149
|
+
"opt_cluster", "opt_cluster_k", "opt_cluster_maxk",
|
|
150
|
+
"opt_shap", "opt_shap_bg", "opt_shap_test",
|
|
151
|
+
"opt_vifnorm",
|
|
152
|
+
|
|
153
|
+
# Correlation settings
|
|
154
|
+
"opt_corr_method", "opt_corr_cap", "opt_corr_thr",
|
|
155
|
+
|
|
156
|
+
# Repro seed
|
|
157
|
+
"opt_seed",
|
|
158
|
+
|
|
159
|
+
# Model selection
|
|
160
|
+
"mdl_task", "mdl_lib", "mdl_algo",
|
|
161
|
+
|
|
162
|
+
# Thresholds
|
|
163
|
+
"thr_auc", "thr_f1", "thr_ks",
|
|
164
|
+
|
|
165
|
+
# Internal helpers
|
|
166
|
+
"__thr_block__", "model_selection", "effective_cfg",
|
|
167
|
+
]
|
|
168
|
+
for k in keys_to_clear:
|
|
169
|
+
st.session_state.pop(k, None)
|
|
170
|
+
|
|
171
|
+
def tvr_init(section_id: str):
|
|
172
|
+
for k, v in {
|
|
173
|
+
"stage": "idle",
|
|
174
|
+
"bytes": None,
|
|
175
|
+
"file": None,
|
|
176
|
+
"ts": None,
|
|
177
|
+
"summary": None,
|
|
178
|
+
"label": None,
|
|
179
|
+
"cfg": None,
|
|
180
|
+
#"history": [],
|
|
181
|
+
}.items():
|
|
182
|
+
st.session_state.setdefault(_tvr_key(section_id, k), v)
|
|
183
|
+
|
|
184
|
+
def tvr_finish(section_id: str, *, report_path: Path = None, report_bytes: bytes = None,
|
|
185
|
+
file_name: str, summary: dict = None, label: str = None, cfg: dict = None):
|
|
186
|
+
"""Store the finished report (bytes + metadata) and mark section as 'ready'."""
|
|
187
|
+
tvr_init(section_id)
|
|
188
|
+
if report_bytes is None:
|
|
189
|
+
report_bytes = Path(report_path).read_bytes()
|
|
190
|
+
ts = int(time.time())
|
|
191
|
+
st.session_state[_tvr_key(section_id, "bytes")] = report_bytes
|
|
192
|
+
st.session_state[_tvr_key(section_id, "file")] = file_name
|
|
193
|
+
st.session_state[_tvr_key(section_id, "ts")] = ts
|
|
194
|
+
st.session_state[_tvr_key(section_id, "summary")] = summary or {}
|
|
195
|
+
st.session_state[_tvr_key(section_id, "label")] = label or "Run"
|
|
196
|
+
st.session_state[_tvr_key(section_id, "cfg")] = cfg
|
|
197
|
+
st.session_state[_tvr_key(section_id, "stage")] = "ready"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def tvr_render_ready(section_id: str, *, header_text="Refit, Validate & Report"):
|
|
202
|
+
tvr_init(section_id)
|
|
203
|
+
if st.session_state[_tvr_key(section_id, "stage")] != "ready":
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
st.subheader(header_text)
|
|
207
|
+
s = st.session_state[_tvr_key(section_id, "summary")] or {}
|
|
208
|
+
# drop PSI-like keys if present
|
|
209
|
+
s = {k: v for k, v in s.items() if "psi" not in k.lower()}
|
|
210
|
+
# keep only task-appropriate metrics
|
|
211
|
+
s = _filter_metrics_for_task(s)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
if s:
|
|
215
|
+
st.caption("Summary (last run)")
|
|
216
|
+
st.write({k: (round(v, 2) if isinstance(v, (int, float)) else v) for k, v in s.items()})
|
|
217
|
+
|
|
218
|
+
st.download_button(
|
|
219
|
+
"⬇️ Download report",
|
|
220
|
+
data=st.session_state[_tvr_key(section_id, "bytes")],
|
|
221
|
+
file_name=st.session_state[_tvr_key(section_id, "file")] or
|
|
222
|
+
f"tanml_report_{st.session_state[_tvr_key(section_id,'ts')]}.docx",
|
|
223
|
+
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
224
|
+
key=f"tvr_dl::{section_id}::{st.session_state[_tvr_key(section_id,'ts')]}",
|
|
225
|
+
width="stretch",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Keep only the "New model / new run" action
|
|
229
|
+
if st.button("✨ New model / new run", key=f"tvr_new::{section_id}", width="stretch"):
|
|
230
|
+
tvr_reset(section_id)
|
|
231
|
+
st.rerun()
|
|
232
|
+
|
|
233
|
+
def tvr_render_history(section_id: str, *, title="🗂️ Past runs"):
|
|
234
|
+
return # no-op
|
|
235
|
+
|
|
236
|
+
def tvr_store_extras(section_id: str, extras: Dict[str, Any]):
|
|
237
|
+
st.session_state[_tvr_key(section_id, "extras")] = extras
|
|
238
|
+
|
|
239
|
+
# ==========================
|
|
240
|
+
# Filesystem / Session utils
|
|
241
|
+
# ==========================
|
|
242
|
+
|
|
243
|
+
def _session_dir() -> Path:
|
|
244
|
+
"""Per-user ephemeral run directory with artifacts subfolder."""
|
|
245
|
+
sid = st.session_state.get("_session_id")
|
|
246
|
+
if not sid:
|
|
247
|
+
sid = str(uuid.uuid4())[:8]
|
|
248
|
+
st.session_state["_session_id"] = sid
|
|
249
|
+
d = Path(".ui_runs") / sid
|
|
250
|
+
d.mkdir(parents=True, exist_ok=True)
|
|
251
|
+
(d / "artifacts").mkdir(parents=True, exist_ok=True)
|
|
252
|
+
return d
|
|
253
|
+
|
|
254
|
+
def _save_upload(upload, dest_dir: Path) -> Optional[Path]:
|
|
255
|
+
"""Persist uploaded file to disk. If CSV, convert once to Parquet for efficiency."""
|
|
256
|
+
if upload is None:
|
|
257
|
+
return None
|
|
258
|
+
name = Path(upload.name).name
|
|
259
|
+
path = dest_dir / name
|
|
260
|
+
with open(path, "wb") as f:
|
|
261
|
+
f.write(upload.getbuffer())
|
|
262
|
+
if path.suffix.lower() == ".csv":
|
|
263
|
+
try:
|
|
264
|
+
df = load_dataframe(path)
|
|
265
|
+
pq_path = path.with_suffix(".parquet")
|
|
266
|
+
df.to_parquet(pq_path, index=False)
|
|
267
|
+
return pq_path
|
|
268
|
+
except Exception as e:
|
|
269
|
+
st.warning(f"CSV→Parquet conversion failed (using CSV): {e}")
|
|
270
|
+
return path
|
|
271
|
+
|
|
272
|
+
# ==========================
|
|
273
|
+
# Data helpers
|
|
274
|
+
# ==========================
|
|
275
|
+
|
|
276
|
+
def _pick_target(df: pd.DataFrame) -> str:
|
|
277
|
+
if "target" in df.columns:
|
|
278
|
+
return "target"
|
|
279
|
+
return df.columns[-1]
|
|
280
|
+
|
|
281
|
+
def _normalize_vif(df: pd.DataFrame) -> pd.DataFrame:
|
|
282
|
+
"""Cast numerics to float64 and round to 9 decimals to stabilize VIF."""
|
|
283
|
+
df = df.copy()
|
|
284
|
+
num_cols = df.select_dtypes(include="number").columns
|
|
285
|
+
if len(num_cols):
|
|
286
|
+
df[num_cols] = df[num_cols].astype("float64").round(9)
|
|
287
|
+
return df
|
|
288
|
+
|
|
289
|
+
def _schema_align_or_error(train_df: pd.DataFrame, test_df: pd.DataFrame) -> Tuple[pd.DataFrame, Optional[str]]:
|
|
290
|
+
"""
|
|
291
|
+
Ensure test_df has same columns as train_df (name & order).
|
|
292
|
+
- If extra columns in test: drop them.
|
|
293
|
+
- If missing columns in test: return error string.
|
|
294
|
+
- Coerce test dtypes to train dtypes when safe (numeric<->numeric).
|
|
295
|
+
"""
|
|
296
|
+
train_cols = list(train_df.columns)
|
|
297
|
+
test_cols_set = set(test_df.columns)
|
|
298
|
+
missing = [c for c in train_cols if c not in test_cols_set]
|
|
299
|
+
if missing:
|
|
300
|
+
return test_df, f"Test set is missing required columns: {missing}"
|
|
301
|
+
aligned = test_df[train_cols].copy()
|
|
302
|
+
for c in train_cols:
|
|
303
|
+
td = train_df[c].dtype
|
|
304
|
+
if pd.api.types.is_numeric_dtype(td) and not pd.api.types.is_numeric_dtype(aligned[c].dtype):
|
|
305
|
+
aligned[c] = pd.to_numeric(aligned[c], errors="coerce")
|
|
306
|
+
return aligned, None
|
|
307
|
+
|
|
308
|
+
def _row_overlap_pct(train_df: pd.DataFrame, test_df: pd.DataFrame, cols: List[str]) -> float:
|
|
309
|
+
"""Approximate leakage check via row-hash overlap on selected columns."""
|
|
310
|
+
if not cols:
|
|
311
|
+
return 0.0
|
|
312
|
+
def _hash_rows(df: pd.DataFrame) -> set:
|
|
313
|
+
sub = df[cols].copy()
|
|
314
|
+
num = sub.select_dtypes(include="number").columns
|
|
315
|
+
if len(num):
|
|
316
|
+
sub[num] = sub[num].round(9)
|
|
317
|
+
sub = sub.astype(str)
|
|
318
|
+
return set(hashlib.md5(("|".join(row)).encode("utf-8")).hexdigest() for row in sub.values)
|
|
319
|
+
|
|
320
|
+
try:
|
|
321
|
+
tr = _hash_rows(train_df)
|
|
322
|
+
te = _hash_rows(test_df)
|
|
323
|
+
if not tr or not te:
|
|
324
|
+
return 0.0
|
|
325
|
+
inter = len(tr.intersection(te))
|
|
326
|
+
return 100.0 * inter / max(1, len(te))
|
|
327
|
+
except Exception:
|
|
328
|
+
return 0.0
|
|
329
|
+
|
|
330
|
+
# ==========================
|
|
331
|
+
# Seeds / Rule / Engine helpers
|
|
332
|
+
# ==========================
|
|
333
|
+
|
|
334
|
+
def _derive_component_seeds(global_seed: int, *, split_random: bool,
|
|
335
|
+
stress_enabled: bool, cluster_enabled: bool, shap_enabled: bool) -> Dict[str, Optional[int]]:
|
|
336
|
+
base = int(global_seed)
|
|
337
|
+
return {
|
|
338
|
+
"split": base if split_random else None,
|
|
339
|
+
"model": base + 1,
|
|
340
|
+
"stress": (base + 2) if stress_enabled else None,
|
|
341
|
+
"cluster": (base + 3) if cluster_enabled else None,
|
|
342
|
+
"shap": (base + 4) if shap_enabled else None,
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
def _build_rule_cfg(
|
|
346
|
+
*,
|
|
347
|
+
saved_raw: Optional[Path],
|
|
348
|
+
auc_min: float,
|
|
349
|
+
f1_min: float,
|
|
350
|
+
ks_min: float,
|
|
351
|
+
eda_enabled: bool,
|
|
352
|
+
eda_max_plots: int,
|
|
353
|
+
corr_enabled: bool,
|
|
354
|
+
vif_enabled: bool,
|
|
355
|
+
raw_data_check_enabled: bool,
|
|
356
|
+
model_meta_enabled: bool,
|
|
357
|
+
stress_enabled: bool,
|
|
358
|
+
stress_epsilon: float,
|
|
359
|
+
stress_perturb_fraction: float,
|
|
360
|
+
cluster_enabled: bool,
|
|
361
|
+
cluster_k: int,
|
|
362
|
+
cluster_max_k: int,
|
|
363
|
+
shap_enabled: bool,
|
|
364
|
+
shap_bg_size: int,
|
|
365
|
+
shap_test_size: int,
|
|
366
|
+
artifacts_dir: Path,
|
|
367
|
+
split_strategy: str,
|
|
368
|
+
test_size: float,
|
|
369
|
+
seed_global: int,
|
|
370
|
+
component_seeds: Dict[str, Optional[int]],
|
|
371
|
+
in_scope_cols: List[str],
|
|
372
|
+
) -> Dict[str, Any]:
|
|
373
|
+
cfg: Dict[str, Any] = {
|
|
374
|
+
"paths": {},
|
|
375
|
+
"data": {
|
|
376
|
+
"source": "separate" if split_strategy == "supplied" else "single",
|
|
377
|
+
"split": {
|
|
378
|
+
"strategy": split_strategy,
|
|
379
|
+
"test_size": float(test_size),
|
|
380
|
+
},
|
|
381
|
+
"in_scope_columns": list(in_scope_cols),
|
|
382
|
+
},
|
|
383
|
+
"checks_scope": {"use_only_in_scope": True, "reference_split": "train"},
|
|
384
|
+
"options": {"save_artifacts_dir": str(artifacts_dir)},
|
|
385
|
+
"reproducibility": {
|
|
386
|
+
"seed_global": int(seed_global),
|
|
387
|
+
"component_seeds": component_seeds,
|
|
388
|
+
},
|
|
389
|
+
"auc_roc": {"min": float(auc_min)},
|
|
390
|
+
"f1": {"min": float(f1_min)},
|
|
391
|
+
"ks": {"min": float(ks_min)},
|
|
392
|
+
"EDACheck": {"enabled": bool(eda_enabled), "max_plots": int(eda_max_plots)},
|
|
393
|
+
"correlation": {"enabled": bool(corr_enabled)},
|
|
394
|
+
"VIFCheck": {"enabled": bool(vif_enabled)},
|
|
395
|
+
"raw_data_check": {"enabled": bool(raw_data_check_enabled)},
|
|
396
|
+
"model_meta": {"enabled": bool(model_meta_enabled)},
|
|
397
|
+
"StressTestCheck": {
|
|
398
|
+
"enabled": bool(stress_enabled),
|
|
399
|
+
"epsilon": float(stress_epsilon),
|
|
400
|
+
"perturb_fraction": float(stress_perturb_fraction),
|
|
401
|
+
},
|
|
402
|
+
"InputClusterCoverageCheck": {
|
|
403
|
+
"enabled": bool(cluster_enabled),
|
|
404
|
+
"n_clusters": int(cluster_k),
|
|
405
|
+
"max_k": int(cluster_max_k),
|
|
406
|
+
},
|
|
407
|
+
"explainability": {
|
|
408
|
+
"shap": {
|
|
409
|
+
"enabled": bool(shap_enabled),
|
|
410
|
+
"background_sample_size": int(shap_bg_size),
|
|
411
|
+
"test_sample_size": int(shap_test_size),
|
|
412
|
+
}
|
|
413
|
+
},
|
|
414
|
+
"train_test_split": {"test_size": float(test_size)},
|
|
415
|
+
}
|
|
416
|
+
if saved_raw:
|
|
417
|
+
cfg["paths"]["raw_data"] = str(saved_raw)
|
|
418
|
+
return cfg
|
|
419
|
+
|
|
420
|
+
def _try_run_engine(
|
|
421
|
+
engine: ValidationEngine, progress_cb: Optional[Callable[[str], None]] = None
|
|
422
|
+
) -> Dict[str, Any]:
|
|
423
|
+
try:
|
|
424
|
+
return engine.run_all_checks(progress_callback=progress_cb)
|
|
425
|
+
except TypeError:
|
|
426
|
+
return engine.run_all_checks()
|
|
427
|
+
|
|
428
|
+
def metric_no_trunc(label: str, value) -> None:
|
|
429
|
+
"""Metric-style display without Streamlit's label truncation."""
|
|
430
|
+
st.caption(label)
|
|
431
|
+
st.markdown(
|
|
432
|
+
f"<div style='font-size:1.6rem; font-weight:600; line-height:1.1'>{value}</div>",
|
|
433
|
+
unsafe_allow_html=True,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# ==========================
|
|
437
|
+
# UI helpers (Correlation & Regression renderers)
|
|
438
|
+
# ==========================
|
|
439
|
+
|
|
440
|
+
def _render_correlation_outputs(results, title="Numeric Correlation"):
|
|
441
|
+
import os
|
|
442
|
+
import pandas as pd
|
|
443
|
+
st.subheader(title)
|
|
444
|
+
corr_res = (results or {}).get("CorrelationCheck", {}) or {}
|
|
445
|
+
art = corr_res.get("artifacts", corr_res)
|
|
446
|
+
summary = corr_res.get("summary", {}) or {}
|
|
447
|
+
notes = corr_res.get("notes", []) or []
|
|
448
|
+
|
|
449
|
+
heatmap_path = art.get("heatmap_path")
|
|
450
|
+
if heatmap_path and os.path.exists(heatmap_path):
|
|
451
|
+
caption = (
|
|
452
|
+
f"Correlation Heatmap ({summary.get('method','')}) — "
|
|
453
|
+
+ ("full matrix" if summary.get('plotted_full_matrix') else
|
|
454
|
+
f"showing {summary.get('plotted_features')}/{summary.get('n_numeric_features')} features (subset)")
|
|
455
|
+
)
|
|
456
|
+
st.image(heatmap_path, caption=caption, width="stretch")
|
|
457
|
+
else:
|
|
458
|
+
st.info("Heatmap not available.")
|
|
459
|
+
|
|
460
|
+
# Format the threshold so 0.4 shows as 0.40 (avoid 0.39999999999999997)
|
|
461
|
+
thr_raw = summary.get("threshold", None)
|
|
462
|
+
thr_str = _fmt2(float(thr_raw), decimals=2) if isinstance(thr_raw, (int, float)) else "—"
|
|
463
|
+
|
|
464
|
+
c1, c2, c3, c4 = st.columns(4)
|
|
465
|
+
labels = ["Numeric features", "Pairs evaluated", f"Pairs ≥ {thr_str}", "Method"]
|
|
466
|
+
values = [
|
|
467
|
+
summary.get("n_numeric_features", "—"),
|
|
468
|
+
summary.get("n_pairs_total", "—"),
|
|
469
|
+
summary.get("n_pairs_flagged_ge_threshold", "—"),
|
|
470
|
+
(summary.get("method", "—") or "—").capitalize(),
|
|
471
|
+
]
|
|
472
|
+
for col, lab in zip((c1, c2, c3, c4), labels):
|
|
473
|
+
col.markdown(f'<div class="tanml-kpi-label">{lab}</div>', unsafe_allow_html=True)
|
|
474
|
+
for col, val in zip((c1, c2, c3, c4), values):
|
|
475
|
+
col.markdown(f'<div class="tanml-kpi-value">{val}</div>', unsafe_allow_html=True)
|
|
476
|
+
|
|
477
|
+
main_csv = art.get("top_pairs_main_csv")
|
|
478
|
+
if main_csv and os.path.exists(main_csv):
|
|
479
|
+
prev_df = pd.read_csv(main_csv)
|
|
480
|
+
st.caption("Top high-correlation pairs (preview)")
|
|
481
|
+
st.dataframe(prev_df.head(3), width="stretch", height=160, hide_index=True)
|
|
482
|
+
else:
|
|
483
|
+
st.info("No high-correlation pairs at current threshold.")
|
|
484
|
+
|
|
485
|
+
full_csv = art.get("top_pairs_csv")
|
|
486
|
+
if full_csv and os.path.exists(full_csv):
|
|
487
|
+
with open(full_csv, "rb") as f:
|
|
488
|
+
st.download_button(
|
|
489
|
+
"⬇️ Download full pairs CSV",
|
|
490
|
+
f,
|
|
491
|
+
file_name="correlation_top_pairs.csv",
|
|
492
|
+
mime="text/csv",
|
|
493
|
+
key=f"corrcsv::{full_csv}"
|
|
494
|
+
)
|
|
495
|
+
for n in notes:
|
|
496
|
+
st.caption(f"• {n}")
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def _render_regression_outputs(results):
|
|
500
|
+
"""Show regression tiles + detailed metrics only (no visuals)."""
|
|
501
|
+
task = results.get("task_type", "classification")
|
|
502
|
+
if task != "regression":
|
|
503
|
+
return
|
|
504
|
+
|
|
505
|
+
# Key tiles
|
|
506
|
+
st.subheader("Key Metrics")
|
|
507
|
+
c1, c2, c3 = st.columns(3)
|
|
508
|
+
c1.metric("RMSE", _fmt2(_g(results, "summary", "rmse")))
|
|
509
|
+
c2.metric("MAE", _fmt2(_g(results, "summary", "mae")))
|
|
510
|
+
r2_val = _g(results, "summary", "r2")
|
|
511
|
+
c3.metric("R²", _fmt2(r2_val, decimals=2))
|
|
512
|
+
|
|
513
|
+
# Detailed metrics
|
|
514
|
+
st.subheader("Regression Metrics (Detailed)")
|
|
515
|
+
adj_r2 = _g(results, "RegressionMetrics", "r2_adjusted")
|
|
516
|
+
med_ae = _g(results, "RegressionMetrics", "median_ae")
|
|
517
|
+
mv = _g(results, "RegressionMetrics", "mape_or_smape")
|
|
518
|
+
m_is_mape = bool(_g(results, "RegressionMetrics", "mape_used", default=False))
|
|
519
|
+
|
|
520
|
+
st.table({
|
|
521
|
+
"Metric": ["Adjusted R²", "Median AE", "MAPE/SMAPE"],
|
|
522
|
+
"Value": [
|
|
523
|
+
_fmt2(adj_r2, decimals=2) if adj_r2 is not None else "N/A",
|
|
524
|
+
_fmt2(med_ae, decimals=2) if med_ae is not None else "—",
|
|
525
|
+
(f"{_fmt2(mv, decimals=2)}% (MAPE)" if m_is_mape else
|
|
526
|
+
(f"{_fmt2(mv, decimals=2)}% (SMAPE)" if mv is not None else "N/A")),
|
|
527
|
+
],
|
|
528
|
+
})
|
|
529
|
+
|
|
530
|
+
# Notes, if any
|
|
531
|
+
notes = _g(results, "RegressionMetrics", "notes", default=[]) or []
|
|
532
|
+
if notes:
|
|
533
|
+
with st.expander("Metric notes"):
|
|
534
|
+
for n in notes:
|
|
535
|
+
st.write("• " + str(n))
|
|
536
|
+
|
|
537
|
+
def tvr_render_extras(section_id: str):
|
|
538
|
+
extras = st.session_state.get(_tvr_key(section_id, "extras"))
|
|
539
|
+
if not extras:
|
|
540
|
+
return
|
|
541
|
+
|
|
542
|
+
results = extras.get("results", {}) or {}
|
|
543
|
+
task = results.get("task_type", "classification")
|
|
544
|
+
|
|
545
|
+
# Correlation always (if present)
|
|
546
|
+
_render_correlation_outputs(results)
|
|
547
|
+
|
|
548
|
+
# Task-aware metrics/dashboard
|
|
549
|
+
if task == "regression":
|
|
550
|
+
_render_regression_outputs(results)
|
|
551
|
+
else:
|
|
552
|
+
st.subheader("Key Metrics")
|
|
553
|
+
summary = (results.get("summary") or {})
|
|
554
|
+
auc_val = _fmt2(summary.get("auc"))
|
|
555
|
+
ks_val = _fmt2(summary.get("ks"))
|
|
556
|
+
rules_failed_v = _fmt2(summary.get("rules_failed"))
|
|
557
|
+
c1, c2, c3 = st.columns(3)
|
|
558
|
+
c1.metric("AUC", auc_val)
|
|
559
|
+
c2.metric("KS", ks_val)
|
|
560
|
+
c3.metric("Rules failed", rules_failed_v)
|
|
561
|
+
|
|
562
|
+
st.subheader("Run Summary")
|
|
563
|
+
st.write({
|
|
564
|
+
"Train rows": int(extras.get("train_rows", 0) or 0),
|
|
565
|
+
"Test rows": int(extras.get("test_rows", 0) or 0),
|
|
566
|
+
"Target": extras.get("target", "—"),
|
|
567
|
+
"Features used": int(extras.get("n_features", 0) or 0),
|
|
568
|
+
"Seed used": extras.get("seed_used", "—"),
|
|
569
|
+
"Effective config": extras.get("eff_path", "—"),
|
|
570
|
+
"Artifacts dir": extras.get("artifacts_dir", "—"),
|
|
571
|
+
})
|
|
572
|
+
|
|
573
|
+
# ------- Artifacts -------
|
|
574
|
+
art_dir = Path(extras.get("artifacts_dir") or "")
|
|
575
|
+
if not art_dir or not art_dir.exists():
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
# Optional: show first SHAP image (if any)
|
|
579
|
+
shap_imgs = sorted(art_dir.glob("**/*shap*.*"))
|
|
580
|
+
if shap_imgs:
|
|
581
|
+
st.image(str(shap_imgs[0]), caption="SHAP summary", width="stretch")
|
|
582
|
+
|
|
583
|
+
st.subheader("Artifacts")
|
|
584
|
+
|
|
585
|
+
SKIP_SUFFIXES = ("_top_pairs_main.csv",)
|
|
586
|
+
pretty_names = {
|
|
587
|
+
"correlation_top_pairs.csv": "Flagged pairs (full)",
|
|
588
|
+
"heatmap.png": "Correlation heatmap",
|
|
589
|
+
"pearson_corr.csv": "Pearson correlation matrix",
|
|
590
|
+
"spearman_corr.csv": "Spearman correlation matrix",
|
|
591
|
+
}
|
|
592
|
+
order_hint = [
|
|
593
|
+
"correlation_top_pairs.csv",
|
|
594
|
+
"heatmap.png",
|
|
595
|
+
"pearson_corr.csv",
|
|
596
|
+
"spearman_corr.csv",
|
|
597
|
+
]
|
|
598
|
+
|
|
599
|
+
all_paths = [p for p in art_dir.glob("**/*") if p.is_file()]
|
|
600
|
+
files_list = [p for p in all_paths if not any(p.name.endswith(sfx) for sfx in SKIP_SUFFIXES)]
|
|
601
|
+
if not files_list:
|
|
602
|
+
st.caption("No artifacts were saved.")
|
|
603
|
+
return
|
|
604
|
+
|
|
605
|
+
files_list.sort(key=lambda p: (order_hint.index(p.name) if p.name in order_hint else 999, p.name.lower()))
|
|
606
|
+
|
|
607
|
+
for p in files_list[:100]:
|
|
608
|
+
label = pretty_names.get(p.name, p.name)
|
|
609
|
+
with open(p, "rb") as fh:
|
|
610
|
+
st.download_button(
|
|
611
|
+
f"⬇️ Download {label}",
|
|
612
|
+
fh.read(),
|
|
613
|
+
file_name=p.name,
|
|
614
|
+
key=f"art::{section_id}::{p}",
|
|
615
|
+
width="stretch",
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# ==========================
|
|
619
|
+
# UI — Refit-only (20 models)
|
|
620
|
+
# ==========================
|
|
621
|
+
|
|
622
|
+
st.set_page_config(page_title="TanML — Refit & Validate", layout="wide")
|
|
623
|
+
st.title("TanML • Refit & Validate")
|
|
624
|
+
|
|
625
|
+
run_dir = _session_dir()
|
|
626
|
+
artifacts_dir = run_dir / "artifacts"
|
|
627
|
+
|
|
628
|
+
# Sidebar
|
|
629
|
+
|
|
630
|
+
with st.sidebar.expander("Checks & Options", expanded=True):
|
|
631
|
+
eda_enabled = st.checkbox("EDA plots", True, key="opt_eda")
|
|
632
|
+
eda_max_plots = st.number_input("EDA max plots (-1=all numeric)", value=-1, step=1, key="opt_eda_max")
|
|
633
|
+
corr_enabled = st.checkbox("Correlation matrix", True, key="opt_corr")
|
|
634
|
+
vif_enabled = st.checkbox("VIF check", True, key="opt_vif")
|
|
635
|
+
raw_data_check_enabled = st.checkbox("RawDataCheck (needs raw)", True, key="opt_rawcheck")
|
|
636
|
+
model_meta_enabled = st.checkbox("Model metadata", True, key="opt_modelmeta")
|
|
637
|
+
|
|
638
|
+
with st.sidebar.expander("Robustness / Stress Testing", expanded=True):
|
|
639
|
+
stress_enabled = st.checkbox("StressTestCheck", True, key="opt_stress")
|
|
640
|
+
stress_epsilon = st.number_input("Epsilon (noise)", 0.0, 1.0, 0.01, 0.01, key="opt_stress_eps")
|
|
641
|
+
stress_perturb_fraction = st.number_input("Perturb fraction", 0.0, 1.0, 0.20, 0.05, key="opt_stress_frac")
|
|
642
|
+
|
|
643
|
+
with st.sidebar.expander("Input Cluster Coverage", expanded=False):
|
|
644
|
+
cluster_enabled = st.checkbox("InputClusterCoverageCheck", True, key="opt_cluster")
|
|
645
|
+
cluster_k = st.number_input("n_clusters", 2, 50, 5, 1, key="opt_cluster_k")
|
|
646
|
+
cluster_max_k = st.number_input("max_k (elbow cap)", 2, 100, 10, 1, key="opt_cluster_maxk")
|
|
647
|
+
|
|
648
|
+
with st.sidebar.expander("Explainability (SHAP)", expanded=True):
|
|
649
|
+
shap_enabled = st.checkbox("Enable SHAP", True, key="opt_shap")
|
|
650
|
+
shap_bg_size = st.number_input("Background sample size", 10, 100000, 100, 10, key="opt_shap_bg")
|
|
651
|
+
shap_test_size = st.number_input("Test rows to explain", 10, 100000, 200, 10, key="opt_shap_test")
|
|
652
|
+
|
|
653
|
+
# with st.sidebar.expander("Numeric normalization (VIF stabilization)", expanded=False):
|
|
654
|
+
# apply_vif_norm = st.checkbox(
|
|
655
|
+
# "Cast numerics to float64 and round to 9 decimals",
|
|
656
|
+
# value=True,
|
|
657
|
+
# #help="Stabilizes VIF across CSV vs Parquet.",
|
|
658
|
+
# key="opt_vifnorm"
|
|
659
|
+
# )
|
|
660
|
+
|
|
661
|
+
if "cast9_round9" not in st.session_state:
|
|
662
|
+
st.session_state["cast9_round9"] = CAST9_DEFAULT
|
|
663
|
+
|
|
664
|
+
cast9 = bool(st.session_state["cast9_round9"])
|
|
665
|
+
|
|
666
|
+
apply_vif_norm = cast9
|
|
667
|
+
|
|
668
|
+
st.sidebar.subheader("Reproducibility")
|
|
669
|
+
seed_global = st.sidebar.number_input(
|
|
670
|
+
"Random seed",
|
|
671
|
+
min_value=0, max_value=2_147_483_647, value=42, step=1,
|
|
672
|
+
help="Controls random split, model refit, stress noise, clustering, and SHAP sampling.",
|
|
673
|
+
key="opt_seed"
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
with st.sidebar.expander("Correlation Settings", expanded=False):
|
|
677
|
+
corr_method = st.radio("Method", ["pearson", "spearman"], index=0, horizontal=True, key="opt_corr_method")
|
|
678
|
+
corr_cap = st.slider("Heatmap features (default 20, max 60)", min_value=10, max_value=60, value=20, step=5, key="opt_corr_cap")
|
|
679
|
+
corr_thr = st.number_input("High-correlation threshold (|r| ≥)", min_value=0.0, max_value=0.99, value=0.80, step=0.05, key="opt_corr_thr")
|
|
680
|
+
st.caption("Tip: ≥ 0.90 often means near-duplicate; confirm with VIF.")
|
|
681
|
+
|
|
682
|
+
corr_ui_cfg = {
|
|
683
|
+
"enabled": bool(corr_enabled),
|
|
684
|
+
"method": corr_method,
|
|
685
|
+
"high_corr_threshold": float(corr_thr),
|
|
686
|
+
"heatmap_max_features_default": int(corr_cap),
|
|
687
|
+
"heatmap_max_features_limit": 60,
|
|
688
|
+
"subset_strategy": "cluster",
|
|
689
|
+
"top_pairs_max": 200,
|
|
690
|
+
"sample_rows": 150_000,
|
|
691
|
+
"seed": int(seed_global),
|
|
692
|
+
"save_csv": True,
|
|
693
|
+
"save_fig": True,
|
|
694
|
+
"appendix_csv_cap": None,
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
def render_model_form(y_train, seed_global: int):
|
|
698
|
+
"""Return (library, algorithm, params, task) using the 20-model registry,
|
|
699
|
+
but never show per-model seed; we inject sidebar seed automatically.
|
|
700
|
+
"""
|
|
701
|
+
task_auto = infer_task_from_target(y_train)
|
|
702
|
+
task = st.radio(
|
|
703
|
+
"Task",
|
|
704
|
+
["classification", "regression"],
|
|
705
|
+
index=0 if task_auto == "classification" else 1,
|
|
706
|
+
horizontal=True,
|
|
707
|
+
key="mdl_task"
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
libraries_all = ["sklearn", "xgboost", "lightgbm", "catboost"]
|
|
711
|
+
library = st.selectbox("Library", libraries_all, index=0, key="mdl_lib")
|
|
712
|
+
|
|
713
|
+
avail = [(lib, algo) for (lib, algo), spec in list_models(task).items() if lib == library]
|
|
714
|
+
if not avail:
|
|
715
|
+
st.error(f"No algorithms available for {library} / {task}. Is the library installed?")
|
|
716
|
+
st.stop()
|
|
717
|
+
algo_names = [a for (_, a) in avail]
|
|
718
|
+
algo = st.selectbox("Algorithm", algo_names, index=0, key="mdl_algo")
|
|
719
|
+
|
|
720
|
+
spec = get_spec(library, algo)
|
|
721
|
+
schema = ui_schema_for(library, algo)
|
|
722
|
+
defaults = spec.defaults or {}
|
|
723
|
+
|
|
724
|
+
seed_keys = [k for k in ("random_state", "seed", "random_seed") if k in defaults]
|
|
725
|
+
params = {}
|
|
726
|
+
|
|
727
|
+
with st.expander("Hyperparameters", expanded=True):
|
|
728
|
+
if seed_keys:
|
|
729
|
+
st.caption("ℹ️ Model seed is taken from the sidebar **Random seed**.")
|
|
730
|
+
|
|
731
|
+
for name, (typ, choices, helptext) in schema.items():
|
|
732
|
+
if name in seed_keys:
|
|
733
|
+
continue
|
|
734
|
+
|
|
735
|
+
default_val = defaults.get(name)
|
|
736
|
+
|
|
737
|
+
if typ == "choice":
|
|
738
|
+
opts = list(choices) if choices else []
|
|
739
|
+
show = ["None" if o is None else o for o in opts]
|
|
740
|
+
if show:
|
|
741
|
+
if default_val is None and "None" in show:
|
|
742
|
+
idx = show.index("None")
|
|
743
|
+
elif default_val in show:
|
|
744
|
+
idx = show.index(default_val)
|
|
745
|
+
else:
|
|
746
|
+
idx = 0
|
|
747
|
+
sel = st.selectbox(name, show, index=idx, help=helptext)
|
|
748
|
+
params[name] = None if sel == "None" else sel
|
|
749
|
+
else:
|
|
750
|
+
params[name] = st.text_input(
|
|
751
|
+
name,
|
|
752
|
+
value=str(default_val) if default_val is not None else "",
|
|
753
|
+
help=helptext
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
elif typ == "bool":
|
|
757
|
+
params[name] = st.checkbox(
|
|
758
|
+
name,
|
|
759
|
+
value=bool(default_val) if default_val is not None else False,
|
|
760
|
+
help=helptext
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
elif typ == "int":
|
|
764
|
+
params[name] = int(st.number_input(
|
|
765
|
+
name,
|
|
766
|
+
value=int(default_val) if default_val is not None else 0,
|
|
767
|
+
step=1,
|
|
768
|
+
help=helptext
|
|
769
|
+
))
|
|
770
|
+
|
|
771
|
+
elif typ == "float":
|
|
772
|
+
params[name] = float(st.number_input(
|
|
773
|
+
name,
|
|
774
|
+
value=float(default_val) if default_val is not None else 0.0,
|
|
775
|
+
help=helptext
|
|
776
|
+
))
|
|
777
|
+
|
|
778
|
+
else: # "str"
|
|
779
|
+
params[name] = st.text_input(
|
|
780
|
+
name,
|
|
781
|
+
value=str(default_val) if default_val is not None else "",
|
|
782
|
+
help=helptext
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
for k in seed_keys:
|
|
786
|
+
params[k] = int(seed_global)
|
|
787
|
+
|
|
788
|
+
return library, algo, params, task
|
|
789
|
+
|
|
790
|
+
# ==========================
|
|
791
|
+
# Main layout — single flow
|
|
792
|
+
# ==========================
|
|
793
|
+
|
|
794
|
+
left, right = st.columns([1.35, 1])
|
|
795
|
+
|
|
796
|
+
with left:
|
|
797
|
+
st.header("1) Choose data source")
|
|
798
|
+
data_source = st.radio(
|
|
799
|
+
"Data source",
|
|
800
|
+
("Single cleaned file (you split)", "Already split: Train & Test"),
|
|
801
|
+
index=0, horizontal=True
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
saved_raw = None
|
|
805
|
+
cleaned_df = train_df = test_df = None
|
|
806
|
+
|
|
807
|
+
if data_source.startswith("Single"):
|
|
808
|
+
st.subheader("Upload files")
|
|
809
|
+
cleaned_file = st.file_uploader(
|
|
810
|
+
"Cleaned dataset (required)",
|
|
811
|
+
key="upl_cleaned"
|
|
812
|
+
)
|
|
813
|
+
raw_file = st.file_uploader(
|
|
814
|
+
"Raw dataset (optional)",
|
|
815
|
+
key="upl_raw_single"
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
saved_cleaned = _save_upload(cleaned_file, run_dir)
|
|
819
|
+
saved_raw = _save_upload(raw_file, run_dir)
|
|
820
|
+
|
|
821
|
+
df_preview = None
|
|
822
|
+
if saved_cleaned:
|
|
823
|
+
st.success(f"Cleaned file saved: `{saved_cleaned}`")
|
|
824
|
+
try:
|
|
825
|
+
df_preview = load_dataframe(saved_cleaned)
|
|
826
|
+
cleaned_df = df_preview
|
|
827
|
+
st.write("Preview (top 10 rows):")
|
|
828
|
+
st.dataframe(df_preview.head(10), width="stretch")
|
|
829
|
+
except Exception as e:
|
|
830
|
+
st.error(f"Could not read cleaned file: {e}")
|
|
831
|
+
|
|
832
|
+
st.subheader("Configure data")
|
|
833
|
+
if df_preview is not None:
|
|
834
|
+
target_default = _pick_target(df_preview)
|
|
835
|
+
cols = list(df_preview.columns)
|
|
836
|
+
target = st.selectbox("Target column", options=cols, index=cols.index(target_default) if target_default in cols else 0)
|
|
837
|
+
features = st.multiselect(
|
|
838
|
+
"Features",
|
|
839
|
+
options=[c for c in cols if c != target],
|
|
840
|
+
default=[c for c in cols if c != target],
|
|
841
|
+
)
|
|
842
|
+
else:
|
|
843
|
+
target, features = None, []
|
|
844
|
+
|
|
845
|
+
test_size = st.slider("Hold-out test size", 0.1, 0.5, 0.3, 0.05)
|
|
846
|
+
|
|
847
|
+
else:
|
|
848
|
+
st.subheader("Upload TRAIN/TEST (cleaned)")
|
|
849
|
+
train_cleaned = st.file_uploader(
|
|
850
|
+
"Train (cleaned) — required",
|
|
851
|
+
key="upl_train"
|
|
852
|
+
)
|
|
853
|
+
test_cleaned = st.file_uploader(
|
|
854
|
+
"Test (cleaned) — required",
|
|
855
|
+
key="upl_test"
|
|
856
|
+
)
|
|
857
|
+
raw_file = st.file_uploader(
|
|
858
|
+
"Raw dataset (optional, global)",
|
|
859
|
+
key="upl_raw_global"
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
saved_train = _save_upload(train_cleaned, run_dir)
|
|
863
|
+
saved_test = _save_upload(test_cleaned, run_dir)
|
|
864
|
+
saved_raw = _save_upload(raw_file, run_dir)
|
|
865
|
+
|
|
866
|
+
df_tr = df_te = None
|
|
867
|
+
if saved_train:
|
|
868
|
+
try:
|
|
869
|
+
df_tr = load_dataframe(saved_train)
|
|
870
|
+
train_df = df_tr
|
|
871
|
+
except Exception as e:
|
|
872
|
+
st.error(f"Could not read train: {e}")
|
|
873
|
+
if saved_test:
|
|
874
|
+
try:
|
|
875
|
+
df_te = load_dataframe(saved_test)
|
|
876
|
+
test_df = df_te
|
|
877
|
+
except Exception as e:
|
|
878
|
+
st.error(f"Could not read test: {e}")
|
|
879
|
+
|
|
880
|
+
if df_tr is not None:
|
|
881
|
+
st.write("Train preview (top 10):")
|
|
882
|
+
st.dataframe(df_tr.head(10), width="stretch")
|
|
883
|
+
target_default = _pick_target(df_tr)
|
|
884
|
+
cols = list(df_tr.columns)
|
|
885
|
+
target = st.selectbox("Target column", options=cols, index=cols.index(target_default) if target_default in cols else 0)
|
|
886
|
+
features = st.multiselect(
|
|
887
|
+
"Features",
|
|
888
|
+
options=[c for c in cols if c != target],
|
|
889
|
+
default=[c for c in cols if c != target],
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
target, features = None, []
|
|
893
|
+
|
|
894
|
+
with right:
|
|
895
|
+
st.header("2) Refit, Validate & Report")
|
|
896
|
+
tvr_render_ready("refit", header_text="Run & Report (last run)")
|
|
897
|
+
tvr_render_extras("refit")
|
|
898
|
+
|
|
899
|
+
report_name = st.text_input(
|
|
900
|
+
"Report file name (.docx)",
|
|
901
|
+
value=f"tanml_report_{int(time.time())}.docx"
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
# --- Choose model BEFORE running ---
|
|
905
|
+
if data_source.startswith("Single"):
|
|
906
|
+
if 'cleaned_df' in locals() and cleaned_df is not None and target:
|
|
907
|
+
y_for_task = cleaned_df[target]
|
|
908
|
+
else:
|
|
909
|
+
y_for_task = pd.Series([], dtype="float64")
|
|
910
|
+
else:
|
|
911
|
+
if 'train_df' in locals() and train_df is not None and target:
|
|
912
|
+
y_for_task = train_df[target]
|
|
913
|
+
else:
|
|
914
|
+
y_for_task = pd.Series([], dtype="float64")
|
|
915
|
+
|
|
916
|
+
library_selected, algo_selected, user_hp, task_selected = render_model_form(y_for_task, seed_global)
|
|
917
|
+
|
|
918
|
+
# ---- Conditional thresholds (left panel) ----
|
|
919
|
+
with st.sidebar:
|
|
920
|
+
if task_selected == "classification":
|
|
921
|
+
st.subheader("Thresholds")
|
|
922
|
+
c1, c2, c3 = st.columns(3)
|
|
923
|
+
auc_min = c1.number_input("AUC ≥", 0.0, 1.0, 0.60, 0.01, key="thr_auc")
|
|
924
|
+
f1_min = c2.number_input("F1 ≥", 0.0, 1.0, 0.60, 0.01, key="thr_f1")
|
|
925
|
+
ks_min = c3.number_input("KS ≥", 0.0, 1.0, 0.20, 0.01, key="thr_ks")
|
|
926
|
+
st.session_state["__thr_block__"] = {"AUC_min": auc_min, "F1_min": f1_min, "KS_min": ks_min}
|
|
927
|
+
else:
|
|
928
|
+
# Regression (or anything not classification) → hide thresholds entirely.
|
|
929
|
+
auc_min = 0.0
|
|
930
|
+
f1_min = 0.0
|
|
931
|
+
ks_min = 0.0
|
|
932
|
+
st.session_state["__thr_block__"] = {"problem_type": "regression"}
|
|
933
|
+
|
|
934
|
+
st.session_state["model_selection"] = {
|
|
935
|
+
"library": library_selected,
|
|
936
|
+
"algo": algo_selected,
|
|
937
|
+
"hp": user_hp,
|
|
938
|
+
"task": task_selected,
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
if data_source.startswith("Single"):
|
|
942
|
+
ready = bool((locals().get("saved_cleaned") is not None) and target and features)
|
|
943
|
+
else:
|
|
944
|
+
ready = bool((locals().get("saved_train") is not None) and (locals().get("saved_test") is not None) and target and features)
|
|
945
|
+
ready = ready and bool(st.session_state.get("model_selection", {}).get("algo"))
|
|
946
|
+
|
|
947
|
+
go = st.button("▶️ Refit & validate", type="primary", disabled=not ready)
|
|
948
|
+
if not ready:
|
|
949
|
+
st.info("Provide data, pick target + features, choose a model, then run.")
|
|
950
|
+
|
|
951
|
+
if go:
|
|
952
|
+
try:
|
|
953
|
+
# ---- Build X_train/X_test/y_train/y_test ----
|
|
954
|
+
if data_source.startswith("Single"):
|
|
955
|
+
cleaned_df = cleaned_df if cleaned_df is not None else load_dataframe(saved_cleaned)
|
|
956
|
+
if target not in cleaned_df.columns:
|
|
957
|
+
st.error(f"Target '{target}' not found in cleaned data."); st.stop()
|
|
958
|
+
|
|
959
|
+
safe_features = [c for c in features if c in cleaned_df.columns and c != target]
|
|
960
|
+
if not safe_features:
|
|
961
|
+
safe_features = [c for c in cleaned_df.columns if c != target]
|
|
962
|
+
|
|
963
|
+
X = cleaned_df[safe_features].copy()
|
|
964
|
+
y = cleaned_df[target].copy()
|
|
965
|
+
|
|
966
|
+
if apply_vif_norm:
|
|
967
|
+
cleaned_df = _normalize_vif(cleaned_df)
|
|
968
|
+
X = _normalize_vif(X)
|
|
969
|
+
|
|
970
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
971
|
+
X, y, test_size=float(test_size), random_state=seed_global, shuffle=True
|
|
972
|
+
)
|
|
973
|
+
df_checks = pd.concat([X_train, y_train], axis=1)
|
|
974
|
+
if apply_vif_norm:
|
|
975
|
+
df_checks = _normalize_vif(df_checks)
|
|
976
|
+
|
|
977
|
+
split_strategy = "random"
|
|
978
|
+
saved_raw_ = saved_raw
|
|
979
|
+
|
|
980
|
+
else:
|
|
981
|
+
train_df = train_df if train_df is not None else load_dataframe(saved_train)
|
|
982
|
+
test_df = test_df if test_df is not None else load_dataframe(saved_test)
|
|
983
|
+
|
|
984
|
+
if target not in train_df.columns: st.error(f"Target '{target}' not in TRAIN."); st.stop()
|
|
985
|
+
if target not in test_df.columns: st.error(f"Target '{target}' not in TEST."); st.stop()
|
|
986
|
+
|
|
987
|
+
X_train = train_df[features].copy(); y_train = train_df[target].copy()
|
|
988
|
+
te_sub = test_df[features + [target]].copy()
|
|
989
|
+
te_aligned, err = _schema_align_or_error(train_df[features + [target]], te_sub)
|
|
990
|
+
if err: st.error(err); st.stop()
|
|
991
|
+
X_test = te_aligned[features].copy(); y_test = te_aligned[target].copy()
|
|
992
|
+
|
|
993
|
+
if apply_vif_norm:
|
|
994
|
+
train_df = _normalize_vif(train_df); test_df = _normalize_vif(test_df)
|
|
995
|
+
X_train = _normalize_vif(X_train); X_test = _normalize_vif(X_test)
|
|
996
|
+
|
|
997
|
+
df_checks = pd.concat([X_train, y_train], axis=1)
|
|
998
|
+
if apply_vif_norm: df_checks = _normalize_vif(df_checks)
|
|
999
|
+
|
|
1000
|
+
overlap_pct = _row_overlap_pct(
|
|
1001
|
+
pd.concat([X_train, y_train], axis=1),
|
|
1002
|
+
pd.concat([X_test, y_test], axis=1),
|
|
1003
|
+
cols=features + [target]
|
|
1004
|
+
)
|
|
1005
|
+
if overlap_pct > 0:
|
|
1006
|
+
st.warning(f"Potential Train↔Test row overlap: ~{overlap_pct:.2f}% (by row hash).")
|
|
1007
|
+
|
|
1008
|
+
split_strategy = "supplied"
|
|
1009
|
+
saved_raw_ = saved_raw
|
|
1010
|
+
|
|
1011
|
+
# ---- Build estimator from selection ----
|
|
1012
|
+
sel = st.session_state.get("model_selection") or {}
|
|
1013
|
+
library_selected = sel.get("library")
|
|
1014
|
+
algo_selected = sel.get("algo")
|
|
1015
|
+
user_hp = sel.get("hp") or {}
|
|
1016
|
+
task_selected = sel.get("task")
|
|
1017
|
+
|
|
1018
|
+
if not library_selected or not algo_selected:
|
|
1019
|
+
st.error("Please choose a library and algorithm before running.")
|
|
1020
|
+
st.stop()
|
|
1021
|
+
|
|
1022
|
+
try:
|
|
1023
|
+
model = build_estimator(library_selected, algo_selected, user_hp)
|
|
1024
|
+
except ImportError:
|
|
1025
|
+
st.error(
|
|
1026
|
+
f"Missing dependency for '{library_selected}.{algo_selected}'. "
|
|
1027
|
+
f"Install the library (e.g., 'pip install {library_selected}') and try again."
|
|
1028
|
+
)
|
|
1029
|
+
st.stop()
|
|
1030
|
+
|
|
1031
|
+
model.fit(X_train, y_train)
|
|
1032
|
+
if not hasattr(model, "feature_names_in_"):
|
|
1033
|
+
try:
|
|
1034
|
+
model.feature_names_in_ = X_train.columns.to_numpy()
|
|
1035
|
+
except Exception:
|
|
1036
|
+
pass
|
|
1037
|
+
|
|
1038
|
+
component_seeds = _derive_component_seeds(
|
|
1039
|
+
seed_global,
|
|
1040
|
+
split_random=(split_strategy == "random"),
|
|
1041
|
+
stress_enabled=stress_enabled,
|
|
1042
|
+
cluster_enabled=cluster_enabled,
|
|
1043
|
+
shap_enabled=shap_enabled,
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
rule_cfg = _build_rule_cfg(
|
|
1047
|
+
saved_raw=saved_raw_,
|
|
1048
|
+
auc_min=auc_min, f1_min=f1_min, ks_min=ks_min,
|
|
1049
|
+
eda_enabled=eda_enabled, eda_max_plots=int(eda_max_plots),
|
|
1050
|
+
corr_enabled=corr_enabled, vif_enabled=vif_enabled,
|
|
1051
|
+
raw_data_check_enabled=raw_data_check_enabled and bool(saved_raw_),
|
|
1052
|
+
model_meta_enabled=model_meta_enabled,
|
|
1053
|
+
stress_enabled=stress_enabled,
|
|
1054
|
+
stress_epsilon=stress_epsilon,
|
|
1055
|
+
stress_perturb_fraction=stress_perturb_fraction,
|
|
1056
|
+
cluster_enabled=cluster_enabled,
|
|
1057
|
+
cluster_k=int(cluster_k),
|
|
1058
|
+
cluster_max_k=int(cluster_max_k),
|
|
1059
|
+
shap_enabled=shap_enabled,
|
|
1060
|
+
shap_bg_size=int(shap_bg_size),
|
|
1061
|
+
shap_test_size=int(shap_test_size),
|
|
1062
|
+
artifacts_dir=artifacts_dir,
|
|
1063
|
+
split_strategy=split_strategy,
|
|
1064
|
+
test_size=float(test_size) if split_strategy == "random" else 0.0,
|
|
1065
|
+
seed_global=seed_global,
|
|
1066
|
+
component_seeds=component_seeds,
|
|
1067
|
+
in_scope_cols=list(X_train.columns) + [target],
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
rule_cfg.setdefault("explainability", {}).setdefault("shap", {})["out_dir"] = str(artifacts_dir)
|
|
1072
|
+
|
|
1073
|
+
# 2) Backward-compat for wrappers that look under "SHAPCheck"
|
|
1074
|
+
rule_cfg["SHAPCheck"] = {
|
|
1075
|
+
"enabled": bool(shap_enabled),
|
|
1076
|
+
"background_size": int(shap_bg_size),
|
|
1077
|
+
"sample_size": int(shap_test_size),
|
|
1078
|
+
"out_dir": str(artifacts_dir),
|
|
1079
|
+
}
|
|
1080
|
+
# Inject correlation settings
|
|
1081
|
+
rule_cfg.setdefault("CorrelationCheck", {}).update({
|
|
1082
|
+
"enabled": bool(corr_enabled), **{
|
|
1083
|
+
"method": corr_method,
|
|
1084
|
+
"high_corr_threshold": float(corr_thr),
|
|
1085
|
+
"heatmap_max_features_default": int(corr_cap),
|
|
1086
|
+
"heatmap_max_features_limit": 60,
|
|
1087
|
+
"subset_strategy": "cluster",
|
|
1088
|
+
"top_pairs_max": 200,
|
|
1089
|
+
"sample_rows": 150_000,
|
|
1090
|
+
"seed": int(seed_global),
|
|
1091
|
+
"save_csv": True,
|
|
1092
|
+
"save_fig": True,
|
|
1093
|
+
"appendix_csv_cap": None,
|
|
1094
|
+
}
|
|
1095
|
+
})
|
|
1096
|
+
rule_cfg.setdefault("correlation", {}).update({"enabled": bool(corr_enabled)})
|
|
1097
|
+
|
|
1098
|
+
raw_df_loaded = load_dataframe(saved_raw_) if saved_raw_ else None
|
|
1099
|
+
|
|
1100
|
+
engine = ValidationEngine(
|
|
1101
|
+
model, X_train, X_test, y_train, y_test, rule_cfg, df_checks,
|
|
1102
|
+
raw_df=raw_df_loaded, ctx=st.session_state
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
eff_path = run_dir / "effective_config.yaml"
|
|
1106
|
+
with st.status("Refitting & running checks…", expanded=True) as status:
|
|
1107
|
+
def cb(msg: str):
|
|
1108
|
+
try: status.write(msg)
|
|
1109
|
+
except Exception: pass
|
|
1110
|
+
|
|
1111
|
+
results = _try_run_engine(engine, progress_cb=cb)
|
|
1112
|
+
|
|
1113
|
+
# Stamp task for downstream UI/history
|
|
1114
|
+
results["task_type"] = task_selected
|
|
1115
|
+
results.setdefault("summary", {})["task_type"] = task_selected
|
|
1116
|
+
|
|
1117
|
+
rules_meta = {}
|
|
1118
|
+
if task_selected == "classification":
|
|
1119
|
+
rules_meta = {"auc_roc_min": auc_min, "f1_min": f1_min, "ks_min": ks_min}
|
|
1120
|
+
|
|
1121
|
+
results.update({
|
|
1122
|
+
"validation_date": pd.Timestamp.now(tz="America/Chicago").strftime("%Y-%m-%d %H:%M:%S %Z"),
|
|
1123
|
+
"model_path": "(refit in UI)",
|
|
1124
|
+
"validated_by": "TanML UI (Refit-only)",
|
|
1125
|
+
"rules": rules_meta, # conditional now
|
|
1126
|
+
"data_split": "random" if split_strategy == "random" else "supplied_train_test",
|
|
1127
|
+
"reproducibility": {"seed_global": int(seed_global), "component_seeds": component_seeds, "split_strategy": split_strategy},
|
|
1128
|
+
"model_provenance": {
|
|
1129
|
+
"refit_always": True,
|
|
1130
|
+
"library": library_selected,
|
|
1131
|
+
"algorithm": algo_selected,
|
|
1132
|
+
"hyperparameters_used": user_hp,
|
|
1133
|
+
"hyperparameters_source": "user_form"
|
|
1134
|
+
}
|
|
1135
|
+
})
|
|
1136
|
+
|
|
1137
|
+
eff_paths = (
|
|
1138
|
+
{"cleaned": str(saved_cleaned), "raw": str(saved_raw_) if saved_raw_ else None}
|
|
1139
|
+
if split_strategy == "random" else
|
|
1140
|
+
{"train_cleaned": str(saved_train), "test_cleaned": str(saved_test), "raw": str(saved_raw_) if saved_raw_ else None}
|
|
1141
|
+
)
|
|
1142
|
+
effective_cfg = {
|
|
1143
|
+
"scenario": "Refit",
|
|
1144
|
+
"mode": "Refit & Validate",
|
|
1145
|
+
"paths": {**eff_paths, "artifacts_dir": str(artifacts_dir)},
|
|
1146
|
+
"data": {
|
|
1147
|
+
"target": target,
|
|
1148
|
+
"features": list(X_train.columns),
|
|
1149
|
+
"split": "random" if split_strategy == "random" else "supplied",
|
|
1150
|
+
"test_size": float(test_size) if split_strategy == "random" else None,
|
|
1151
|
+
"random_state": seed_global if split_strategy == "random" else None,
|
|
1152
|
+
},
|
|
1153
|
+
"model_refit": {"library": library_selected, "type": algo_selected, "hyperparameters": user_hp},
|
|
1154
|
+
"checks": rule_cfg,
|
|
1155
|
+
"reproducibility": {"seed_global": int(seed_global), "component_seeds": component_seeds},
|
|
1156
|
+
}
|
|
1157
|
+
try:
|
|
1158
|
+
from ruamel.yaml import YAML
|
|
1159
|
+
YAML().dump(effective_cfg, eff_path.open("w"))
|
|
1160
|
+
except Exception:
|
|
1161
|
+
eff_path.write_text(json.dumps(effective_cfg, indent=2))
|
|
1162
|
+
|
|
1163
|
+
report_path = run_dir / report_name
|
|
1164
|
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1165
|
+
|
|
1166
|
+
# ⬇️ choose correct template based on task (classification/regression)
|
|
1167
|
+
template_path = _choose_report_template(task_selected)
|
|
1168
|
+
|
|
1169
|
+
# Build the report with the chosen template
|
|
1170
|
+
ReportBuilder(
|
|
1171
|
+
results,
|
|
1172
|
+
template_path=str(template_path),
|
|
1173
|
+
output_path=report_path
|
|
1174
|
+
).build()
|
|
1175
|
+
|
|
1176
|
+
status.update(label="Done!", state="complete", expanded=False)
|
|
1177
|
+
|
|
1178
|
+
tvr_store_extras("refit", {
|
|
1179
|
+
"results": results,
|
|
1180
|
+
"train_rows": len(y_train),
|
|
1181
|
+
"test_rows": len(y_test),
|
|
1182
|
+
"target": target,
|
|
1183
|
+
"n_features": len(X_train.columns),
|
|
1184
|
+
"seed_used": seed_global if split_strategy == "random" else "N/A (user-supplied split)",
|
|
1185
|
+
"eff_path": str(eff_path),
|
|
1186
|
+
"artifacts_dir": str(artifacts_dir),
|
|
1187
|
+
})
|
|
1188
|
+
|
|
1189
|
+
tvr_finish(
|
|
1190
|
+
"refit",
|
|
1191
|
+
report_path=report_path,
|
|
1192
|
+
file_name=report_path.name,
|
|
1193
|
+
summary=results.get("summary", {}),
|
|
1194
|
+
label=f"Refit — {library_selected}.{algo_selected}",
|
|
1195
|
+
cfg=effective_cfg,
|
|
1196
|
+
)
|
|
1197
|
+
tvr_render_ready("refit", header_text="Run & Report (last run)")
|
|
1198
|
+
tvr_render_extras("refit")
|
|
1199
|
+
|
|
1200
|
+
except Exception as e:
|
|
1201
|
+
st.error(f"Refit/validate failed: {e}")
|
|
1202
|
+
st.stop()
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
st.markdown("---")
|