tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -1
- tanml/check_runners/cleaning_repro_runner.py +2 -2
- tanml/check_runners/correlation_runner.py +49 -12
- tanml/check_runners/explainability_runner.py +12 -22
- tanml/check_runners/logistic_stats_runner.py +196 -17
- tanml/check_runners/performance_runner.py +82 -26
- tanml/check_runners/raw_data_runner.py +29 -14
- tanml/check_runners/regression_metrics_runner.py +195 -0
- tanml/check_runners/stress_test_runner.py +23 -6
- tanml/check_runners/vif_runner.py +33 -27
- tanml/checks/correlation.py +241 -41
- tanml/checks/explainability/shap_check.py +261 -29
- tanml/checks/logit_stats.py +186 -54
- tanml/checks/performance_classification.py +305 -0
- tanml/checks/raw_data.py +58 -23
- tanml/checks/regression_metrics.py +167 -0
- tanml/checks/stress_test.py +157 -53
- tanml/cli/main.py +99 -27
- tanml/engine/check_agent_registry.py +20 -10
- tanml/engine/core_engine_agent.py +199 -37
- tanml/models/registry.py +329 -0
- tanml/report/report_builder.py +1180 -147
- tanml/report/templates/report_template_cls.docx +0 -0
- tanml/report/templates/report_template_reg.docx +0 -0
- tanml/ui/app.py +1205 -0
- tanml/utils/data_loader.py +105 -15
- tanml-0.1.7.dist-info/METADATA +164 -0
- tanml-0.1.7.dist-info/RECORD +54 -0
- tanml/cli/arg_parser.py +0 -31
- tanml/cli/init_cmd.py +0 -8
- tanml/cli/validate_cmd.py +0 -7
- tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
- tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
- tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
- tanml/engine/segmentation_agent.py +0 -118
- tanml/engine/validation_agent.py +0 -91
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/model_loader.py +0 -35
- tanml/utils/r_loader.py +0 -30
- tanml/utils/sas_loader.py +0 -50
- tanml/utils/yaml_generator.py +0 -34
- tanml/utils/yaml_loader.py +0 -5
- tanml/validate.py +0 -209
- tanml-0.1.6.dist-info/METADATA +0 -317
- tanml-0.1.6.dist-info/RECORD +0 -62
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os, math
|
|
3
|
+
from dataclasses import dataclass, asdict
|
|
4
|
+
from typing import Dict, Any
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
from sklearn.metrics import (
|
|
9
|
+
roc_auc_score, roc_curve, confusion_matrix, precision_recall_fscore_support,
|
|
10
|
+
accuracy_score, average_precision_score, brier_score_loss, precision_recall_curve
|
|
11
|
+
)
|
|
12
|
+
from sklearn.calibration import calibration_curve
|
|
13
|
+
|
|
14
|
+
# ---------- utilities ----------
|
|
15
|
+
def _ensure_dir(d: str) -> str:
|
|
16
|
+
os.makedirs(d, exist_ok=True)
|
|
17
|
+
return d
|
|
18
|
+
|
|
19
|
+
def _savefig(path: str) -> str:
|
|
20
|
+
plt.tight_layout()
|
|
21
|
+
plt.savefig(path, bbox_inches="tight", dpi=160)
|
|
22
|
+
plt.close()
|
|
23
|
+
return path
|
|
24
|
+
|
|
25
|
+
def _gini_from_auc(auc: float) -> float:
|
|
26
|
+
return 2 * auc - 1 if (auc is not None and not np.isnan(auc)) else np.nan
|
|
27
|
+
|
|
28
|
+
def _ks_from_roc(fpr, tpr) -> float:
|
|
29
|
+
return float(np.max(np.abs(tpr - fpr))) if len(fpr) else np.nan
|
|
30
|
+
|
|
31
|
+
def _decile_lift_table(y_true: np.ndarray, y_score: np.ndarray, pos_label: int = 1, n_bins: int = 10) -> pd.DataFrame:
|
|
32
|
+
df = pd.DataFrame({"y": (y_true == pos_label).astype(int), "score": y_score})
|
|
33
|
+
df = df.sort_values("score", ascending=False).reset_index(drop=True)
|
|
34
|
+
df["decile"] = pd.qcut(df.index + 1, q=n_bins, labels=list(range(1, n_bins + 1)))
|
|
35
|
+
g = df.groupby("decile", observed=True).agg(
|
|
36
|
+
total=("y", "size"),
|
|
37
|
+
events=("y", "sum"),
|
|
38
|
+
avg_score=("score", "mean"),
|
|
39
|
+
).reset_index()
|
|
40
|
+
g["event_rate"] = g["events"] / g["total"]
|
|
41
|
+
overall_rate = df["y"].mean() if df["y"].size else np.nan
|
|
42
|
+
g["lift"] = g["event_rate"] / overall_rate if (overall_rate and not math.isclose(overall_rate, 0.0)) else np.nan
|
|
43
|
+
|
|
44
|
+
# cumulative capture & gain
|
|
45
|
+
g["cum_events"] = g["events"].cumsum()
|
|
46
|
+
g["cum_total"] = g["total"].cumsum()
|
|
47
|
+
total_events = g["events"].sum()
|
|
48
|
+
g["cum_capture_rate"] = g["cum_events"] / total_events if total_events > 0 else np.nan
|
|
49
|
+
g["cum_population"] = g["cum_total"] / g["total"].sum()
|
|
50
|
+
g["cum_gain"] = g["cum_capture_rate"] # same as cumulative gains curve
|
|
51
|
+
return g
|
|
52
|
+
|
|
53
|
+
def _ks_curve_frame(y_true: np.ndarray, y_score: np.ndarray, pos_label: int = 1) -> pd.DataFrame:
|
|
54
|
+
"""
|
|
55
|
+
Returns a dataframe with columns:
|
|
56
|
+
population (fraction 0..1), cum_event, cum_non_event, ks_gap
|
|
57
|
+
sorted by score DESC, which is standard for risk ranking.
|
|
58
|
+
"""
|
|
59
|
+
df = pd.DataFrame({"y": (y_true == pos_label).astype(int), "score": y_score})
|
|
60
|
+
if df.empty:
|
|
61
|
+
return pd.DataFrame(columns=["population", "cum_event", "cum_non_event", "ks_gap"])
|
|
62
|
+
|
|
63
|
+
df = df.sort_values("score", ascending=False).reset_index(drop=True)
|
|
64
|
+
n = len(df)
|
|
65
|
+
|
|
66
|
+
# counts
|
|
67
|
+
total_events = df["y"].sum()
|
|
68
|
+
total_non_events = n - total_events
|
|
69
|
+
|
|
70
|
+
# avoid divide-by-zero; if all one class, return empty (plotter will handle)
|
|
71
|
+
if total_events == 0 or total_non_events == 0:
|
|
72
|
+
return pd.DataFrame(columns=["population", "cum_event", "cum_non_event", "ks_gap"])
|
|
73
|
+
|
|
74
|
+
cum_events = np.cumsum(df["y"].values) / total_events
|
|
75
|
+
cum_non_events = np.cumsum(1 - df["y"].values) / total_non_events
|
|
76
|
+
population = (np.arange(1, n + 1)) / n
|
|
77
|
+
ks_gap = np.abs(cum_events - cum_non_events)
|
|
78
|
+
|
|
79
|
+
return pd.DataFrame({
|
|
80
|
+
"population": population,
|
|
81
|
+
"cum_event": cum_events,
|
|
82
|
+
"cum_non_event": cum_non_events,
|
|
83
|
+
"ks_gap": ks_gap
|
|
84
|
+
})
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class ClassificationSummary:
|
|
88
|
+
auc: float
|
|
89
|
+
ks: float
|
|
90
|
+
accuracy: float
|
|
91
|
+
precision: float
|
|
92
|
+
recall: float
|
|
93
|
+
f1: float
|
|
94
|
+
pr_auc: float
|
|
95
|
+
brier: float
|
|
96
|
+
gini: float
|
|
97
|
+
# paths
|
|
98
|
+
roc_png: str
|
|
99
|
+
pr_png: str
|
|
100
|
+
lift_png: str
|
|
101
|
+
calib_png: str
|
|
102
|
+
cm_png: str
|
|
103
|
+
ks_png: str
|
|
104
|
+
# tables
|
|
105
|
+
confusion_csv: str
|
|
106
|
+
lift_csv: str
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def compute_classification_report(
|
|
110
|
+
*,
|
|
111
|
+
y_true: np.ndarray,
|
|
112
|
+
y_score: np.ndarray,
|
|
113
|
+
y_pred: np.ndarray,
|
|
114
|
+
outdir: str,
|
|
115
|
+
pos_label: int = 1,
|
|
116
|
+
title_prefix: str = "Model"
|
|
117
|
+
) -> Dict[str, Any]:
|
|
118
|
+
"""
|
|
119
|
+
Computes metrics + saves plots/CSVs for classification.
|
|
120
|
+
Returns a dict ready for ReportBuilder/Jinja.
|
|
121
|
+
"""
|
|
122
|
+
_ensure_dir(outdir)
|
|
123
|
+
|
|
124
|
+
# --- metrics
|
|
125
|
+
has_posneg = len(np.unique(y_true)) > 1
|
|
126
|
+
auc = roc_auc_score(y_true, y_score) if has_posneg else np.nan
|
|
127
|
+
fpr, tpr, _ = roc_curve(y_true, y_score) if has_posneg else (np.array([]), np.array([]), None)
|
|
128
|
+
ks = _ks_from_roc(fpr, tpr)
|
|
129
|
+
pr_auc = average_precision_score(y_true, y_score) if has_posneg else np.nan
|
|
130
|
+
brier = brier_score_loss(y_true, y_score)
|
|
131
|
+
gini = _gini_from_auc(auc)
|
|
132
|
+
|
|
133
|
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
134
|
+
y_true, y_pred, average="binary", pos_label=pos_label, zero_division=0
|
|
135
|
+
)
|
|
136
|
+
acc = accuracy_score(y_true, y_pred)
|
|
137
|
+
|
|
138
|
+
# --- confusion matrix & CSV
|
|
139
|
+
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
|
|
140
|
+
cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Pred 0", "Pred 1"])
|
|
141
|
+
cm_csv = os.path.join(outdir, "confusion_matrix.csv")
|
|
142
|
+
cm_df.to_csv(cm_csv, index=True)
|
|
143
|
+
|
|
144
|
+
# --- decile lift table & CSV (rounded to 2 decimals)
|
|
145
|
+
lift_df = _decile_lift_table(y_true, y_score, pos_label=pos_label, n_bins=10)
|
|
146
|
+
lift_df_round = lift_df.copy()
|
|
147
|
+
num_cols = lift_df_round.select_dtypes(include=[np.number]).columns
|
|
148
|
+
lift_df_round[num_cols] = lift_df_round[num_cols].round(2)
|
|
149
|
+
lift_csv = os.path.join(outdir, "lift_table_deciles.csv")
|
|
150
|
+
lift_df_round.to_csv(lift_csv, index=False)
|
|
151
|
+
|
|
152
|
+
# --- plots
|
|
153
|
+
# ROC
|
|
154
|
+
if len(fpr):
|
|
155
|
+
plt.figure()
|
|
156
|
+
plt.plot(fpr, tpr, label=f"ROC (AUC={auc:.3f})")
|
|
157
|
+
plt.plot([0, 1], [0, 1], linestyle="--")
|
|
158
|
+
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate"); plt.title(f"{title_prefix}: ROC Curve")
|
|
159
|
+
plt.legend(loc="lower right")
|
|
160
|
+
roc_png = os.path.join(outdir, "roc_curve.png")
|
|
161
|
+
_savefig(roc_png)
|
|
162
|
+
|
|
163
|
+
# PR
|
|
164
|
+
pr, rc, _ = precision_recall_curve(y_true, y_score)
|
|
165
|
+
plt.figure()
|
|
166
|
+
plt.plot(rc, pr, label=f"PR (AP={pr_auc:.3f})")
|
|
167
|
+
plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"{title_prefix}: Precision–Recall Curve")
|
|
168
|
+
plt.legend(loc="lower left")
|
|
169
|
+
pr_png = os.path.join(outdir, "pr_curve.png")
|
|
170
|
+
_savefig(pr_png)
|
|
171
|
+
|
|
172
|
+
# Calibration (Reliability) curve
|
|
173
|
+
prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=10, strategy="uniform")
|
|
174
|
+
plt.figure()
|
|
175
|
+
plt.plot(prob_pred, prob_true, marker="o", label="Reliability")
|
|
176
|
+
plt.plot([0, 1], [0, 1], linestyle="--", label="Perfect")
|
|
177
|
+
plt.xlabel("Predicted probability"); plt.ylabel("Observed frequency"); plt.title(f"{title_prefix}: Calibration")
|
|
178
|
+
plt.legend(loc="upper left")
|
|
179
|
+
calib_png = os.path.join(outdir, "calibration_curve.png")
|
|
180
|
+
_savefig(calib_png)
|
|
181
|
+
|
|
182
|
+
# Lift / Gain chart (use unrounded frame for smooth curve)
|
|
183
|
+
plt.figure()
|
|
184
|
+
plt.plot(lift_df["cum_population"], lift_df["cum_gain"], marker="o", label="Cumulative Gain")
|
|
185
|
+
plt.plot([0, 1], [0, 1], linestyle="--", label="Baseline")
|
|
186
|
+
plt.xlabel("Cumulative Population"); plt.ylabel("Cumulative Gain"); plt.title(f"{title_prefix}: Cumulative Gain")
|
|
187
|
+
plt.legend(loc="lower right")
|
|
188
|
+
lift_png = os.path.join(outdir, "lift_gain_curve.png")
|
|
189
|
+
_savefig(lift_png)
|
|
190
|
+
|
|
191
|
+
# Confusion heatmap
|
|
192
|
+
plt.figure()
|
|
193
|
+
plt.imshow(cm, interpolation="nearest")
|
|
194
|
+
plt.title(f"{title_prefix}: Confusion Matrix")
|
|
195
|
+
plt.colorbar()
|
|
196
|
+
tick_marks = np.arange(2)
|
|
197
|
+
plt.xticks(tick_marks, ["Pred 0", "Pred 1"])
|
|
198
|
+
plt.yticks(tick_marks, ["Actual 0", "Actual 1"])
|
|
199
|
+
for i in range(2):
|
|
200
|
+
for j in range(2):
|
|
201
|
+
plt.text(j, i, cm[i, j], ha="center", va="center")
|
|
202
|
+
plt.ylabel("Actual"); plt.xlabel("Predicted")
|
|
203
|
+
cm_png = os.path.join(outdir, "confusion_matrix.png")
|
|
204
|
+
_savefig(cm_png)
|
|
205
|
+
|
|
206
|
+
# --- KS curve (cumulative event vs non-event by population)
|
|
207
|
+
ks_df = _ks_curve_frame(y_true, y_score, pos_label=pos_label)
|
|
208
|
+
ks_csv = os.path.join(outdir, "ks_curve.csv")
|
|
209
|
+
ks_png = os.path.join(outdir, "ks_curve.png")
|
|
210
|
+
|
|
211
|
+
if not ks_df.empty:
|
|
212
|
+
# locate max KS point
|
|
213
|
+
ks_idx = int(ks_df["ks_gap"].values.argmax())
|
|
214
|
+
ks_x = float(ks_df["population"].iloc[ks_idx])
|
|
215
|
+
ks_y_event = float(ks_df["cum_event"].iloc[ks_idx])
|
|
216
|
+
ks_y_nonevent= float(ks_df["cum_non_event"].iloc[ks_idx])
|
|
217
|
+
ks_val_annot = abs(ks_y_event - ks_y_nonevent)
|
|
218
|
+
|
|
219
|
+
# plot with explicit figure/axes only once
|
|
220
|
+
fig, ax = plt.subplots(figsize=(6, 4))
|
|
221
|
+
ax.plot(ks_df["population"], ks_df["cum_event"], label="Cumulative Event")
|
|
222
|
+
ax.plot(ks_df["population"], ks_df["cum_non_event"], label="Cumulative Non-Event")
|
|
223
|
+
|
|
224
|
+
# vertical line & markers at max KS
|
|
225
|
+
ax.axvline(ks_x, linestyle="--", alpha=0.7)
|
|
226
|
+
ax.scatter([ks_x], [ks_y_event], s=25)
|
|
227
|
+
ax.scatter([ks_x], [ks_y_nonevent], s=25)
|
|
228
|
+
|
|
229
|
+
# readable annotation (two lines, boxed)
|
|
230
|
+
ax.annotate(
|
|
231
|
+
f"KS = {ks_val_annot:.1%}\nPop = {ks_x:.1%}",
|
|
232
|
+
xy=(ks_x, (ks_y_event + ks_y_nonevent) / 2.0),
|
|
233
|
+
xytext=(ks_x + 0.05, min(0.9, (ks_y_event + ks_y_nonevent) / 2.0 + 0.1)),
|
|
234
|
+
arrowprops=dict(arrowstyle="->", color="black", lw=1),
|
|
235
|
+
ha="left", va="center", fontsize=10,
|
|
236
|
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=0.8),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
ax.set_xlabel("Population (fraction)")
|
|
240
|
+
ax.set_ylabel("Cumulative share")
|
|
241
|
+
ax.set_title(f"{title_prefix}: KS Curve")
|
|
242
|
+
ax.legend(loc="lower right")
|
|
243
|
+
|
|
244
|
+
ks_df.to_csv(ks_csv, index=False)
|
|
245
|
+
_savefig(ks_png)
|
|
246
|
+
else:
|
|
247
|
+
# Write header-only CSV and a placeholder figure so ks_png always exists
|
|
248
|
+
pd.DataFrame(columns=["population","cum_event","cum_non_event","ks_gap"]).to_csv(ks_csv, index=False)
|
|
249
|
+
fig, ax = plt.subplots(figsize=(6, 4))
|
|
250
|
+
ax.set_title(f"{title_prefix}: KS Curve (not available)")
|
|
251
|
+
_savefig(ks_png)
|
|
252
|
+
|
|
253
|
+
# --- build summaries: raw + rounded for display
|
|
254
|
+
summary = ClassificationSummary(
|
|
255
|
+
auc=float(auc) if auc == auc else np.nan, # handle NaN
|
|
256
|
+
ks=float(ks) if ks == ks else np.nan,
|
|
257
|
+
accuracy=float(acc),
|
|
258
|
+
precision=float(precision),
|
|
259
|
+
recall=float(recall),
|
|
260
|
+
f1=float(f1),
|
|
261
|
+
pr_auc=float(pr_auc) if pr_auc == pr_auc else np.nan,
|
|
262
|
+
brier=float(brier),
|
|
263
|
+
gini=float(gini) if gini == gini else np.nan,
|
|
264
|
+
roc_png=roc_png,
|
|
265
|
+
pr_png=pr_png,
|
|
266
|
+
lift_png=lift_png,
|
|
267
|
+
calib_png=calib_png,
|
|
268
|
+
cm_png=cm_png,
|
|
269
|
+
ks_png=ks_png, # <---- NEW
|
|
270
|
+
confusion_csv=cm_csv,
|
|
271
|
+
lift_csv=lift_csv,
|
|
272
|
+
)
|
|
273
|
+
summary_raw = asdict(summary)
|
|
274
|
+
|
|
275
|
+
# round only numeric metric fields to 2 decimals; keep paths as-is
|
|
276
|
+
metric_fields = {"auc", "ks", "accuracy", "precision", "recall", "f1", "pr_auc", "brier", "gini"}
|
|
277
|
+
summary_rounded = {}
|
|
278
|
+
for k, v in summary_raw.items():
|
|
279
|
+
if k in metric_fields:
|
|
280
|
+
try:
|
|
281
|
+
summary_rounded[k] = None if v is None or (isinstance(v, float) and np.isnan(v)) else round(float(v), 2)
|
|
282
|
+
except Exception:
|
|
283
|
+
summary_rounded[k] = v
|
|
284
|
+
else:
|
|
285
|
+
summary_rounded[k] = v
|
|
286
|
+
|
|
287
|
+
return {
|
|
288
|
+
"summary": summary_rounded, # rounded for report/UI display
|
|
289
|
+
"summary_raw": summary_raw, # full precision preserved for rules/debug
|
|
290
|
+
"tables": {
|
|
291
|
+
"confusion_csv": cm_csv,
|
|
292
|
+
"lift_csv": lift_csv,
|
|
293
|
+
"ks_csv": ks_csv,
|
|
294
|
+
},
|
|
295
|
+
"plots": {
|
|
296
|
+
"roc": roc_png,
|
|
297
|
+
"pr": pr_png,
|
|
298
|
+
"lift": lift_png,
|
|
299
|
+
"calibration": calib_png,
|
|
300
|
+
"confusion": cm_png,
|
|
301
|
+
"ks": ks_png,
|
|
302
|
+
},
|
|
303
|
+
# return rounded rows so the DOCX table looks clean
|
|
304
|
+
"deciles": lift_df_round.to_dict(orient="records"),
|
|
305
|
+
}
|
tanml/checks/raw_data.py
CHANGED
|
@@ -1,39 +1,75 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
3
4
|
import pandas as pd
|
|
4
5
|
|
|
6
|
+
from .base import BaseCheck
|
|
7
|
+
from tanml.utils.data_loader import load_dataframe
|
|
8
|
+
|
|
9
|
+
|
|
5
10
|
class RawDataCheck(BaseCheck):
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
"""
|
|
12
|
+
Raw data sanity metrics. Accepts:
|
|
13
|
+
- DataFrame via `raw_data`, or
|
|
14
|
+
- a file path via `raw_data`, or
|
|
15
|
+
- YAML: paths.raw_data
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model,
|
|
21
|
+
X_train,
|
|
22
|
+
X_test,
|
|
23
|
+
y_train,
|
|
24
|
+
y_test,
|
|
25
|
+
rule_config: Dict[str, Any],
|
|
26
|
+
cleaned_data: pd.DataFrame,
|
|
27
|
+
raw_data: Any = None,
|
|
28
|
+
):
|
|
29
|
+
super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data)
|
|
13
30
|
|
|
14
31
|
if not hasattr(self, "config") or self.config is None:
|
|
15
32
|
self.config = {}
|
|
16
33
|
|
|
17
34
|
if raw_data is not None:
|
|
18
|
-
if isinstance(raw_data,
|
|
19
|
-
raw_data =
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
35
|
+
if isinstance(raw_data, pd.DataFrame):
|
|
36
|
+
self.config["raw_data"] = raw_data
|
|
37
|
+
elif isinstance(raw_data, (str, bytes)):
|
|
38
|
+
self.config["raw_data"] = load_dataframe(raw_data)
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError("raw_data must be a pandas DataFrame or a file path")
|
|
41
|
+
|
|
42
|
+
def run(self) -> Dict[str, Any]:
|
|
43
|
+
results: Dict[str, Any] = {}
|
|
26
44
|
try:
|
|
27
|
-
df = self.config.get("raw_data")
|
|
45
|
+
df = self.config.get("raw_data")
|
|
46
|
+
|
|
47
|
+
# If df is a path-like string, load it now
|
|
48
|
+
if isinstance(df, (str, bytes)):
|
|
49
|
+
df = load_dataframe(df)
|
|
50
|
+
self.config["raw_data"] = df
|
|
51
|
+
|
|
52
|
+
# If still missing, try YAML fallbacks
|
|
28
53
|
if not isinstance(df, pd.DataFrame):
|
|
29
|
-
|
|
54
|
+
paths = self.config.get("paths") if isinstance(self.config.get("paths"), dict) else None
|
|
55
|
+
raw_path = None
|
|
56
|
+
if paths:
|
|
57
|
+
raw_path = paths.get("raw_data") or paths.get("raw")
|
|
58
|
+
raw_path = raw_path or self.config.get("raw_data_path") or self.config.get("raw_path")
|
|
30
59
|
|
|
31
|
-
|
|
60
|
+
if raw_path:
|
|
61
|
+
df = load_dataframe(raw_path)
|
|
62
|
+
self.config["raw_data"] = df
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError("raw_data not found: provide DataFrame or set paths.raw_data in YAML")
|
|
65
|
+
|
|
66
|
+
# --- metrics ---
|
|
67
|
+
results["total_rows"] = int(df.shape[0])
|
|
32
68
|
results["total_columns"] = int(df.shape[1])
|
|
33
69
|
|
|
34
70
|
miss = df.isnull().mean().round(4)
|
|
35
|
-
results["avg_missing"]
|
|
36
|
-
results["columns_with_missing"]
|
|
71
|
+
results["avg_missing"] = float(miss.mean())
|
|
72
|
+
results["columns_with_missing"] = miss[miss > 0].to_dict()
|
|
37
73
|
|
|
38
74
|
results["duplicate_rows"] = int(df.duplicated().sum())
|
|
39
75
|
|
|
@@ -44,4 +80,3 @@ class RawDataCheck(BaseCheck):
|
|
|
44
80
|
results["error"] = str(e)
|
|
45
81
|
|
|
46
82
|
return {"RawDataCheck": results}
|
|
47
|
-
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# tanml/checks/regression_metrics.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
# Prefer sklearn implementations when available
|
|
9
|
+
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
|
|
10
|
+
_HAS_SKLEARN = True
|
|
11
|
+
except Exception:
|
|
12
|
+
_HAS_SKLEARN = False
|
|
13
|
+
|
|
14
|
+
from .base import BaseCheck
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class RegressionMetricsCheck(BaseCheck):
|
|
18
|
+
"""
|
|
19
|
+
Computes TanML's frozen regression metrics:
|
|
20
|
+
- RMSE
|
|
21
|
+
- MAE
|
|
22
|
+
- Median Absolute Error
|
|
23
|
+
- R²
|
|
24
|
+
- Adjusted R²
|
|
25
|
+
- MAPE (or SMAPE fallback when zeros/near-zeros exist in y_true)
|
|
26
|
+
|
|
27
|
+
Pure compute: no file I/O, no plotting. Returns a dict.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
y_true: np.ndarray,
|
|
34
|
+
y_pred: np.ndarray,
|
|
35
|
+
n_features: Optional[int] = None,
|
|
36
|
+
mape_eps: float = 1e-8,
|
|
37
|
+
config: Optional[Dict[str, Any]] = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
y_true : array-like of shape (n_samples,)
|
|
43
|
+
y_pred : array-like of shape (n_samples,)
|
|
44
|
+
n_features : int, optional
|
|
45
|
+
Number of model features (for Adjusted R²). If None, Adjusted R² may be None.
|
|
46
|
+
mape_eps : float
|
|
47
|
+
Small constant to guard divisions in MAPE/SMAPE.
|
|
48
|
+
config : dict, optional
|
|
49
|
+
Reserved for future options.
|
|
50
|
+
"""
|
|
51
|
+
self.y_true = np.asarray(y_true).reshape(-1)
|
|
52
|
+
self.y_pred = np.asarray(y_pred).reshape(-1)
|
|
53
|
+
self.n_features = int(n_features) if n_features is not None else None
|
|
54
|
+
self.mape_eps = float(mape_eps)
|
|
55
|
+
self.config = config or {}
|
|
56
|
+
|
|
57
|
+
self._notes: List[str] = []
|
|
58
|
+
|
|
59
|
+
# ---------------------------
|
|
60
|
+
# Public API
|
|
61
|
+
# ---------------------------
|
|
62
|
+
def run(self) -> Dict[str, Any]:
|
|
63
|
+
self._validate_inputs()
|
|
64
|
+
|
|
65
|
+
rmse = self._rmse(self.y_true, self.y_pred)
|
|
66
|
+
mae = self._mae(self.y_true, self.y_pred)
|
|
67
|
+
median_ae = self._median_ae(self.y_true, self.y_pred)
|
|
68
|
+
|
|
69
|
+
r2, r2_adj = self._r2_and_adjusted(self.y_true, self.y_pred, self.n_features)
|
|
70
|
+
|
|
71
|
+
mape_val, smape_val, used = self._mape_or_smape(self.y_true, self.y_pred, self.mape_eps)
|
|
72
|
+
|
|
73
|
+
return {
|
|
74
|
+
"rmse": rmse,
|
|
75
|
+
"mae": mae,
|
|
76
|
+
"median_ae": median_ae,
|
|
77
|
+
"r2": r2,
|
|
78
|
+
"r2_adjusted": r2_adj,
|
|
79
|
+
"mape_or_smape": mape_val if used == "MAPE" else smape_val,
|
|
80
|
+
"mape_used": (used == "MAPE"),
|
|
81
|
+
"notes": self._notes, # human-readable notes (e.g., SMAPE fallback, zero variance)
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# ---------------------------
|
|
85
|
+
# Internals
|
|
86
|
+
# ---------------------------
|
|
87
|
+
def _validate_inputs(self) -> None:
|
|
88
|
+
if self.y_true.shape != self.y_pred.shape:
|
|
89
|
+
raise ValueError(f"Shapes differ: y_true{self.y_true.shape} vs y_pred{self.y_pred.shape}")
|
|
90
|
+
if self.y_true.ndim != 1 or self.y_pred.ndim != 1:
|
|
91
|
+
raise ValueError("y_true and y_pred must be 1-D arrays.")
|
|
92
|
+
if self.y_true.size < 2:
|
|
93
|
+
self._notes.append("Too few samples (<2) — some metrics may be undefined.")
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
97
|
+
if _HAS_SKLEARN:
|
|
98
|
+
return float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
|
99
|
+
return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
103
|
+
if _HAS_SKLEARN:
|
|
104
|
+
return float(mean_absolute_error(y_true, y_pred))
|
|
105
|
+
return float(np.mean(np.abs(y_true - y_pred)))
|
|
106
|
+
|
|
107
|
+
@staticmethod
|
|
108
|
+
def _median_ae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
109
|
+
return float(np.median(np.abs(y_true - y_pred)))
|
|
110
|
+
|
|
111
|
+
def _r2_and_adjusted(
|
|
112
|
+
self,
|
|
113
|
+
y_true: np.ndarray,
|
|
114
|
+
y_pred: np.ndarray,
|
|
115
|
+
n_features: Optional[int],
|
|
116
|
+
) -> Tuple[Optional[float], Optional[float]]:
|
|
117
|
+
n = int(y_true.size)
|
|
118
|
+
# If variance is zero, R² is undefined
|
|
119
|
+
if float(np.var(y_true)) == 0.0:
|
|
120
|
+
self._notes.append("R² undefined: target has zero variance.")
|
|
121
|
+
return None, None
|
|
122
|
+
|
|
123
|
+
if _HAS_SKLEARN:
|
|
124
|
+
try:
|
|
125
|
+
r2_val = float(r2_score(y_true, y_pred))
|
|
126
|
+
except Exception:
|
|
127
|
+
r2_val = None
|
|
128
|
+
self._notes.append("R² could not be computed via sklearn.r2_score.")
|
|
129
|
+
else:
|
|
130
|
+
ss_res = float(np.sum((y_true - y_pred) ** 2))
|
|
131
|
+
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
|
|
132
|
+
r2_val = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else None
|
|
133
|
+
|
|
134
|
+
if r2_val is None or n_features is None:
|
|
135
|
+
return r2_val, None
|
|
136
|
+
|
|
137
|
+
p = int(n_features)
|
|
138
|
+
if n <= p + 1:
|
|
139
|
+
self._notes.append("Adjusted R² unavailable: insufficient degrees of freedom (n <= p + 1).")
|
|
140
|
+
return r2_val, None
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
r2_adj = float(1.0 - (1.0 - r2_val) * (n - 1) / (n - p - 1))
|
|
144
|
+
except Exception:
|
|
145
|
+
r2_adj = None
|
|
146
|
+
self._notes.append("Adjusted R² computation failed due to numeric issues.")
|
|
147
|
+
return r2_val, r2_adj
|
|
148
|
+
|
|
149
|
+
def _mape_or_smape(
|
|
150
|
+
self, y_true: np.ndarray, y_pred: np.ndarray, eps: float
|
|
151
|
+
) -> Tuple[Optional[float], Optional[float], str]:
|
|
152
|
+
"""Return (MAPE, SMAPE, used_flag) and record notes for fallbacks."""
|
|
153
|
+
has_near_zero = np.any(np.abs(y_true) <= eps)
|
|
154
|
+
smape_val = self._smape(y_true, y_pred, eps)
|
|
155
|
+
|
|
156
|
+
if has_near_zero:
|
|
157
|
+
self._notes.append("MAPE skipped due to zeros/near-zeros in target; SMAPE reported instead.")
|
|
158
|
+
return None, smape_val, "SMAPE"
|
|
159
|
+
|
|
160
|
+
mape_val = float(np.mean(np.abs((y_true - y_pred) / (y_true + eps))) * 100.0)
|
|
161
|
+
return mape_val, smape_val, "MAPE"
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _smape(y_true: np.ndarray, y_pred: np.ndarray, eps: float) -> float:
|
|
165
|
+
num = 2.0 * np.abs(y_pred - y_true)
|
|
166
|
+
den = np.abs(y_true) + np.abs(y_pred) + eps
|
|
167
|
+
return float(np.mean(num / den) * 100.0)
|