tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ # tanml/check_runners/regression_metrics_runner.py
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from typing import Any, Dict, Optional
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+ from tanml.checks.regression_metrics import RegressionMetricsCheck
11
+
12
+ try:
13
+ from scipy import stats as _scipy_stats
14
+ _HAS_SCIPY = True
15
+ except Exception:
16
+ _HAS_SCIPY = False
17
+
18
+
19
+ # ---------- utils ----------
20
+ def _ensure_outdir(config: Dict[str, Any]) -> str:
21
+ base = (config.get("options") or {}).get("save_artifacts_dir") or "reports"
22
+ outdir = os.path.join(base, "regression_metrics")
23
+ os.makedirs(outdir, exist_ok=True)
24
+ return outdir
25
+
26
+
27
+ def _to_1d(x: Any) -> np.ndarray:
28
+ return np.asarray(x).reshape(-1)
29
+
30
+
31
+ def _residuals(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
32
+ return y_true - y_pred
33
+
34
+
35
+ def _plot_pred_vs_actual(y_true: np.ndarray, y_pred: np.ndarray, save_path: str) -> str:
36
+ plt.figure()
37
+ plt.scatter(y_true, y_pred, s=12, alpha=0.75)
38
+ mn = float(min(np.min(y_true), np.min(y_pred)))
39
+ mx = float(max(np.max(y_true), np.max(y_pred)))
40
+ plt.plot([mn, mx], [mn, mx]) # reference y=x
41
+ plt.xlabel("Actual")
42
+ plt.ylabel("Predicted")
43
+ plt.title("Predicted vs Actual")
44
+ plt.tight_layout()
45
+ plt.savefig(save_path, dpi=160)
46
+ plt.close()
47
+ return save_path
48
+
49
+
50
+ def _plot_residuals_vs_pred(y_pred: np.ndarray, resid: np.ndarray, save_path: str) -> str:
51
+ plt.figure()
52
+ plt.scatter(y_pred, resid, s=12, alpha=0.75)
53
+ plt.axhline(0.0)
54
+ plt.xlabel("Predicted")
55
+ plt.ylabel("Residual (y_true - y_pred)")
56
+ plt.title("Residuals vs Predicted")
57
+ plt.tight_layout()
58
+ plt.savefig(save_path, dpi=160)
59
+ plt.close()
60
+ return save_path
61
+
62
+
63
+ def _plot_residual_hist(resid: np.ndarray, save_path: str) -> str:
64
+ plt.figure()
65
+ plt.hist(resid, bins=30, alpha=0.9)
66
+ plt.xlabel("Residual")
67
+ plt.ylabel("Count")
68
+ plt.title("Residual Distribution")
69
+ plt.tight_layout()
70
+ plt.savefig(save_path, dpi=160)
71
+ plt.close()
72
+ return save_path
73
+
74
+
75
+ def _plot_qq(resid: np.ndarray, save_path: str) -> str:
76
+ osm, osr = _scipy_stats.probplot(resid, dist="norm", fit=False)
77
+ plt.figure()
78
+ plt.scatter(osm, osr, s=12, alpha=0.8)
79
+ mn = float(min(np.min(osm), np.min(osr)))
80
+ mx = float(max(np.max(osm), np.max(osr)))
81
+ plt.plot([mn, mx], [mn, mx])
82
+ plt.xlabel("Theoretical Quantiles (Normal)")
83
+ plt.ylabel("Ordered Residuals")
84
+ plt.title("Residuals Q–Q Plot")
85
+ plt.tight_layout()
86
+ plt.savefig(save_path, dpi=160)
87
+ plt.close()
88
+ return save_path
89
+
90
+
91
+ def _plot_abs_error_box(abs_err: np.ndarray, save_path: str) -> str:
92
+ plt.figure()
93
+ plt.boxplot(abs_err, vert=True, showfliers=True)
94
+ plt.ylabel("|Residual|")
95
+ plt.title("Absolute Error — Box Plot")
96
+ plt.tight_layout()
97
+ plt.savefig(save_path, dpi=160)
98
+ plt.close()
99
+ return save_path
100
+
101
+
102
+ def _plot_abs_error_violin(abs_err: np.ndarray, save_path: str) -> str:
103
+ plt.figure()
104
+ plt.violinplot(abs_err, showmeans=True, showmedians=True)
105
+ plt.ylabel("|Residual|")
106
+ plt.title("Absolute Error — Violin Plot")
107
+ plt.tight_layout()
108
+ plt.savefig(save_path, dpi=160)
109
+ plt.close()
110
+ return save_path
111
+
112
+
113
+ def RegressionMetricsCheckRunner(
114
+ model: Any,
115
+ X_train: Any,
116
+ X_test: Any,
117
+ y_train: Any,
118
+ y_test: Any,
119
+ config: Dict[str, Any],
120
+ cleaned_df: Optional[Any] = None,
121
+ raw_df: Optional[Any] = None,
122
+ ctx: Optional[Dict[str, Any]] = None,
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ 1) Predict on X_test
126
+ 2) Compute frozen regression metrics
127
+ 3) Save 5 standard charts (Q–Q skipped if SciPy missing)
128
+ 4) Return structured results for engine/report
129
+ """
130
+ # 1) predictions
131
+ try:
132
+ y_pred = model.predict(X_test)
133
+ except Exception as e:
134
+ raise RuntimeError(f"Model prediction failed in RegressionMetricsCheckRunner: {e}")
135
+
136
+ y_true = _to_1d(y_test)
137
+ y_pred = _to_1d(y_pred)
138
+
139
+ # n_features for Adjusted R²
140
+ try:
141
+ n_features = int(getattr(X_train, "shape", [None, None])[1])
142
+ except Exception:
143
+ n_features = None
144
+
145
+ # 2) metrics
146
+ chk = RegressionMetricsCheck(
147
+ y_true=y_true,
148
+ y_pred=y_pred,
149
+ n_features=n_features,
150
+ config=(config or {}),
151
+ )
152
+ metrics = chk.run()
153
+
154
+ # 3) plots
155
+ outdir = _ensure_outdir(config)
156
+ resid = _residuals(y_true, y_pred)
157
+ abs_err = np.abs(resid)
158
+
159
+ p1 = os.path.join(outdir, "pred_vs_actual.png")
160
+ p2 = os.path.join(outdir, "residuals_vs_pred.png")
161
+ p3 = os.path.join(outdir, "residual_hist.png")
162
+ p4 = os.path.join(outdir, "qq_plot.png")
163
+ b1 = os.path.join(outdir, "abs_error_box.png")
164
+ v1 = os.path.join(outdir, "abs_error_violin.png")
165
+
166
+ try:
167
+ _plot_pred_vs_actual(y_true, y_pred, p1)
168
+ _plot_residuals_vs_pred(y_pred, resid, p2)
169
+ _plot_residual_hist(resid, p3)
170
+
171
+ if _HAS_SCIPY:
172
+ _plot_qq(resid, p4)
173
+ else:
174
+ (metrics.get("notes") or []).append("Q–Q plot skipped: SciPy not available.")
175
+ p4 = None
176
+
177
+ _plot_abs_error_box(abs_err, b1)
178
+ _plot_abs_error_violin(abs_err, v1)
179
+ except Exception as e:
180
+ (metrics.get("notes") or []).append(f"Plotting failed: {e}")
181
+
182
+ # 4) return
183
+ return {
184
+ "RegressionMetrics": {
185
+ **metrics,
186
+ "artifacts": {
187
+ "pred_vs_actual": p1,
188
+ "residuals_vs_pred": p2,
189
+ "residual_hist": p3,
190
+ "qq_plot": p4,
191
+ "abs_error_box": b1,
192
+ "abs_error_violin": v1,
193
+ },
194
+ }
195
+ }
@@ -1,26 +1,43 @@
1
+ from __future__ import annotations
2
+ from typing import Any, Dict
3
+ import pandas as pd
1
4
  from tanml.checks.stress_test import StressTestCheck
2
5
 
3
6
  def run_stress_test_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
- cfg = rule_config.get("StressTestCheck", {})
7
+ cfg = (rule_config or {}).get("StressTestCheck", {})
5
8
  if not cfg.get("enabled", True):
6
9
  print("ℹ️ Skipping StressTestCheck (disabled in rules.yaml)")
7
10
  return {"StressTestCheck": {"skipped": True}}
8
11
 
9
12
  try:
10
- epsilon = cfg.get("epsilon", 0.01)
13
+ epsilon = cfg.get("epsilon", 0.01)
11
14
  perturb_fraction = cfg.get("perturb_fraction", 0.2)
12
15
 
13
- checker = StressTestCheck(model, X_test, y_test, epsilon, perturb_fraction)
16
+
17
+ cols_test = getattr(X_test, "columns", None)
18
+ cols_train = getattr(X_train, "columns", None)
19
+
20
+ if cols_test is not None:
21
+ columns = list(cols_test)
22
+ elif cols_train is not None:
23
+ columns = list(cols_train)
24
+ else:
25
+ columns = None
26
+
27
+ X_test_df = pd.DataFrame(X_test, columns=columns)
28
+
29
+ checker = StressTestCheck(model, X_test_df, y_test, epsilon, perturb_fraction)
14
30
  result = checker.run()
15
31
 
16
- # Ensure output is always a dictionary
17
32
  if isinstance(result, list):
18
- return {"StressTestCheck": {"table": result}}
33
+ table = result
19
34
  elif hasattr(result, "to_dict"):
20
- return {"StressTestCheck": {"table": result.to_dict(orient="records")}}
35
+ table = result.to_dict(orient="records")
21
36
  else:
22
37
  return {"StressTestCheck": {"output": result}}
23
38
 
39
+ return {"StressTestCheck": {"table": table}}
40
+
24
41
  except Exception as e:
25
42
  print(f"⚠️ StressTestCheck failed: {e}")
26
43
  return {"StressTestCheck": {"error": str(e)}}
@@ -1,27 +1,40 @@
1
1
  # tanml/check_runners/vif_runner.py
2
+ from __future__ import annotations
2
3
 
3
- from tanml.checks.vif import VIFCheck
4
- import pandas as pd
4
+ import os
5
5
  from pathlib import Path
6
+ import pandas as pd
7
+
8
+ from tanml.utils.data_loader import load_dataframe
9
+ from tanml.checks.vif import VIFCheck
10
+
6
11
 
7
12
  def VIFCheckRunner(
8
13
  model, X_train, X_test, y_train, y_test,
9
14
  rule_config, cleaned_df, *args, **kwargs
10
15
  ):
11
- # Ensure cleaned_df is a DataFrame
12
- if isinstance(cleaned_df, (str, Path)):
13
- try:
14
- cleaned_df = pd.read_csv(cleaned_df)
15
- except Exception as e:
16
- err = f"Could not read cleaned_df CSV: {e}"
17
- print(f"⚠️ {err}")
18
- return {"vif_table": [], "high_vif_features": [], "error": err}
16
+ """
17
+ Ensure cleaned_df is a DataFrame; if a path (csv/xlsx/parquet/etc.),
18
+ load it via the universal loader, then run VIFCheck.
19
+ """
20
+ # 1) Normalize cleaned_df to a DataFrame
21
+ try:
22
+ if isinstance(cleaned_df, (str, bytes, os.PathLike, Path)):
23
+ cleaned_df = load_dataframe(cleaned_df)
24
+ elif not isinstance(cleaned_df, pd.DataFrame):
25
+ msg = "cleaned_df is not a DataFrame or loadable path; skipping VIF."
26
+ print(f"ℹ️ {msg}")
27
+ return {"vif_table": [], "high_vif_features": [], "error": msg}
28
+ except Exception as e:
29
+ err = f"Could not load cleaned_df: {e}"
30
+ print(f"⚠️ {err}")
31
+ return {"vif_table": [], "high_vif_features": [], "error": err}
19
32
 
33
+ # 2) Run the check
20
34
  try:
21
35
  check = VIFCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
22
- result = check.run() # Could be dict or list
23
-
24
- # Normalize result regardless of format
36
+ result = check.run()
37
+ # 3) Normalize result
25
38
  if isinstance(result, dict) and "vif_table" in result:
26
39
  vif_rows = result["vif_table"]
27
40
  elif isinstance(result, list):
@@ -29,25 +42,18 @@ def VIFCheckRunner(
29
42
  else:
30
43
  raise ValueError("Unexpected VIFCheck return shape")
31
44
 
32
- # Rename 'feature' to 'Feature', round VIF values
45
+ # 4) Canonicalize keys and values
33
46
  for row in vif_rows:
34
47
  if "Feature" not in row and "feature" in row:
35
48
  row["Feature"] = row.pop("feature")
36
- row["VIF"] = round(float(row["VIF"]), 2)
49
+ if "VIF" in row and row["VIF"] is not None:
50
+ row["VIF"] = round(float(row["VIF"]), 2)
37
51
 
38
- # Identify high VIF features
52
+ # 5) Identify high VIF features
39
53
  threshold = rule_config.get("vif_threshold", 5)
40
- high_vif = [
41
- row["Feature"] for row in vif_rows
42
- if row.get("VIF") is not None and row["VIF"] > threshold
43
- ]
44
-
45
- # Return final output
46
- return {
47
- "vif_table": vif_rows,
48
- "high_vif_features": high_vif,
49
- "error": None,
50
- }
54
+ high_vif = [r["Feature"] for r in vif_rows if r.get("VIF") is not None and r["VIF"] > threshold]
55
+
56
+ return {"vif_table": vif_rows, "high_vif_features": high_vif, "error": None}
51
57
 
52
58
  except Exception as e:
53
59
  print(f"⚠️ VIFCheck failed: {e}")
@@ -1,61 +1,261 @@
1
- from .base import BaseCheck
1
+ # checks/correlation.py
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Dict, List, Tuple
7
+
8
+ import numpy as np
2
9
  import pandas as pd
3
- import seaborn as sns
4
10
  import matplotlib.pyplot as plt
5
- import os
11
+
12
+ try:
13
+ from scipy.cluster.hierarchy import linkage, leaves_list
14
+ from scipy.spatial.distance import squareform
15
+ _HAS_SCIPY = True
16
+ except Exception:
17
+ _HAS_SCIPY = False
18
+
19
+ from .base import BaseCheck
20
+
21
+
22
+ DEFAULT_CFG = {
23
+ "method": "pearson", # "pearson" | "spearman"
24
+ "high_corr_threshold": 0.80, # |r| >= threshold flagged
25
+ "top_pairs_max": 200, # rows in the "main" table CSV
26
+ "heatmap_max_features_default": 20, # default plotted features
27
+ "heatmap_max_features_limit": 60, # max allowed via UI/slider
28
+ "subset_strategy": "cluster", # "cluster" | "degree"
29
+ "sample_rows": 150_000, # downsample for speed on huge data
30
+ "seed": 42,
31
+ "save_csv": True,
32
+ "save_fig": True,
33
+ "appendix_csv_cap": None, # None = no cap; or int (e.g., 5000)
34
+ }
35
+
36
+
37
+ def _ensure_dir(p: Path) -> None:
38
+ p.mkdir(parents=True, exist_ok=True)
39
+
40
+
41
+ def _numeric_columns(df: pd.DataFrame) -> List[str]:
42
+ return [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
43
+
44
+
45
+ def _drop_constant_columns(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
46
+ keep, dropped = [], []
47
+ for c in df.columns:
48
+ s = df[c]
49
+ if s.dropna().nunique() <= 1:
50
+ dropped.append(c)
51
+ else:
52
+ keep.append(c)
53
+ return df[keep], dropped
54
+
55
+
56
+ def _subset_by_degree(corr_abs: pd.DataFrame, max_feats: int) -> List[str]:
57
+ if corr_abs.shape[0] <= max_feats:
58
+ return list(corr_abs.index)
59
+ scores = corr_abs.sum().sort_values(ascending=False)
60
+ return list(scores.head(max_feats).index)
61
+
62
+
63
+ def _subset_by_cluster(corr_abs: pd.DataFrame, max_feats: int) -> List[str]:
64
+ if corr_abs.shape[0] <= max_feats:
65
+ return list(corr_abs.index)
66
+ if not _HAS_SCIPY:
67
+ return _subset_by_degree(corr_abs, max_feats)
68
+ # distance = 1 - |corr|
69
+ dist = 1.0 - corr_abs
70
+ dist = (dist + dist.T) / 2.0
71
+ np.fill_diagonal(dist.values, 0.0)
72
+ Z = linkage(squareform(dist.values, checks=False), method="average")
73
+ order = leaves_list(Z)
74
+ ordered = corr_abs.index[order]
75
+ step = max(1, len(ordered) // max_feats)
76
+ return list(ordered[::step][:max_feats])
77
+
78
+
79
+ def _render_heatmap(corr: pd.DataFrame, out_path: Path, title: str) -> None:
80
+ fig, ax = plt.subplots(figsize=(10, 8))
81
+ im = ax.imshow(corr.values, vmin=-1, vmax=1)
82
+ ax.set_xticks(range(corr.shape[1]))
83
+ ax.set_yticks(range(corr.shape[0]))
84
+ ax.set_xticklabels(corr.columns, rotation=90, fontsize=6)
85
+ ax.set_yticklabels(corr.index, fontsize=6)
86
+ ax.set_title(title)
87
+ cbar = plt.colorbar(im, ax=ax)
88
+ cbar.set_label("Correlation")
89
+ fig.tight_layout()
90
+ fig.savefig(out_path, dpi=200)
91
+ plt.close(fig)
92
+
6
93
 
7
94
  class CorrelationCheck(BaseCheck):
8
- def __init__(self, cleaned_data: pd.DataFrame, output_dir: str = "reports/correlation"):
9
- """
10
- Computes Pearson and Spearman correlation matrices and saves them to disk,
11
- along with a heatmap for visualization.
12
- """
95
+ """
96
+ Numeric-only correlation analysis:
97
+ Pearson or Spearman (pairwise complete obs)
98
+ Heatmap on ≤20 features by default (clustered subset up to 60 max)
99
+ • CSV of high-correlation pairs (|r| ≥ threshold), sorted by |r|
100
+ • Handles constant/all-NA columns, optional sampling for speed
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ cleaned_data: pd.DataFrame,
106
+ cfg: Dict | None = None,
107
+ output_dir: str = "reports/correlation",
108
+ ):
13
109
  self.cleaned_data = cleaned_data
110
+ self.cfg = {**DEFAULT_CFG, **(cfg or {})}
14
111
  self.output_dir = output_dir
15
112
  os.makedirs(self.output_dir, exist_ok=True)
16
113
 
114
+ def _top_corr_pairs(self, corr: pd.DataFrame, thr: float) -> pd.DataFrame:
115
+ a = corr.copy()
116
+ np.fill_diagonal(a.values, np.nan)
117
+ s = a.stack().reset_index()
118
+ s.columns = ["feature_i", "feature_j", "corr"]
119
+ s = s.dropna()
120
+ # remove duplicate symmetric pairs
121
+ s["pair"] = s.apply(lambda r: tuple(sorted([r["feature_i"], r["feature_j"]])), axis=1)
122
+ s = s.drop_duplicates(subset=["pair"]).drop(columns=["pair"])
123
+ s["abs_corr"] = s["corr"].abs()
124
+ s = s[s["abs_corr"] >= thr].sort_values("abs_corr", ascending=False)
125
+ return s
126
+
17
127
  def run(self):
18
- # Select numeric features only
19
- numeric_data = self.cleaned_data.select_dtypes(include="number")
128
+ cfg = self.cfg
129
+ method = cfg["method"]
130
+ thr = float(cfg["high_corr_threshold"])
131
+ top_cap = int(cfg["top_pairs_max"])
132
+ default_cap = int(cfg["heatmap_max_features_default"])
133
+ max_cap = int(cfg["heatmap_max_features_limit"])
134
+ subset_strategy = cfg["subset_strategy"]
135
+ sample_rows = int(cfg["sample_rows"])
136
+ seed = int(cfg["seed"])
137
+ save_csv = bool(cfg["save_csv"])
138
+ save_fig = bool(cfg["save_fig"])
139
+ appendix_cap = cfg.get("appendix_csv_cap", None)
140
+
141
+ # ===== 1) Select numeric & (optional) sample rows =====
142
+ X = self.cleaned_data.copy()
143
+ if len(X) > sample_rows:
144
+ X = X.sample(sample_rows, random_state=seed)
145
+
146
+ num_cols = _numeric_columns(X)
147
+ X = X[num_cols]
148
+ X, dropped_constants = _drop_constant_columns(X)
20
149
 
21
- if numeric_data.shape[1] < 2:
22
- print("⚠️ Not enough numeric features for correlation.")
150
+ if X.shape[1] < 2:
151
+ msg = "⚠️ Not enough numeric features for correlation."
152
+ print(msg)
23
153
  return {
24
154
  "pearson_csv": None,
25
155
  "spearman_csv": None,
26
156
  "heatmap_path": None,
157
+ "top_pairs_csv": None,
158
+ "summary": {"n_numeric_features": X.shape[1]},
159
+ "notes": [msg, f"Dropped constant/all-NA columns: {dropped_constants}"] if dropped_constants else [msg],
27
160
  "error": "Not enough numeric features for correlation",
28
161
  }
29
162
 
30
- # Compute correlations
31
- pearson_corr = numeric_data.corr(method="pearson")
32
- spearman_corr = numeric_data.corr(method="spearman")
33
-
34
- # Save CSVs
35
- pearson_path = os.path.join(self.output_dir, "pearson_corr.csv")
36
- spearman_path = os.path.join(self.output_dir, "spearman_corr.csv")
37
- pearson_corr.to_csv(pearson_path)
38
- spearman_corr.to_csv(spearman_path)
39
-
40
- # Create heatmap
41
- heatmap_path = os.path.join(self.output_dir, "heatmap.png")
42
- plt.figure(figsize=(10, 8))
43
- sns.heatmap(
44
- pearson_corr,
45
- annot=True,
46
- fmt=".2f",
47
- cmap="coolwarm",
48
- cbar_kws={"label": "Pearson Coefficient"},
49
- )
50
- plt.title("Pearson Correlation Heatmap")
51
- plt.xticks(rotation=45, ha="right")
52
- plt.yticks(rotation=0)
53
- plt.tight_layout()
54
- plt.savefig(heatmap_path)
55
- plt.close()
163
+ # ===== 2) Correlation matrix =====
164
+ # Compute both; pick one to drive plotting/threshold logic
165
+ corr_pearson = X.corr(method="pearson")
166
+ corr_spearman = X.corr(method="spearman")
167
+ corr = corr_pearson if method == "pearson" else corr_spearman
168
+ corr_abs = corr.abs()
169
+
170
+ # ===== 3) High-correlation pairs CSV =====
171
+ pairs = self._top_corr_pairs(corr, thr)
172
+ # augment with pairwise n_used and feature missingness %
173
+ non_null_counts = X.notna().sum()
174
+ total_rows = len(X)
175
+ if not pairs.empty:
176
+ pairs["n_used"] = pairs.apply(
177
+ lambda r: X[[r["feature_i"], r["feature_j"]]].dropna().shape[0], axis=1
178
+ )
179
+ pairs["pct_missing_i"] = pairs.apply(
180
+ lambda r: 1 - non_null_counts[r["feature_i"]] / total_rows, axis=1
181
+ )
182
+ pairs["pct_missing_j"] = pairs.apply(
183
+ lambda r: 1 - non_null_counts[r["feature_j"]] / total_rows, axis=1
184
+ )
185
+
186
+ artifacts: Dict[str, str] = {}
187
+ outdir = Path(self.output_dir)
188
+ _ensure_dir(outdir)
189
+
190
+ # Save full correlation matrices (if enabled)
191
+ pearson_csv_path = outdir / "pearson_corr.csv"
192
+ spearman_csv_path = outdir / "spearman_corr.csv"
193
+ if save_csv:
194
+ corr_pearson.to_csv(pearson_csv_path, index=True)
195
+ corr_spearman.to_csv(spearman_csv_path, index=True)
196
+
197
+ # Save top-pairs CSVs (main + full/appendix)
198
+ if save_csv:
199
+ full_csv = outdir / "correlation_top_pairs.csv"
200
+ if appendix_cap is not None:
201
+ pairs.head(int(appendix_cap)).to_csv(full_csv, index=False)
202
+ else:
203
+ pairs.to_csv(full_csv, index=False)
204
+ artifacts["top_pairs_csv"] = str(full_csv)
205
+
206
+ main_csv = outdir / "correlation_top_pairs_main.csv"
207
+ pairs.head(top_cap).to_csv(main_csv, index=False)
208
+ artifacts["top_pairs_main_csv"] = str(main_csv)
209
+
210
+ # ===== 4) Adaptive heatmap =====
211
+ n_features_total = X.shape[1]
212
+ plotted_full_matrix = n_features_total <= default_cap
213
+
214
+ if not plotted_full_matrix:
215
+ cap = min(max_cap, n_features_total)
216
+ if subset_strategy == "cluster" and _HAS_SCIPY:
217
+ subset = _subset_by_cluster(corr_abs, cap)
218
+ else:
219
+ subset = _subset_by_degree(corr_abs, cap)
220
+ corr_plot = corr.loc[subset, subset]
221
+ title = f"Correlation Heatmap ({method}) — {len(subset)}/{n_features_total} features (subset)"
222
+ else:
223
+ corr_plot = corr
224
+ title = f"Correlation Heatmap ({method}) — full matrix ({n_features_total} features)"
225
+
226
+ heatmap_path = None
227
+ if save_fig:
228
+ heatmap_path = outdir / "heatmap.png"
229
+ _render_heatmap(corr_plot, heatmap_path, title)
230
+ artifacts["heatmap_path"] = str(heatmap_path)
231
+
232
+ # ===== 5) Summary/notes =====
233
+ n_pairs_total = n_features_total * (n_features_total - 1) // 2
234
+ n_pairs_flagged = int(pairs.shape[0]) if not pairs.empty else 0
235
+ notes = []
236
+ if dropped_constants:
237
+ notes.append(f"Dropped constant/all-NA columns: {sorted(dropped_constants)}")
238
+ if len(self.cleaned_data) > sample_rows:
239
+ notes.append(f"Computed on a {sample_rows}-row sample (seed={seed}).")
240
+ if not plotted_full_matrix:
241
+ notes.append(
242
+ f"Heatmap shows a subset ({corr_plot.shape[0]}/{n_features_total}); see CSV for full list of pairs."
243
+ )
56
244
 
57
245
  return {
58
- "pearson_csv": pearson_path,
59
- "spearman_csv": spearman_path,
60
- "heatmap_path": heatmap_path,
246
+ "pearson_csv": str(pearson_csv_path) if save_csv else None,
247
+ "spearman_csv": str(spearman_csv_path) if save_csv else None,
248
+ "heatmap_path": str(heatmap_path) if heatmap_path else None,
249
+ "top_pairs_csv": artifacts.get("top_pairs_csv"),
250
+ "top_pairs_main_csv": artifacts.get("top_pairs_main_csv"),
251
+ "summary": {
252
+ "n_numeric_features": int(n_features_total),
253
+ "n_pairs_total": int(n_pairs_total),
254
+ "n_pairs_flagged_ge_threshold": int(n_pairs_flagged),
255
+ "threshold": float(thr),
256
+ "method": method,
257
+ "plotted_features": int(corr_plot.shape[0]),
258
+ "plotted_full_matrix": bool(plotted_full_matrix),
259
+ },
260
+ "notes": notes,
61
261
  }