tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
1
+ from __future__ import annotations
2
+ import os, math
3
+ from dataclasses import dataclass, asdict
4
+ from typing import Dict, Any
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.metrics import (
9
+ roc_auc_score, roc_curve, confusion_matrix, precision_recall_fscore_support,
10
+ accuracy_score, average_precision_score, brier_score_loss, precision_recall_curve
11
+ )
12
+ from sklearn.calibration import calibration_curve
13
+
14
+ # ---------- utilities ----------
15
+ def _ensure_dir(d: str) -> str:
16
+ os.makedirs(d, exist_ok=True)
17
+ return d
18
+
19
+ def _savefig(path: str) -> str:
20
+ plt.tight_layout()
21
+ plt.savefig(path, bbox_inches="tight", dpi=160)
22
+ plt.close()
23
+ return path
24
+
25
+ def _gini_from_auc(auc: float) -> float:
26
+ return 2 * auc - 1 if (auc is not None and not np.isnan(auc)) else np.nan
27
+
28
+ def _ks_from_roc(fpr, tpr) -> float:
29
+ return float(np.max(np.abs(tpr - fpr))) if len(fpr) else np.nan
30
+
31
+ def _decile_lift_table(y_true: np.ndarray, y_score: np.ndarray, pos_label: int = 1, n_bins: int = 10) -> pd.DataFrame:
32
+ df = pd.DataFrame({"y": (y_true == pos_label).astype(int), "score": y_score})
33
+ df = df.sort_values("score", ascending=False).reset_index(drop=True)
34
+ df["decile"] = pd.qcut(df.index + 1, q=n_bins, labels=list(range(1, n_bins + 1)))
35
+ g = df.groupby("decile", observed=True).agg(
36
+ total=("y", "size"),
37
+ events=("y", "sum"),
38
+ avg_score=("score", "mean"),
39
+ ).reset_index()
40
+ g["event_rate"] = g["events"] / g["total"]
41
+ overall_rate = df["y"].mean() if df["y"].size else np.nan
42
+ g["lift"] = g["event_rate"] / overall_rate if (overall_rate and not math.isclose(overall_rate, 0.0)) else np.nan
43
+
44
+ # cumulative capture & gain
45
+ g["cum_events"] = g["events"].cumsum()
46
+ g["cum_total"] = g["total"].cumsum()
47
+ total_events = g["events"].sum()
48
+ g["cum_capture_rate"] = g["cum_events"] / total_events if total_events > 0 else np.nan
49
+ g["cum_population"] = g["cum_total"] / g["total"].sum()
50
+ g["cum_gain"] = g["cum_capture_rate"] # same as cumulative gains curve
51
+ return g
52
+
53
+ def _ks_curve_frame(y_true: np.ndarray, y_score: np.ndarray, pos_label: int = 1) -> pd.DataFrame:
54
+ """
55
+ Returns a dataframe with columns:
56
+ population (fraction 0..1), cum_event, cum_non_event, ks_gap
57
+ sorted by score DESC, which is standard for risk ranking.
58
+ """
59
+ df = pd.DataFrame({"y": (y_true == pos_label).astype(int), "score": y_score})
60
+ if df.empty:
61
+ return pd.DataFrame(columns=["population", "cum_event", "cum_non_event", "ks_gap"])
62
+
63
+ df = df.sort_values("score", ascending=False).reset_index(drop=True)
64
+ n = len(df)
65
+
66
+ # counts
67
+ total_events = df["y"].sum()
68
+ total_non_events = n - total_events
69
+
70
+ # avoid divide-by-zero; if all one class, return empty (plotter will handle)
71
+ if total_events == 0 or total_non_events == 0:
72
+ return pd.DataFrame(columns=["population", "cum_event", "cum_non_event", "ks_gap"])
73
+
74
+ cum_events = np.cumsum(df["y"].values) / total_events
75
+ cum_non_events = np.cumsum(1 - df["y"].values) / total_non_events
76
+ population = (np.arange(1, n + 1)) / n
77
+ ks_gap = np.abs(cum_events - cum_non_events)
78
+
79
+ return pd.DataFrame({
80
+ "population": population,
81
+ "cum_event": cum_events,
82
+ "cum_non_event": cum_non_events,
83
+ "ks_gap": ks_gap
84
+ })
85
+
86
+ @dataclass
87
+ class ClassificationSummary:
88
+ auc: float
89
+ ks: float
90
+ accuracy: float
91
+ precision: float
92
+ recall: float
93
+ f1: float
94
+ pr_auc: float
95
+ brier: float
96
+ gini: float
97
+ # paths
98
+ roc_png: str
99
+ pr_png: str
100
+ lift_png: str
101
+ calib_png: str
102
+ cm_png: str
103
+ ks_png: str
104
+ # tables
105
+ confusion_csv: str
106
+ lift_csv: str
107
+
108
+
109
+ def compute_classification_report(
110
+ *,
111
+ y_true: np.ndarray,
112
+ y_score: np.ndarray,
113
+ y_pred: np.ndarray,
114
+ outdir: str,
115
+ pos_label: int = 1,
116
+ title_prefix: str = "Model"
117
+ ) -> Dict[str, Any]:
118
+ """
119
+ Computes metrics + saves plots/CSVs for classification.
120
+ Returns a dict ready for ReportBuilder/Jinja.
121
+ """
122
+ _ensure_dir(outdir)
123
+
124
+ # --- metrics
125
+ has_posneg = len(np.unique(y_true)) > 1
126
+ auc = roc_auc_score(y_true, y_score) if has_posneg else np.nan
127
+ fpr, tpr, _ = roc_curve(y_true, y_score) if has_posneg else (np.array([]), np.array([]), None)
128
+ ks = _ks_from_roc(fpr, tpr)
129
+ pr_auc = average_precision_score(y_true, y_score) if has_posneg else np.nan
130
+ brier = brier_score_loss(y_true, y_score)
131
+ gini = _gini_from_auc(auc)
132
+
133
+ precision, recall, f1, _ = precision_recall_fscore_support(
134
+ y_true, y_pred, average="binary", pos_label=pos_label, zero_division=0
135
+ )
136
+ acc = accuracy_score(y_true, y_pred)
137
+
138
+ # --- confusion matrix & CSV
139
+ cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
140
+ cm_df = pd.DataFrame(cm, index=["Actual 0", "Actual 1"], columns=["Pred 0", "Pred 1"])
141
+ cm_csv = os.path.join(outdir, "confusion_matrix.csv")
142
+ cm_df.to_csv(cm_csv, index=True)
143
+
144
+ # --- decile lift table & CSV (rounded to 2 decimals)
145
+ lift_df = _decile_lift_table(y_true, y_score, pos_label=pos_label, n_bins=10)
146
+ lift_df_round = lift_df.copy()
147
+ num_cols = lift_df_round.select_dtypes(include=[np.number]).columns
148
+ lift_df_round[num_cols] = lift_df_round[num_cols].round(2)
149
+ lift_csv = os.path.join(outdir, "lift_table_deciles.csv")
150
+ lift_df_round.to_csv(lift_csv, index=False)
151
+
152
+ # --- plots
153
+ # ROC
154
+ if len(fpr):
155
+ plt.figure()
156
+ plt.plot(fpr, tpr, label=f"ROC (AUC={auc:.3f})")
157
+ plt.plot([0, 1], [0, 1], linestyle="--")
158
+ plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate"); plt.title(f"{title_prefix}: ROC Curve")
159
+ plt.legend(loc="lower right")
160
+ roc_png = os.path.join(outdir, "roc_curve.png")
161
+ _savefig(roc_png)
162
+
163
+ # PR
164
+ pr, rc, _ = precision_recall_curve(y_true, y_score)
165
+ plt.figure()
166
+ plt.plot(rc, pr, label=f"PR (AP={pr_auc:.3f})")
167
+ plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"{title_prefix}: Precision–Recall Curve")
168
+ plt.legend(loc="lower left")
169
+ pr_png = os.path.join(outdir, "pr_curve.png")
170
+ _savefig(pr_png)
171
+
172
+ # Calibration (Reliability) curve
173
+ prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=10, strategy="uniform")
174
+ plt.figure()
175
+ plt.plot(prob_pred, prob_true, marker="o", label="Reliability")
176
+ plt.plot([0, 1], [0, 1], linestyle="--", label="Perfect")
177
+ plt.xlabel("Predicted probability"); plt.ylabel("Observed frequency"); plt.title(f"{title_prefix}: Calibration")
178
+ plt.legend(loc="upper left")
179
+ calib_png = os.path.join(outdir, "calibration_curve.png")
180
+ _savefig(calib_png)
181
+
182
+ # Lift / Gain chart (use unrounded frame for smooth curve)
183
+ plt.figure()
184
+ plt.plot(lift_df["cum_population"], lift_df["cum_gain"], marker="o", label="Cumulative Gain")
185
+ plt.plot([0, 1], [0, 1], linestyle="--", label="Baseline")
186
+ plt.xlabel("Cumulative Population"); plt.ylabel("Cumulative Gain"); plt.title(f"{title_prefix}: Cumulative Gain")
187
+ plt.legend(loc="lower right")
188
+ lift_png = os.path.join(outdir, "lift_gain_curve.png")
189
+ _savefig(lift_png)
190
+
191
+ # Confusion heatmap
192
+ plt.figure()
193
+ plt.imshow(cm, interpolation="nearest")
194
+ plt.title(f"{title_prefix}: Confusion Matrix")
195
+ plt.colorbar()
196
+ tick_marks = np.arange(2)
197
+ plt.xticks(tick_marks, ["Pred 0", "Pred 1"])
198
+ plt.yticks(tick_marks, ["Actual 0", "Actual 1"])
199
+ for i in range(2):
200
+ for j in range(2):
201
+ plt.text(j, i, cm[i, j], ha="center", va="center")
202
+ plt.ylabel("Actual"); plt.xlabel("Predicted")
203
+ cm_png = os.path.join(outdir, "confusion_matrix.png")
204
+ _savefig(cm_png)
205
+
206
+ # --- KS curve (cumulative event vs non-event by population)
207
+ ks_df = _ks_curve_frame(y_true, y_score, pos_label=pos_label)
208
+ ks_csv = os.path.join(outdir, "ks_curve.csv")
209
+ ks_png = os.path.join(outdir, "ks_curve.png")
210
+
211
+ if not ks_df.empty:
212
+ # locate max KS point
213
+ ks_idx = int(ks_df["ks_gap"].values.argmax())
214
+ ks_x = float(ks_df["population"].iloc[ks_idx])
215
+ ks_y_event = float(ks_df["cum_event"].iloc[ks_idx])
216
+ ks_y_nonevent= float(ks_df["cum_non_event"].iloc[ks_idx])
217
+ ks_val_annot = abs(ks_y_event - ks_y_nonevent)
218
+
219
+ # plot with explicit figure/axes only once
220
+ fig, ax = plt.subplots(figsize=(6, 4))
221
+ ax.plot(ks_df["population"], ks_df["cum_event"], label="Cumulative Event")
222
+ ax.plot(ks_df["population"], ks_df["cum_non_event"], label="Cumulative Non-Event")
223
+
224
+ # vertical line & markers at max KS
225
+ ax.axvline(ks_x, linestyle="--", alpha=0.7)
226
+ ax.scatter([ks_x], [ks_y_event], s=25)
227
+ ax.scatter([ks_x], [ks_y_nonevent], s=25)
228
+
229
+ # readable annotation (two lines, boxed)
230
+ ax.annotate(
231
+ f"KS = {ks_val_annot:.1%}\nPop = {ks_x:.1%}",
232
+ xy=(ks_x, (ks_y_event + ks_y_nonevent) / 2.0),
233
+ xytext=(ks_x + 0.05, min(0.9, (ks_y_event + ks_y_nonevent) / 2.0 + 0.1)),
234
+ arrowprops=dict(arrowstyle="->", color="black", lw=1),
235
+ ha="left", va="center", fontsize=10,
236
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=0.8),
237
+ )
238
+
239
+ ax.set_xlabel("Population (fraction)")
240
+ ax.set_ylabel("Cumulative share")
241
+ ax.set_title(f"{title_prefix}: KS Curve")
242
+ ax.legend(loc="lower right")
243
+
244
+ ks_df.to_csv(ks_csv, index=False)
245
+ _savefig(ks_png)
246
+ else:
247
+ # Write header-only CSV and a placeholder figure so ks_png always exists
248
+ pd.DataFrame(columns=["population","cum_event","cum_non_event","ks_gap"]).to_csv(ks_csv, index=False)
249
+ fig, ax = plt.subplots(figsize=(6, 4))
250
+ ax.set_title(f"{title_prefix}: KS Curve (not available)")
251
+ _savefig(ks_png)
252
+
253
+ # --- build summaries: raw + rounded for display
254
+ summary = ClassificationSummary(
255
+ auc=float(auc) if auc == auc else np.nan, # handle NaN
256
+ ks=float(ks) if ks == ks else np.nan,
257
+ accuracy=float(acc),
258
+ precision=float(precision),
259
+ recall=float(recall),
260
+ f1=float(f1),
261
+ pr_auc=float(pr_auc) if pr_auc == pr_auc else np.nan,
262
+ brier=float(brier),
263
+ gini=float(gini) if gini == gini else np.nan,
264
+ roc_png=roc_png,
265
+ pr_png=pr_png,
266
+ lift_png=lift_png,
267
+ calib_png=calib_png,
268
+ cm_png=cm_png,
269
+ ks_png=ks_png, # <---- NEW
270
+ confusion_csv=cm_csv,
271
+ lift_csv=lift_csv,
272
+ )
273
+ summary_raw = asdict(summary)
274
+
275
+ # round only numeric metric fields to 2 decimals; keep paths as-is
276
+ metric_fields = {"auc", "ks", "accuracy", "precision", "recall", "f1", "pr_auc", "brier", "gini"}
277
+ summary_rounded = {}
278
+ for k, v in summary_raw.items():
279
+ if k in metric_fields:
280
+ try:
281
+ summary_rounded[k] = None if v is None or (isinstance(v, float) and np.isnan(v)) else round(float(v), 2)
282
+ except Exception:
283
+ summary_rounded[k] = v
284
+ else:
285
+ summary_rounded[k] = v
286
+
287
+ return {
288
+ "summary": summary_rounded, # rounded for report/UI display
289
+ "summary_raw": summary_raw, # full precision preserved for rules/debug
290
+ "tables": {
291
+ "confusion_csv": cm_csv,
292
+ "lift_csv": lift_csv,
293
+ "ks_csv": ks_csv,
294
+ },
295
+ "plots": {
296
+ "roc": roc_png,
297
+ "pr": pr_png,
298
+ "lift": lift_png,
299
+ "calibration": calib_png,
300
+ "confusion": cm_png,
301
+ "ks": ks_png,
302
+ },
303
+ # return rounded rows so the DOCX table looks clean
304
+ "deciles": lift_df_round.to_dict(orient="records"),
305
+ }
tanml/checks/raw_data.py CHANGED
@@ -1,39 +1,75 @@
1
- # tanml/checks/raw_data.py ← make sure this is the ONLY copy on disk
2
- from .base import BaseCheck
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
3
4
  import pandas as pd
4
5
 
6
+ from .base import BaseCheck
7
+ from tanml.utils.data_loader import load_dataframe
8
+
9
+
5
10
  class RawDataCheck(BaseCheck):
6
- def __init__(self,
7
- model, X_train, X_test, y_train, y_test,
8
- rule_config, cleaned_data,
9
- raw_data=None):
10
- # bring in rule_config & cleaned_data
11
- super().__init__(model, X_train, X_test, y_train, y_test,
12
- rule_config, cleaned_data)
11
+ """
12
+ Raw data sanity metrics. Accepts:
13
+ - DataFrame via `raw_data`, or
14
+ - a file path via `raw_data`, or
15
+ - YAML: paths.raw_data
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model,
21
+ X_train,
22
+ X_test,
23
+ y_train,
24
+ y_test,
25
+ rule_config: Dict[str, Any],
26
+ cleaned_data: pd.DataFrame,
27
+ raw_data: Any = None,
28
+ ):
29
+ super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data)
13
30
 
14
31
  if not hasattr(self, "config") or self.config is None:
15
32
  self.config = {}
16
33
 
17
34
  if raw_data is not None:
18
- if isinstance(raw_data, (str, bytes)):
19
- raw_data = pd.read_csv(raw_data)
20
- if not isinstance(raw_data, pd.DataFrame):
21
- raise ValueError("raw_data must be a pandas DataFrame or CSV path")
22
- self.config["raw_data"] = raw_data
23
-
24
- def run(self):
25
- results = {}
35
+ if isinstance(raw_data, pd.DataFrame):
36
+ self.config["raw_data"] = raw_data
37
+ elif isinstance(raw_data, (str, bytes)):
38
+ self.config["raw_data"] = load_dataframe(raw_data)
39
+ else:
40
+ raise ValueError("raw_data must be a pandas DataFrame or a file path")
41
+
42
+ def run(self) -> Dict[str, Any]:
43
+ results: Dict[str, Any] = {}
26
44
  try:
27
- df = self.config.get("raw_data")
45
+ df = self.config.get("raw_data")
46
+
47
+ # If df is a path-like string, load it now
48
+ if isinstance(df, (str, bytes)):
49
+ df = load_dataframe(df)
50
+ self.config["raw_data"] = df
51
+
52
+ # If still missing, try YAML fallbacks
28
53
  if not isinstance(df, pd.DataFrame):
29
- raise ValueError("raw_data not found or not a DataFrame")
54
+ paths = self.config.get("paths") if isinstance(self.config.get("paths"), dict) else None
55
+ raw_path = None
56
+ if paths:
57
+ raw_path = paths.get("raw_data") or paths.get("raw")
58
+ raw_path = raw_path or self.config.get("raw_data_path") or self.config.get("raw_path")
30
59
 
31
- results["total_rows"] = int(df.shape[0])
60
+ if raw_path:
61
+ df = load_dataframe(raw_path)
62
+ self.config["raw_data"] = df
63
+ else:
64
+ raise ValueError("raw_data not found: provide DataFrame or set paths.raw_data in YAML")
65
+
66
+ # --- metrics ---
67
+ results["total_rows"] = int(df.shape[0])
32
68
  results["total_columns"] = int(df.shape[1])
33
69
 
34
70
  miss = df.isnull().mean().round(4)
35
- results["avg_missing"] = float(miss.mean())
36
- results["columns_with_missing"] = miss[miss > 0].to_dict()
71
+ results["avg_missing"] = float(miss.mean())
72
+ results["columns_with_missing"] = miss[miss > 0].to_dict()
37
73
 
38
74
  results["duplicate_rows"] = int(df.duplicated().sum())
39
75
 
@@ -44,4 +80,3 @@ class RawDataCheck(BaseCheck):
44
80
  results["error"] = str(e)
45
81
 
46
82
  return {"RawDataCheck": results}
47
-
@@ -0,0 +1,167 @@
1
+ # tanml/checks/regression_metrics.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+ import numpy as np
6
+
7
+ try:
8
+ # Prefer sklearn implementations when available
9
+ from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
10
+ _HAS_SKLEARN = True
11
+ except Exception:
12
+ _HAS_SKLEARN = False
13
+
14
+ from .base import BaseCheck
15
+
16
+
17
+ class RegressionMetricsCheck(BaseCheck):
18
+ """
19
+ Computes TanML's frozen regression metrics:
20
+ - RMSE
21
+ - MAE
22
+ - Median Absolute Error
23
+ - R²
24
+ - Adjusted R²
25
+ - MAPE (or SMAPE fallback when zeros/near-zeros exist in y_true)
26
+
27
+ Pure compute: no file I/O, no plotting. Returns a dict.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ y_true: np.ndarray,
34
+ y_pred: np.ndarray,
35
+ n_features: Optional[int] = None,
36
+ mape_eps: float = 1e-8,
37
+ config: Optional[Dict[str, Any]] = None,
38
+ ) -> None:
39
+ """
40
+ Parameters
41
+ ----------
42
+ y_true : array-like of shape (n_samples,)
43
+ y_pred : array-like of shape (n_samples,)
44
+ n_features : int, optional
45
+ Number of model features (for Adjusted R²). If None, Adjusted R² may be None.
46
+ mape_eps : float
47
+ Small constant to guard divisions in MAPE/SMAPE.
48
+ config : dict, optional
49
+ Reserved for future options.
50
+ """
51
+ self.y_true = np.asarray(y_true).reshape(-1)
52
+ self.y_pred = np.asarray(y_pred).reshape(-1)
53
+ self.n_features = int(n_features) if n_features is not None else None
54
+ self.mape_eps = float(mape_eps)
55
+ self.config = config or {}
56
+
57
+ self._notes: List[str] = []
58
+
59
+ # ---------------------------
60
+ # Public API
61
+ # ---------------------------
62
+ def run(self) -> Dict[str, Any]:
63
+ self._validate_inputs()
64
+
65
+ rmse = self._rmse(self.y_true, self.y_pred)
66
+ mae = self._mae(self.y_true, self.y_pred)
67
+ median_ae = self._median_ae(self.y_true, self.y_pred)
68
+
69
+ r2, r2_adj = self._r2_and_adjusted(self.y_true, self.y_pred, self.n_features)
70
+
71
+ mape_val, smape_val, used = self._mape_or_smape(self.y_true, self.y_pred, self.mape_eps)
72
+
73
+ return {
74
+ "rmse": rmse,
75
+ "mae": mae,
76
+ "median_ae": median_ae,
77
+ "r2": r2,
78
+ "r2_adjusted": r2_adj,
79
+ "mape_or_smape": mape_val if used == "MAPE" else smape_val,
80
+ "mape_used": (used == "MAPE"),
81
+ "notes": self._notes, # human-readable notes (e.g., SMAPE fallback, zero variance)
82
+ }
83
+
84
+ # ---------------------------
85
+ # Internals
86
+ # ---------------------------
87
+ def _validate_inputs(self) -> None:
88
+ if self.y_true.shape != self.y_pred.shape:
89
+ raise ValueError(f"Shapes differ: y_true{self.y_true.shape} vs y_pred{self.y_pred.shape}")
90
+ if self.y_true.ndim != 1 or self.y_pred.ndim != 1:
91
+ raise ValueError("y_true and y_pred must be 1-D arrays.")
92
+ if self.y_true.size < 2:
93
+ self._notes.append("Too few samples (<2) — some metrics may be undefined.")
94
+
95
+ @staticmethod
96
+ def _rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
97
+ if _HAS_SKLEARN:
98
+ return float(np.sqrt(mean_squared_error(y_true, y_pred)))
99
+ return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))
100
+
101
+ @staticmethod
102
+ def _mae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
103
+ if _HAS_SKLEARN:
104
+ return float(mean_absolute_error(y_true, y_pred))
105
+ return float(np.mean(np.abs(y_true - y_pred)))
106
+
107
+ @staticmethod
108
+ def _median_ae(y_true: np.ndarray, y_pred: np.ndarray) -> float:
109
+ return float(np.median(np.abs(y_true - y_pred)))
110
+
111
+ def _r2_and_adjusted(
112
+ self,
113
+ y_true: np.ndarray,
114
+ y_pred: np.ndarray,
115
+ n_features: Optional[int],
116
+ ) -> Tuple[Optional[float], Optional[float]]:
117
+ n = int(y_true.size)
118
+ # If variance is zero, R² is undefined
119
+ if float(np.var(y_true)) == 0.0:
120
+ self._notes.append("R² undefined: target has zero variance.")
121
+ return None, None
122
+
123
+ if _HAS_SKLEARN:
124
+ try:
125
+ r2_val = float(r2_score(y_true, y_pred))
126
+ except Exception:
127
+ r2_val = None
128
+ self._notes.append("R² could not be computed via sklearn.r2_score.")
129
+ else:
130
+ ss_res = float(np.sum((y_true - y_pred) ** 2))
131
+ ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
132
+ r2_val = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else None
133
+
134
+ if r2_val is None or n_features is None:
135
+ return r2_val, None
136
+
137
+ p = int(n_features)
138
+ if n <= p + 1:
139
+ self._notes.append("Adjusted R² unavailable: insufficient degrees of freedom (n <= p + 1).")
140
+ return r2_val, None
141
+
142
+ try:
143
+ r2_adj = float(1.0 - (1.0 - r2_val) * (n - 1) / (n - p - 1))
144
+ except Exception:
145
+ r2_adj = None
146
+ self._notes.append("Adjusted R² computation failed due to numeric issues.")
147
+ return r2_val, r2_adj
148
+
149
+ def _mape_or_smape(
150
+ self, y_true: np.ndarray, y_pred: np.ndarray, eps: float
151
+ ) -> Tuple[Optional[float], Optional[float], str]:
152
+ """Return (MAPE, SMAPE, used_flag) and record notes for fallbacks."""
153
+ has_near_zero = np.any(np.abs(y_true) <= eps)
154
+ smape_val = self._smape(y_true, y_pred, eps)
155
+
156
+ if has_near_zero:
157
+ self._notes.append("MAPE skipped due to zeros/near-zeros in target; SMAPE reported instead.")
158
+ return None, smape_val, "SMAPE"
159
+
160
+ mape_val = float(np.mean(np.abs((y_true - y_pred) / (y_true + eps))) * 100.0)
161
+ return mape_val, smape_val, "MAPE"
162
+
163
+ @staticmethod
164
+ def _smape(y_true: np.ndarray, y_pred: np.ndarray, eps: float) -> float:
165
+ num = 2.0 * np.abs(y_pred - y_true)
166
+ den = np.abs(y_true) + np.abs(y_pred) + eps
167
+ return float(np.mean(num / den) * 100.0)