tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
tanml/__init__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.1"
1
+ __version__ = "0.1.7"
@@ -3,12 +3,12 @@ from tanml.checks.cleaning_repro import CleaningReproCheck
3
3
 
4
4
  def run_cleaning_repro_check(model, X_train, X_test, y_train, y_test,
5
5
  config, cleaned_data, *args, **kwargs):
6
- # honour rules.yaml toggle
6
+
7
7
  if not config.get("rules", {}).get("CleaningReproCheck", {}).get("enabled", True):
8
8
  print("ℹ️ CleaningReproCheck skipped (disabled in rules.yaml)")
9
9
  return None
10
10
 
11
- # raw_df can come from rules.yaml *or* via kwargs (passed by ValidationEngine)
11
+
12
12
  raw_data = config.get("raw_data") or kwargs.get("raw_df")
13
13
  if raw_data is None:
14
14
  print("⚠️ Skipping CleaningReproCheck — raw_data missing in config and kwargs")
@@ -1,15 +1,52 @@
1
- from tanml.checks.correlation import CorrelationCheck
1
+ # tanml/check_runners/correlation_runner.py
2
+ from __future__ import annotations
3
+ import os
4
+ from typing import Any, Dict, List
5
+ import pandas as pd
2
6
 
3
- def CorrelationCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
- try:
5
- cfg = rule_config.get("CorrelationCheck", {})
6
- if not cfg.get("enabled", True):
7
- print("ℹ️ CorrelationCheck skipped (disabled in rules.yaml)")
8
- return None
7
+ from tanml.checks.correlation import CorrelationCheck
9
8
 
10
- check = CorrelationCheck(cleaned_df)
11
- return check.run()
9
+ def _resolve_outdir(config: Dict[str, Any]) -> str:
10
+ base = (config.get("options") or {}).get("save_artifacts_dir") or "reports"
11
+ outdir = os.path.join(base, "correlation")
12
+ os.makedirs(outdir, exist_ok=True)
13
+ return outdir
12
14
 
13
- except Exception as e:
14
- print(f"⚠️ CorrelationCheck failed: {e}")
15
- return {"CorrelationCheck": {"error": str(e)}}
15
+ def _df_features_only(cleaned_df: pd.DataFrame) -> pd.DataFrame:
16
+ if cleaned_df is None or cleaned_df.empty:
17
+ return cleaned_df
18
+ cols = list(cleaned_df.columns)
19
+ if len(cols) >= 2:
20
+ return cleaned_df[cols[:-1]]
21
+ return cleaned_df
22
+
23
+ def CorrelationCheckRunner(
24
+ model,
25
+ X_train,
26
+ X_test,
27
+ y_train,
28
+ y_test,
29
+ config: Dict[str, Any],
30
+ cleaned_df: pd.DataFrame,
31
+ raw_df: pd.DataFrame | None = None,
32
+ ):
33
+ ui_block: Dict[str, Any] = (config.get("CorrelationCheck") or {})
34
+ legacy: Dict[str, Any] = (config.get("correlation") or {})
35
+ if not bool(ui_block.get("enabled", legacy.get("enabled", True))):
36
+ return None
37
+
38
+ df = _df_features_only(cleaned_df)
39
+ cfg: Dict[str, Any] = {
40
+ "method": ui_block.get("method", "pearson"),
41
+ "high_corr_threshold": float(ui_block.get("high_corr_threshold", 0.8)),
42
+ "heatmap_max_features_default": int(ui_block.get("heatmap_max_features_default", 20)),
43
+ "heatmap_max_features_limit": int(ui_block.get("heatmap_max_features_limit", 60)),
44
+ "subset_strategy": ui_block.get("subset_strategy", "cluster"),
45
+ "sample_rows": int(ui_block.get("sample_rows", 150_000)),
46
+ "seed": int(ui_block.get("seed", 42)),
47
+ "save_csv": True,
48
+ "save_fig": True,
49
+ "appendix_csv_cap": ui_block.get("appendix_csv_cap", None),
50
+ }
51
+ outdir = _resolve_outdir(config)
52
+ return CorrelationCheck(cleaned_data=df, cfg=cfg, output_dir=outdir).run()
@@ -1,28 +1,18 @@
1
- # tanml/check_runners/explainability_runner.py
2
-
1
+ # explainability runner
3
2
  from tanml.checks.explainability.shap_check import SHAPCheck
4
3
 
5
- def run_shap_check(
6
- model, X_train, X_test, y_train, y_test,
7
- rule_config, cleaned_df, *args, **kwargs
8
- ):
4
+ def run_shap_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
9
5
  try:
10
- cfg = rule_config.get("SHAPCheck", {})
11
- if not cfg.get("enabled", True):
12
- print("ℹ️ SHAPCheck skipped (disabled in rules.yaml)")
13
- return None
14
-
15
- check = SHAPCheck(
16
- model=model,
17
- X_train=X_train,
18
- X_test=X_test,
19
- y_train=y_train,
20
- y_test=y_test,
21
- rule_config=rule_config,
22
- cleaned_df=cleaned_df,
23
- )
24
- return check.run()
6
+ cfg_shapcheck = (rule_config or {}).get("SHAPCheck", {}) or {}
7
+ cfg_expl = (rule_config or {}).get("explainability", {}).get("shap", {}) or {}
8
+ enabled = cfg_shapcheck.get("enabled", cfg_expl.get("enabled", True))
9
+ if not enabled:
10
+ print("ℹ️ SHAPCheck skipped (disabled)")
11
+ return {"SHAPCheck": {"skipped": True}}
25
12
 
13
+ check = SHAPCheck(model, X_train, X_test, y_train, y_test, rule_config=rule_config, cleaned_df=cleaned_df)
14
+ result = check.run()
15
+ return {"SHAPCheck": result}
26
16
  except Exception as e:
27
17
  print(f"⚠️ SHAPCheck failed: {e}")
28
- return {"SHAPCheck": {"error": str(e)}}
18
+ return {"SHAPCheck": {"status": "error", "error": str(e)}}
@@ -1,28 +1,207 @@
1
- from tanml.checks.logit_stats import LogisticStatsCheck
2
- from sklearn.linear_model import LogisticRegression
1
+ # tanml/check_runners/logistic_stats_runner.py
2
+ from __future__ import annotations
3
3
 
4
- def run_logistic_stats_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ from typing import Any, Dict
5
+ from tanml.checks.logit_stats import _prep_design_matrix_df
6
+ import numpy as np
7
+ import pandas as pd
8
+ import statsmodels.api as sm
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.metrics import (
11
+ roc_auc_score,
12
+ roc_curve,
13
+ precision_recall_fscore_support,
14
+ accuracy_score,
15
+ average_precision_score,
16
+ brier_score_loss,
17
+ )
18
+
19
+
20
+ def _is_binary_series(y: pd.Series) -> bool:
5
21
  try:
6
- # Check if model is a LogisticRegression or statsmodels object
7
- is_logistic = isinstance(model, LogisticRegression) or (
8
- hasattr(model, "llf") and hasattr(model, "params") and hasattr(model, "summary")
22
+ u = pd.unique(pd.Series(y).dropna())
23
+ return len(u) == 2
24
+ except Exception:
25
+ return False
26
+
27
+
28
+ def _prep_design_matrix(
29
+ X_like: Any, ref_columns: pd.Index | None, add_const: bool = True
30
+ ) -> pd.DataFrame:
31
+ """
32
+ 1) Convert to DataFrame
33
+ 2) One-hot encode (drop_first=True)
34
+ 3) Align to ref_columns (if given), filling missing cols with 0 and dropping extras
35
+ 4) Coerce to numeric & sanitize
36
+ 5) Optionally add constant
37
+ """
38
+ Xd = X_like if isinstance(X_like, pd.DataFrame) else pd.DataFrame(X_like)
39
+ Xd = pd.get_dummies(Xd, drop_first=True)
40
+
41
+ if ref_columns is not None:
42
+ ref_wo_const = [c for c in ref_columns if c != "const"]
43
+ Xd = Xd.reindex(columns=ref_wo_const, fill_value=0.0)
44
+
45
+ for c in Xd.columns:
46
+ Xd[c] = pd.to_numeric(Xd[c], errors="coerce")
47
+ Xd = Xd.replace([np.inf, -np.inf], np.nan).fillna(0.0)
48
+
49
+ if add_const:
50
+ Xd = sm.add_constant(Xd, has_constant="add")
51
+
52
+ return Xd
53
+
54
+
55
+ def run_logistic_stats_check(
56
+ model,
57
+ X_train,
58
+ X_test,
59
+ y_train,
60
+ y_test,
61
+ rule_config: Dict[str, Any],
62
+ cleaned_df,
63
+ *args,
64
+ **kwargs,
65
+ ) -> Dict[str, Any]:
66
+ """
67
+ Logistic challenger (stats-only):
68
+
69
+ - Fits a statsmodels Logit on a one-hot design of X_train (with intercept)
70
+ - Produces: summary_text and coefficient table with CIs
71
+ - Computes baseline classification metrics on the test set (NO plots/CSVs)
72
+
73
+ Returns:
74
+ {
75
+ "LogitStats": {
76
+ "summary_text": str,
77
+ "coef_table_headers": ["feature","coef","std err","z","P>|z|","ci_low","ci_high"],
78
+ "coef_table_rows": [ {...}, ... ],
79
+ "baseline_metrics": { "summary": {...} }, # rounded, no 'plots'/'tables'
80
+ "baseline_note": "..."
81
+ }
82
+ }
83
+ """
84
+ try:
85
+ # 1) Skip if model is obviously not logistic-like
86
+ is_logistic_like = (
87
+ isinstance(model, LogisticRegression)
88
+ or getattr(model, "__class__", type("X", (object,), {})).__name__.lower().startswith("logit")
89
+ or hasattr(model, "predict_proba")
90
+ )
91
+ if not is_logistic_like:
92
+ print("ℹ️ LogisticStatsCheck skipped — model not logistic-like")
93
+ return {"LogitStats": {"skipped": True}}
94
+
95
+ # 2) Ensure binary target
96
+ y_train_s = pd.Series(y_train)
97
+ if not _is_binary_series(y_train_s):
98
+ print("ℹ️ LogisticStatsCheck skipped — target is not binary")
99
+ return {"LogitStats": {"skipped": True}}
100
+
101
+ # Robust 0/1 encoding (majority -> 0, minority -> 1)
102
+ counts = y_train_s.value_counts().sort_values(ascending=False).index.tolist()
103
+ enc_map = {counts[0]: 0, counts[1]: 1}
104
+ yb_train = y_train_s.map(enc_map).astype(int)
105
+
106
+ # 3) Train design matrix (with intercept)
107
+ Xd_train = _prep_design_matrix_df(X_train, ref_columns=None, add_const=True)
108
+
109
+ # 4) Fit statsmodels Logit (MLE)
110
+ res = sm.Logit(yb_train, Xd_train).fit(disp=0, method="lbfgs", maxiter=1000)
111
+
112
+ # 5) Summary text (human-readable)
113
+ try:
114
+ summary_text = res.summary2().as_text()
115
+ except Exception:
116
+ summary_text = str(res.summary())
117
+
118
+ # 6) Coefficient table (const first)
119
+ params = res.params
120
+ bse = res.bse
121
+ # Avoid divide-by-zero in z; replace zeros with NaN then fill after rounding
122
+ zvals = params / bse.replace(0, np.nan)
123
+ pvals = res.pvalues
124
+ ci = res.conf_int(alpha=0.05)
125
+ ci.columns = ["ci_low", "ci_high"]
126
+
127
+ coef_df = pd.DataFrame(
128
+ {
129
+ "feature": params.index,
130
+ "coef": params.values,
131
+ "std err": bse.values,
132
+ "z": zvals.values,
133
+ "P>|z|": pvals.values,
134
+ "ci_low": ci["ci_low"].values,
135
+ "ci_high": ci["ci_high"].values,
136
+ }
9
137
  )
10
138
 
11
- if not is_logistic:
12
- print("ℹ️ LogisticStatsCheck skipped — model is not logistic or not recognized")
13
- return None
139
+ if "const" in coef_df["feature"].values:
140
+ coef_df = pd.concat(
141
+ [
142
+ coef_df.loc[coef_df["feature"] == "const"],
143
+ coef_df.loc[coef_df["feature"] != "const"],
144
+ ],
145
+ ignore_index=True,
146
+ )
147
+
148
+ for c in ["coef", "std err", "z", "P>|z|", "ci_low", "ci_high"]:
149
+ coef_df[c] = pd.to_numeric(coef_df[c], errors="coerce").round(4)
150
+
151
+ # 7) Test-set baseline metrics (NO PLOTS/CSVs)
152
+ # Build test matrix aligned to the training design columns.
153
+ Xd_test = _prep_design_matrix_df(X_test, ref_columns=Xd_train.columns, add_const=True)
14
154
 
15
- # Use training data only for fitting stats
16
- check = LogisticStatsCheck(model, X_train, y_train, rule_config)
17
- output = check.run()
155
+ # Statsmodels Logit returns probability for class "1"
156
+ y_score = res.predict(Xd_test) # shape (n_test,)
157
+
158
+ # Threshold policy (aligned with PerformanceCheck if present)
159
+ threshold = (rule_config.get("PerformanceCheck", {}) or {}).get("threshold", 0.5)
160
+ try:
161
+ thr = float(threshold)
162
+ except Exception:
163
+ thr = 0.5
164
+
165
+ y_pred = (y_score >= thr).astype(int)
166
+
167
+ yb_test = pd.Series(y_test).map(enc_map).astype(int).to_numpy()
168
+
169
+ has_posneg = len(np.unique(yb_test)) > 1
170
+ auc = roc_auc_score(yb_test, y_score) if has_posneg else np.nan
171
+ fpr, tpr, _ = roc_curve(yb_test, y_score) if has_posneg else (np.array([]), np.array([]), None)
172
+ ks = float(np.max(np.abs(tpr - fpr))) if len(fpr) else np.nan
173
+ ap = average_precision_score(yb_test, y_score) if has_posneg else np.nan
174
+ brier = brier_score_loss(yb_test, y_score)
175
+ precision, recall, f1, _ = precision_recall_fscore_support(
176
+ yb_test, y_pred, average="binary", pos_label=1, zero_division=0
177
+ )
178
+ acc = accuracy_score(yb_test, y_pred)
179
+ gini = 2 * auc - 1 if (auc == auc) else np.nan # handle NaN
180
+
181
+ baseline_metrics = {
182
+ "summary": {
183
+ "auc": None if auc != auc else round(float(auc), 2),
184
+ "ks": None if ks != ks else round(float(ks), 2),
185
+ "accuracy": round(float(acc), 2),
186
+ "precision": round(float(precision), 2),
187
+ "recall": round(float(recall), 2),
188
+ "f1": round(float(f1), 2),
189
+ "pr_auc": None if ap != ap else round(float(ap), 2),
190
+ "brier": round(float(brier), 2),
191
+ "gini": None if gini != gini else round(float(gini), 2),
192
+ }
193
+ }
18
194
 
19
195
  return {
20
- "LogisticStatsCheck": output["table"],
21
- "LogisticStatsFit": output["fit"],
22
- "LogisticStatsSummary": output["summary"],
23
- "LogisticStatsCheck_obj": output["object"]
196
+ "LogitStats": {
197
+ "summary_text": summary_text,
198
+ "coef_table_headers": ["feature", "coef", "std err", "z", "P>|z|", "ci_low", "ci_high"],
199
+ "coef_table_rows": coef_df.to_dict(orient="records"),
200
+ "baseline_metrics": baseline_metrics, # <-- metrics only; no plots/tables
201
+ "baseline_note": f"Computed on the same test split and preprocessing as the primary model; threshold={thr}.",
202
+ }
24
203
  }
25
204
 
26
205
  except Exception as e:
27
206
  print(f"⚠️ LogisticStatsCheck failed: {e}")
28
- return {"LogisticStatsCheck": {"error": str(e)}}
207
+ return {"LogitStats": {"error": str(e)}}
@@ -1,28 +1,84 @@
1
- from tanml.checks.performance import PerformanceCheck
2
-
3
- def run_performance_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
- perf_cfg = rule_config.get("PerformanceCheck", {})
5
- if not perf_cfg.get("enabled", True):
6
- print("ℹ️ Skipping PerformanceCheck (disabled in rules.yaml)")
7
- return {"PerformanceCheck": {"skipped": True}}
8
-
9
- if X_test is None or y_test is None or len(X_test) == 0 or len(y_test) == 0:
10
- print("⚠️ Skipping PerformanceCheck due to empty test data.")
11
- return {"PerformanceCheck": {"error": "Test data is empty :skipping performance evaluation."}}
12
-
13
- try:
14
- check = PerformanceCheck(
15
- model=model,
16
- X_train=X_train,
17
- X_test=X_test,
18
- y_train=y_train,
19
- y_test=y_test,
20
- rule_config=perf_cfg,
21
- cleaned_data=cleaned_df
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Any, Dict
4
+ import numpy as np
5
+ from tanml.checks.performance_classification import compute_classification_report
6
+
7
+ def _resolve_outdir(config: Dict[str, Any]) -> str:
8
+ base = (config.get("paths") or {}).get("artifacts_dir") \
9
+ or (config.get("options") or {}).get("save_artifacts_dir") \
10
+ or "reports"
11
+ outdir = os.path.join(base, "performance")
12
+ os.makedirs(outdir, exist_ok=True)
13
+ return outdir
14
+
15
+ def PerformanceCheckRunner(
16
+ model,
17
+ X_train, X_test, y_train, y_test,
18
+ config: Dict[str, Any],
19
+ cleaned_df,
20
+ raw_df=None,
21
+ ctx=None,
22
+ ):
23
+ outdir = _resolve_outdir(config)
24
+ task_type = ((config.get("model") or {}).get("type") or "binary_classification").lower()
25
+
26
+ payload: Dict[str, Any] = {}
27
+
28
+ if "class" in task_type:
29
+ # --- build scores ---
30
+ def _scores(m, X):
31
+ if hasattr(m, "predict_proba"):
32
+ p = m.predict_proba(X)
33
+ return p[:, 1] if p.ndim == 2 and p.shape[1] > 1 else p.ravel()
34
+ if hasattr(m, "decision_function"):
35
+ return m.decision_function(X).ravel()
36
+ return m.predict(X).ravel()
37
+
38
+ y_score = _scores(model, X_test)
39
+ y_pred = getattr(model, "predict")(X_test)
40
+
41
+ cls_dir = os.path.join(outdir, "classification")
42
+ os.makedirs(cls_dir, exist_ok=True)
43
+
44
+ results_cls = compute_classification_report(
45
+ y_true=np.asarray(y_test),
46
+ y_score=np.asarray(y_score),
47
+ y_pred=np.asarray(y_pred),
48
+ outdir=cls_dir,
49
+ pos_label=1,
50
+ title_prefix=(config.get("model") or {}).get("name", "Model"),
22
51
  )
23
- result = check.run()
24
- return {"PerformanceCheck": result}
25
52
 
26
- except Exception as e:
27
- print(f"⚠️ PerformanceCheck failed: {e}")
28
- return {"PerformanceCheck": {"error": str(e)}}
53
+ payload = {
54
+ "performance": {
55
+ "classification": results_cls
56
+ },
57
+ "task_type": "classification",
58
+ }
59
+
60
+ else:
61
+
62
+ # payload = {"performance": {"regression": results_reg}, "task_type": "regression"}
63
+ payload = {"task_type": "regression"}
64
+
65
+ return payload
66
+
67
+ # ---- Back-compat alias so registry can import old name ----
68
+ def run_performance_check(
69
+ model,
70
+ X_train, X_test, y_train, y_test,
71
+ config,
72
+ cleaned_df,
73
+ raw_df=None,
74
+ ctx=None,
75
+ ):
76
+ return PerformanceCheckRunner(
77
+ model=model,
78
+ X_train=X_train, X_test=X_test,
79
+ y_train=y_train, y_test=y_test,
80
+ config=config,
81
+ cleaned_df=cleaned_df,
82
+ raw_df=raw_df,
83
+ ctx=ctx,
84
+ )
@@ -1,27 +1,43 @@
1
- from tanml.checks.raw_data import RawDataCheck
1
+ from __future__ import annotations
2
+
3
+ import os
2
4
  import pandas as pd
3
5
 
4
- def run_raw_data_check(model, X_train, X_test, y_train, y_test,
5
- rule_config, cleaned_data, *args, **kwargs):
6
+ from tanml.checks.raw_data import RawDataCheck
7
+ from tanml.utils.data_loader import load_dataframe
8
+
9
+
10
+ def run_raw_data_check(
11
+ model,
12
+ X_train,
13
+ X_test,
14
+ y_train,
15
+ y_test,
16
+ rule_config,
17
+ cleaned_data,
18
+ *args,
19
+ **kwargs
20
+ ):
6
21
  try:
7
- # ---- locate raw data (DF or path) -------------
22
+ # Locate raw data from config
8
23
  raw_obj = (
9
- rule_config.get("raw_data") or
10
- rule_config.get("paths", {}).get("raw_data")
24
+ rule_config.get("raw_data")
25
+ or (rule_config.get("paths", {}) or {}).get("raw_data")
26
+ or (rule_config.get("paths", {}) or {}).get("raw")
11
27
  )
28
+
12
29
  if raw_obj is None:
13
30
  print("ℹ️ RawDataCheck skipped — raw_data not provided in config.")
14
31
  return None
15
32
 
16
- # CSV path load once
17
- if isinstance(raw_obj, (str, bytes)):
18
- raw_obj = pd.read_csv(raw_obj)
33
+ if isinstance(raw_obj, (str, bytes, os.PathLike)):
34
+ raw_obj = load_dataframe(raw_obj)
19
35
 
20
36
  if not isinstance(raw_obj, pd.DataFrame):
21
- print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame.")
37
+ print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame or loadable path.")
22
38
  return None
23
39
 
24
- # ---- run the check -----------------------------
40
+ # Run the check
25
41
  check = RawDataCheck(
26
42
  model=model,
27
43
  X_train=X_train,
@@ -32,9 +48,8 @@ def run_raw_data_check(model, X_train, X_test, y_train, y_test,
32
48
  cleaned_data=cleaned_data,
33
49
  raw_data=raw_obj
34
50
  )
35
-
36
- stats = check.run()
37
- return stats["RawDataCheck"] # hand the inner dict to ValidationEngine
51
+ stats = check.run()
52
+ return stats.get("RawDataCheck", stats)
38
53
 
39
54
  except Exception as e:
40
55
  print(f"⚠️ RawDataCheck failed: {e}")