tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -1
- tanml/check_runners/cleaning_repro_runner.py +2 -2
- tanml/check_runners/correlation_runner.py +49 -12
- tanml/check_runners/explainability_runner.py +12 -22
- tanml/check_runners/logistic_stats_runner.py +196 -17
- tanml/check_runners/performance_runner.py +82 -26
- tanml/check_runners/raw_data_runner.py +29 -14
- tanml/check_runners/regression_metrics_runner.py +195 -0
- tanml/check_runners/stress_test_runner.py +23 -6
- tanml/check_runners/vif_runner.py +33 -27
- tanml/checks/correlation.py +241 -41
- tanml/checks/explainability/shap_check.py +261 -29
- tanml/checks/logit_stats.py +186 -54
- tanml/checks/performance_classification.py +305 -0
- tanml/checks/raw_data.py +58 -23
- tanml/checks/regression_metrics.py +167 -0
- tanml/checks/stress_test.py +157 -53
- tanml/cli/main.py +99 -27
- tanml/engine/check_agent_registry.py +20 -10
- tanml/engine/core_engine_agent.py +199 -37
- tanml/models/registry.py +329 -0
- tanml/report/report_builder.py +1180 -147
- tanml/report/templates/report_template_cls.docx +0 -0
- tanml/report/templates/report_template_reg.docx +0 -0
- tanml/ui/app.py +1205 -0
- tanml/utils/data_loader.py +105 -15
- tanml-0.1.7.dist-info/METADATA +164 -0
- tanml-0.1.7.dist-info/RECORD +54 -0
- tanml/cli/arg_parser.py +0 -31
- tanml/cli/init_cmd.py +0 -8
- tanml/cli/validate_cmd.py +0 -7
- tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
- tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
- tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
- tanml/engine/segmentation_agent.py +0 -118
- tanml/engine/validation_agent.py +0 -91
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/model_loader.py +0 -35
- tanml/utils/r_loader.py +0 -30
- tanml/utils/sas_loader.py +0 -50
- tanml/utils/yaml_generator.py +0 -34
- tanml/utils/yaml_loader.py +0 -5
- tanml/validate.py +0 -209
- tanml-0.1.6.dist-info/METADATA +0 -317
- tanml-0.1.6.dist-info/RECORD +0 -62
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
tanml/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.7"
|
|
@@ -3,12 +3,12 @@ from tanml.checks.cleaning_repro import CleaningReproCheck
|
|
|
3
3
|
|
|
4
4
|
def run_cleaning_repro_check(model, X_train, X_test, y_train, y_test,
|
|
5
5
|
config, cleaned_data, *args, **kwargs):
|
|
6
|
-
|
|
6
|
+
|
|
7
7
|
if not config.get("rules", {}).get("CleaningReproCheck", {}).get("enabled", True):
|
|
8
8
|
print("ℹ️ CleaningReproCheck skipped (disabled in rules.yaml)")
|
|
9
9
|
return None
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
|
|
12
12
|
raw_data = config.get("raw_data") or kwargs.get("raw_df")
|
|
13
13
|
if raw_data is None:
|
|
14
14
|
print("⚠️ Skipping CleaningReproCheck — raw_data missing in config and kwargs")
|
|
@@ -1,15 +1,52 @@
|
|
|
1
|
-
|
|
1
|
+
# tanml/check_runners/correlation_runner.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
import pandas as pd
|
|
2
6
|
|
|
3
|
-
|
|
4
|
-
try:
|
|
5
|
-
cfg = rule_config.get("CorrelationCheck", {})
|
|
6
|
-
if not cfg.get("enabled", True):
|
|
7
|
-
print("ℹ️ CorrelationCheck skipped (disabled in rules.yaml)")
|
|
8
|
-
return None
|
|
7
|
+
from tanml.checks.correlation import CorrelationCheck
|
|
9
8
|
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
def _resolve_outdir(config: Dict[str, Any]) -> str:
|
|
10
|
+
base = (config.get("options") or {}).get("save_artifacts_dir") or "reports"
|
|
11
|
+
outdir = os.path.join(base, "correlation")
|
|
12
|
+
os.makedirs(outdir, exist_ok=True)
|
|
13
|
+
return outdir
|
|
12
14
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
return
|
|
15
|
+
def _df_features_only(cleaned_df: pd.DataFrame) -> pd.DataFrame:
|
|
16
|
+
if cleaned_df is None or cleaned_df.empty:
|
|
17
|
+
return cleaned_df
|
|
18
|
+
cols = list(cleaned_df.columns)
|
|
19
|
+
if len(cols) >= 2:
|
|
20
|
+
return cleaned_df[cols[:-1]]
|
|
21
|
+
return cleaned_df
|
|
22
|
+
|
|
23
|
+
def CorrelationCheckRunner(
|
|
24
|
+
model,
|
|
25
|
+
X_train,
|
|
26
|
+
X_test,
|
|
27
|
+
y_train,
|
|
28
|
+
y_test,
|
|
29
|
+
config: Dict[str, Any],
|
|
30
|
+
cleaned_df: pd.DataFrame,
|
|
31
|
+
raw_df: pd.DataFrame | None = None,
|
|
32
|
+
):
|
|
33
|
+
ui_block: Dict[str, Any] = (config.get("CorrelationCheck") or {})
|
|
34
|
+
legacy: Dict[str, Any] = (config.get("correlation") or {})
|
|
35
|
+
if not bool(ui_block.get("enabled", legacy.get("enabled", True))):
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
df = _df_features_only(cleaned_df)
|
|
39
|
+
cfg: Dict[str, Any] = {
|
|
40
|
+
"method": ui_block.get("method", "pearson"),
|
|
41
|
+
"high_corr_threshold": float(ui_block.get("high_corr_threshold", 0.8)),
|
|
42
|
+
"heatmap_max_features_default": int(ui_block.get("heatmap_max_features_default", 20)),
|
|
43
|
+
"heatmap_max_features_limit": int(ui_block.get("heatmap_max_features_limit", 60)),
|
|
44
|
+
"subset_strategy": ui_block.get("subset_strategy", "cluster"),
|
|
45
|
+
"sample_rows": int(ui_block.get("sample_rows", 150_000)),
|
|
46
|
+
"seed": int(ui_block.get("seed", 42)),
|
|
47
|
+
"save_csv": True,
|
|
48
|
+
"save_fig": True,
|
|
49
|
+
"appendix_csv_cap": ui_block.get("appendix_csv_cap", None),
|
|
50
|
+
}
|
|
51
|
+
outdir = _resolve_outdir(config)
|
|
52
|
+
return CorrelationCheck(cleaned_data=df, cfg=cfg, output_dir=outdir).run()
|
|
@@ -1,28 +1,18 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
1
|
+
# explainability runner
|
|
3
2
|
from tanml.checks.explainability.shap_check import SHAPCheck
|
|
4
3
|
|
|
5
|
-
def run_shap_check(
|
|
6
|
-
model, X_train, X_test, y_train, y_test,
|
|
7
|
-
rule_config, cleaned_df, *args, **kwargs
|
|
8
|
-
):
|
|
4
|
+
def run_shap_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
9
5
|
try:
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
model=model,
|
|
17
|
-
X_train=X_train,
|
|
18
|
-
X_test=X_test,
|
|
19
|
-
y_train=y_train,
|
|
20
|
-
y_test=y_test,
|
|
21
|
-
rule_config=rule_config,
|
|
22
|
-
cleaned_df=cleaned_df,
|
|
23
|
-
)
|
|
24
|
-
return check.run()
|
|
6
|
+
cfg_shapcheck = (rule_config or {}).get("SHAPCheck", {}) or {}
|
|
7
|
+
cfg_expl = (rule_config or {}).get("explainability", {}).get("shap", {}) or {}
|
|
8
|
+
enabled = cfg_shapcheck.get("enabled", cfg_expl.get("enabled", True))
|
|
9
|
+
if not enabled:
|
|
10
|
+
print("ℹ️ SHAPCheck skipped (disabled)")
|
|
11
|
+
return {"SHAPCheck": {"skipped": True}}
|
|
25
12
|
|
|
13
|
+
check = SHAPCheck(model, X_train, X_test, y_train, y_test, rule_config=rule_config, cleaned_df=cleaned_df)
|
|
14
|
+
result = check.run()
|
|
15
|
+
return {"SHAPCheck": result}
|
|
26
16
|
except Exception as e:
|
|
27
17
|
print(f"⚠️ SHAPCheck failed: {e}")
|
|
28
|
-
return {"SHAPCheck": {"error": str(e)}}
|
|
18
|
+
return {"SHAPCheck": {"status": "error", "error": str(e)}}
|
|
@@ -1,28 +1,207 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
1
|
+
# tanml/check_runners/logistic_stats_runner.py
|
|
2
|
+
from __future__ import annotations
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
from tanml.checks.logit_stats import _prep_design_matrix_df
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import statsmodels.api as sm
|
|
9
|
+
from sklearn.linear_model import LogisticRegression
|
|
10
|
+
from sklearn.metrics import (
|
|
11
|
+
roc_auc_score,
|
|
12
|
+
roc_curve,
|
|
13
|
+
precision_recall_fscore_support,
|
|
14
|
+
accuracy_score,
|
|
15
|
+
average_precision_score,
|
|
16
|
+
brier_score_loss,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _is_binary_series(y: pd.Series) -> bool:
|
|
5
21
|
try:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
22
|
+
u = pd.unique(pd.Series(y).dropna())
|
|
23
|
+
return len(u) == 2
|
|
24
|
+
except Exception:
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _prep_design_matrix(
|
|
29
|
+
X_like: Any, ref_columns: pd.Index | None, add_const: bool = True
|
|
30
|
+
) -> pd.DataFrame:
|
|
31
|
+
"""
|
|
32
|
+
1) Convert to DataFrame
|
|
33
|
+
2) One-hot encode (drop_first=True)
|
|
34
|
+
3) Align to ref_columns (if given), filling missing cols with 0 and dropping extras
|
|
35
|
+
4) Coerce to numeric & sanitize
|
|
36
|
+
5) Optionally add constant
|
|
37
|
+
"""
|
|
38
|
+
Xd = X_like if isinstance(X_like, pd.DataFrame) else pd.DataFrame(X_like)
|
|
39
|
+
Xd = pd.get_dummies(Xd, drop_first=True)
|
|
40
|
+
|
|
41
|
+
if ref_columns is not None:
|
|
42
|
+
ref_wo_const = [c for c in ref_columns if c != "const"]
|
|
43
|
+
Xd = Xd.reindex(columns=ref_wo_const, fill_value=0.0)
|
|
44
|
+
|
|
45
|
+
for c in Xd.columns:
|
|
46
|
+
Xd[c] = pd.to_numeric(Xd[c], errors="coerce")
|
|
47
|
+
Xd = Xd.replace([np.inf, -np.inf], np.nan).fillna(0.0)
|
|
48
|
+
|
|
49
|
+
if add_const:
|
|
50
|
+
Xd = sm.add_constant(Xd, has_constant="add")
|
|
51
|
+
|
|
52
|
+
return Xd
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def run_logistic_stats_check(
|
|
56
|
+
model,
|
|
57
|
+
X_train,
|
|
58
|
+
X_test,
|
|
59
|
+
y_train,
|
|
60
|
+
y_test,
|
|
61
|
+
rule_config: Dict[str, Any],
|
|
62
|
+
cleaned_df,
|
|
63
|
+
*args,
|
|
64
|
+
**kwargs,
|
|
65
|
+
) -> Dict[str, Any]:
|
|
66
|
+
"""
|
|
67
|
+
Logistic challenger (stats-only):
|
|
68
|
+
|
|
69
|
+
- Fits a statsmodels Logit on a one-hot design of X_train (with intercept)
|
|
70
|
+
- Produces: summary_text and coefficient table with CIs
|
|
71
|
+
- Computes baseline classification metrics on the test set (NO plots/CSVs)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
{
|
|
75
|
+
"LogitStats": {
|
|
76
|
+
"summary_text": str,
|
|
77
|
+
"coef_table_headers": ["feature","coef","std err","z","P>|z|","ci_low","ci_high"],
|
|
78
|
+
"coef_table_rows": [ {...}, ... ],
|
|
79
|
+
"baseline_metrics": { "summary": {...} }, # rounded, no 'plots'/'tables'
|
|
80
|
+
"baseline_note": "..."
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
"""
|
|
84
|
+
try:
|
|
85
|
+
# 1) Skip if model is obviously not logistic-like
|
|
86
|
+
is_logistic_like = (
|
|
87
|
+
isinstance(model, LogisticRegression)
|
|
88
|
+
or getattr(model, "__class__", type("X", (object,), {})).__name__.lower().startswith("logit")
|
|
89
|
+
or hasattr(model, "predict_proba")
|
|
90
|
+
)
|
|
91
|
+
if not is_logistic_like:
|
|
92
|
+
print("ℹ️ LogisticStatsCheck skipped — model not logistic-like")
|
|
93
|
+
return {"LogitStats": {"skipped": True}}
|
|
94
|
+
|
|
95
|
+
# 2) Ensure binary target
|
|
96
|
+
y_train_s = pd.Series(y_train)
|
|
97
|
+
if not _is_binary_series(y_train_s):
|
|
98
|
+
print("ℹ️ LogisticStatsCheck skipped — target is not binary")
|
|
99
|
+
return {"LogitStats": {"skipped": True}}
|
|
100
|
+
|
|
101
|
+
# Robust 0/1 encoding (majority -> 0, minority -> 1)
|
|
102
|
+
counts = y_train_s.value_counts().sort_values(ascending=False).index.tolist()
|
|
103
|
+
enc_map = {counts[0]: 0, counts[1]: 1}
|
|
104
|
+
yb_train = y_train_s.map(enc_map).astype(int)
|
|
105
|
+
|
|
106
|
+
# 3) Train design matrix (with intercept)
|
|
107
|
+
Xd_train = _prep_design_matrix_df(X_train, ref_columns=None, add_const=True)
|
|
108
|
+
|
|
109
|
+
# 4) Fit statsmodels Logit (MLE)
|
|
110
|
+
res = sm.Logit(yb_train, Xd_train).fit(disp=0, method="lbfgs", maxiter=1000)
|
|
111
|
+
|
|
112
|
+
# 5) Summary text (human-readable)
|
|
113
|
+
try:
|
|
114
|
+
summary_text = res.summary2().as_text()
|
|
115
|
+
except Exception:
|
|
116
|
+
summary_text = str(res.summary())
|
|
117
|
+
|
|
118
|
+
# 6) Coefficient table (const first)
|
|
119
|
+
params = res.params
|
|
120
|
+
bse = res.bse
|
|
121
|
+
# Avoid divide-by-zero in z; replace zeros with NaN then fill after rounding
|
|
122
|
+
zvals = params / bse.replace(0, np.nan)
|
|
123
|
+
pvals = res.pvalues
|
|
124
|
+
ci = res.conf_int(alpha=0.05)
|
|
125
|
+
ci.columns = ["ci_low", "ci_high"]
|
|
126
|
+
|
|
127
|
+
coef_df = pd.DataFrame(
|
|
128
|
+
{
|
|
129
|
+
"feature": params.index,
|
|
130
|
+
"coef": params.values,
|
|
131
|
+
"std err": bse.values,
|
|
132
|
+
"z": zvals.values,
|
|
133
|
+
"P>|z|": pvals.values,
|
|
134
|
+
"ci_low": ci["ci_low"].values,
|
|
135
|
+
"ci_high": ci["ci_high"].values,
|
|
136
|
+
}
|
|
9
137
|
)
|
|
10
138
|
|
|
11
|
-
if
|
|
12
|
-
|
|
13
|
-
|
|
139
|
+
if "const" in coef_df["feature"].values:
|
|
140
|
+
coef_df = pd.concat(
|
|
141
|
+
[
|
|
142
|
+
coef_df.loc[coef_df["feature"] == "const"],
|
|
143
|
+
coef_df.loc[coef_df["feature"] != "const"],
|
|
144
|
+
],
|
|
145
|
+
ignore_index=True,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
for c in ["coef", "std err", "z", "P>|z|", "ci_low", "ci_high"]:
|
|
149
|
+
coef_df[c] = pd.to_numeric(coef_df[c], errors="coerce").round(4)
|
|
150
|
+
|
|
151
|
+
# 7) Test-set baseline metrics (NO PLOTS/CSVs)
|
|
152
|
+
# Build test matrix aligned to the training design columns.
|
|
153
|
+
Xd_test = _prep_design_matrix_df(X_test, ref_columns=Xd_train.columns, add_const=True)
|
|
14
154
|
|
|
15
|
-
#
|
|
16
|
-
|
|
17
|
-
|
|
155
|
+
# Statsmodels Logit returns probability for class "1"
|
|
156
|
+
y_score = res.predict(Xd_test) # shape (n_test,)
|
|
157
|
+
|
|
158
|
+
# Threshold policy (aligned with PerformanceCheck if present)
|
|
159
|
+
threshold = (rule_config.get("PerformanceCheck", {}) or {}).get("threshold", 0.5)
|
|
160
|
+
try:
|
|
161
|
+
thr = float(threshold)
|
|
162
|
+
except Exception:
|
|
163
|
+
thr = 0.5
|
|
164
|
+
|
|
165
|
+
y_pred = (y_score >= thr).astype(int)
|
|
166
|
+
|
|
167
|
+
yb_test = pd.Series(y_test).map(enc_map).astype(int).to_numpy()
|
|
168
|
+
|
|
169
|
+
has_posneg = len(np.unique(yb_test)) > 1
|
|
170
|
+
auc = roc_auc_score(yb_test, y_score) if has_posneg else np.nan
|
|
171
|
+
fpr, tpr, _ = roc_curve(yb_test, y_score) if has_posneg else (np.array([]), np.array([]), None)
|
|
172
|
+
ks = float(np.max(np.abs(tpr - fpr))) if len(fpr) else np.nan
|
|
173
|
+
ap = average_precision_score(yb_test, y_score) if has_posneg else np.nan
|
|
174
|
+
brier = brier_score_loss(yb_test, y_score)
|
|
175
|
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
176
|
+
yb_test, y_pred, average="binary", pos_label=1, zero_division=0
|
|
177
|
+
)
|
|
178
|
+
acc = accuracy_score(yb_test, y_pred)
|
|
179
|
+
gini = 2 * auc - 1 if (auc == auc) else np.nan # handle NaN
|
|
180
|
+
|
|
181
|
+
baseline_metrics = {
|
|
182
|
+
"summary": {
|
|
183
|
+
"auc": None if auc != auc else round(float(auc), 2),
|
|
184
|
+
"ks": None if ks != ks else round(float(ks), 2),
|
|
185
|
+
"accuracy": round(float(acc), 2),
|
|
186
|
+
"precision": round(float(precision), 2),
|
|
187
|
+
"recall": round(float(recall), 2),
|
|
188
|
+
"f1": round(float(f1), 2),
|
|
189
|
+
"pr_auc": None if ap != ap else round(float(ap), 2),
|
|
190
|
+
"brier": round(float(brier), 2),
|
|
191
|
+
"gini": None if gini != gini else round(float(gini), 2),
|
|
192
|
+
}
|
|
193
|
+
}
|
|
18
194
|
|
|
19
195
|
return {
|
|
20
|
-
"
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
196
|
+
"LogitStats": {
|
|
197
|
+
"summary_text": summary_text,
|
|
198
|
+
"coef_table_headers": ["feature", "coef", "std err", "z", "P>|z|", "ci_low", "ci_high"],
|
|
199
|
+
"coef_table_rows": coef_df.to_dict(orient="records"),
|
|
200
|
+
"baseline_metrics": baseline_metrics, # <-- metrics only; no plots/tables
|
|
201
|
+
"baseline_note": f"Computed on the same test split and preprocessing as the primary model; threshold={thr}.",
|
|
202
|
+
}
|
|
24
203
|
}
|
|
25
204
|
|
|
26
205
|
except Exception as e:
|
|
27
206
|
print(f"⚠️ LogisticStatsCheck failed: {e}")
|
|
28
|
-
return {"
|
|
207
|
+
return {"LogitStats": {"error": str(e)}}
|
|
@@ -1,28 +1,84 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
import numpy as np
|
|
5
|
+
from tanml.checks.performance_classification import compute_classification_report
|
|
6
|
+
|
|
7
|
+
def _resolve_outdir(config: Dict[str, Any]) -> str:
|
|
8
|
+
base = (config.get("paths") or {}).get("artifacts_dir") \
|
|
9
|
+
or (config.get("options") or {}).get("save_artifacts_dir") \
|
|
10
|
+
or "reports"
|
|
11
|
+
outdir = os.path.join(base, "performance")
|
|
12
|
+
os.makedirs(outdir, exist_ok=True)
|
|
13
|
+
return outdir
|
|
14
|
+
|
|
15
|
+
def PerformanceCheckRunner(
|
|
16
|
+
model,
|
|
17
|
+
X_train, X_test, y_train, y_test,
|
|
18
|
+
config: Dict[str, Any],
|
|
19
|
+
cleaned_df,
|
|
20
|
+
raw_df=None,
|
|
21
|
+
ctx=None,
|
|
22
|
+
):
|
|
23
|
+
outdir = _resolve_outdir(config)
|
|
24
|
+
task_type = ((config.get("model") or {}).get("type") or "binary_classification").lower()
|
|
25
|
+
|
|
26
|
+
payload: Dict[str, Any] = {}
|
|
27
|
+
|
|
28
|
+
if "class" in task_type:
|
|
29
|
+
# --- build scores ---
|
|
30
|
+
def _scores(m, X):
|
|
31
|
+
if hasattr(m, "predict_proba"):
|
|
32
|
+
p = m.predict_proba(X)
|
|
33
|
+
return p[:, 1] if p.ndim == 2 and p.shape[1] > 1 else p.ravel()
|
|
34
|
+
if hasattr(m, "decision_function"):
|
|
35
|
+
return m.decision_function(X).ravel()
|
|
36
|
+
return m.predict(X).ravel()
|
|
37
|
+
|
|
38
|
+
y_score = _scores(model, X_test)
|
|
39
|
+
y_pred = getattr(model, "predict")(X_test)
|
|
40
|
+
|
|
41
|
+
cls_dir = os.path.join(outdir, "classification")
|
|
42
|
+
os.makedirs(cls_dir, exist_ok=True)
|
|
43
|
+
|
|
44
|
+
results_cls = compute_classification_report(
|
|
45
|
+
y_true=np.asarray(y_test),
|
|
46
|
+
y_score=np.asarray(y_score),
|
|
47
|
+
y_pred=np.asarray(y_pred),
|
|
48
|
+
outdir=cls_dir,
|
|
49
|
+
pos_label=1,
|
|
50
|
+
title_prefix=(config.get("model") or {}).get("name", "Model"),
|
|
22
51
|
)
|
|
23
|
-
result = check.run()
|
|
24
|
-
return {"PerformanceCheck": result}
|
|
25
52
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
53
|
+
payload = {
|
|
54
|
+
"performance": {
|
|
55
|
+
"classification": results_cls
|
|
56
|
+
},
|
|
57
|
+
"task_type": "classification",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
|
|
62
|
+
# payload = {"performance": {"regression": results_reg}, "task_type": "regression"}
|
|
63
|
+
payload = {"task_type": "regression"}
|
|
64
|
+
|
|
65
|
+
return payload
|
|
66
|
+
|
|
67
|
+
# ---- Back-compat alias so registry can import old name ----
|
|
68
|
+
def run_performance_check(
|
|
69
|
+
model,
|
|
70
|
+
X_train, X_test, y_train, y_test,
|
|
71
|
+
config,
|
|
72
|
+
cleaned_df,
|
|
73
|
+
raw_df=None,
|
|
74
|
+
ctx=None,
|
|
75
|
+
):
|
|
76
|
+
return PerformanceCheckRunner(
|
|
77
|
+
model=model,
|
|
78
|
+
X_train=X_train, X_test=X_test,
|
|
79
|
+
y_train=y_train, y_test=y_test,
|
|
80
|
+
config=config,
|
|
81
|
+
cleaned_df=cleaned_df,
|
|
82
|
+
raw_df=raw_df,
|
|
83
|
+
ctx=ctx,
|
|
84
|
+
)
|
|
@@ -1,27 +1,43 @@
|
|
|
1
|
-
from
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
2
4
|
import pandas as pd
|
|
3
5
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
+
from tanml.checks.raw_data import RawDataCheck
|
|
7
|
+
from tanml.utils.data_loader import load_dataframe
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def run_raw_data_check(
|
|
11
|
+
model,
|
|
12
|
+
X_train,
|
|
13
|
+
X_test,
|
|
14
|
+
y_train,
|
|
15
|
+
y_test,
|
|
16
|
+
rule_config,
|
|
17
|
+
cleaned_data,
|
|
18
|
+
*args,
|
|
19
|
+
**kwargs
|
|
20
|
+
):
|
|
6
21
|
try:
|
|
7
|
-
#
|
|
22
|
+
# Locate raw data from config
|
|
8
23
|
raw_obj = (
|
|
9
|
-
rule_config.get("raw_data")
|
|
10
|
-
rule_config.get("paths", {}).get("raw_data")
|
|
24
|
+
rule_config.get("raw_data")
|
|
25
|
+
or (rule_config.get("paths", {}) or {}).get("raw_data")
|
|
26
|
+
or (rule_config.get("paths", {}) or {}).get("raw")
|
|
11
27
|
)
|
|
28
|
+
|
|
12
29
|
if raw_obj is None:
|
|
13
30
|
print("ℹ️ RawDataCheck skipped — raw_data not provided in config.")
|
|
14
31
|
return None
|
|
15
32
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
raw_obj = pd.read_csv(raw_obj)
|
|
33
|
+
if isinstance(raw_obj, (str, bytes, os.PathLike)):
|
|
34
|
+
raw_obj = load_dataframe(raw_obj)
|
|
19
35
|
|
|
20
36
|
if not isinstance(raw_obj, pd.DataFrame):
|
|
21
|
-
print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame.")
|
|
37
|
+
print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame or loadable path.")
|
|
22
38
|
return None
|
|
23
39
|
|
|
24
|
-
#
|
|
40
|
+
# Run the check
|
|
25
41
|
check = RawDataCheck(
|
|
26
42
|
model=model,
|
|
27
43
|
X_train=X_train,
|
|
@@ -32,9 +48,8 @@ def run_raw_data_check(model, X_train, X_test, y_train, y_test,
|
|
|
32
48
|
cleaned_data=cleaned_data,
|
|
33
49
|
raw_data=raw_obj
|
|
34
50
|
)
|
|
35
|
-
|
|
36
|
-
stats
|
|
37
|
-
return stats["RawDataCheck"] # hand the inner dict to ValidationEngine
|
|
51
|
+
stats = check.run()
|
|
52
|
+
return stats.get("RawDataCheck", stats)
|
|
38
53
|
|
|
39
54
|
except Exception as e:
|
|
40
55
|
print(f"⚠️ RawDataCheck failed: {e}")
|