tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -1
- tanml/check_runners/cleaning_repro_runner.py +2 -2
- tanml/check_runners/correlation_runner.py +49 -12
- tanml/check_runners/explainability_runner.py +12 -22
- tanml/check_runners/logistic_stats_runner.py +196 -17
- tanml/check_runners/performance_runner.py +82 -26
- tanml/check_runners/raw_data_runner.py +29 -14
- tanml/check_runners/regression_metrics_runner.py +195 -0
- tanml/check_runners/stress_test_runner.py +23 -6
- tanml/check_runners/vif_runner.py +33 -27
- tanml/checks/correlation.py +241 -41
- tanml/checks/explainability/shap_check.py +261 -29
- tanml/checks/logit_stats.py +186 -54
- tanml/checks/performance_classification.py +305 -0
- tanml/checks/raw_data.py +58 -23
- tanml/checks/regression_metrics.py +167 -0
- tanml/checks/stress_test.py +157 -53
- tanml/cli/main.py +99 -27
- tanml/engine/check_agent_registry.py +20 -10
- tanml/engine/core_engine_agent.py +199 -37
- tanml/models/registry.py +329 -0
- tanml/report/report_builder.py +1180 -147
- tanml/report/templates/report_template_cls.docx +0 -0
- tanml/report/templates/report_template_reg.docx +0 -0
- tanml/ui/app.py +1205 -0
- tanml/utils/data_loader.py +105 -15
- tanml-0.1.7.dist-info/METADATA +164 -0
- tanml-0.1.7.dist-info/RECORD +54 -0
- tanml/cli/arg_parser.py +0 -31
- tanml/cli/init_cmd.py +0 -8
- tanml/cli/validate_cmd.py +0 -7
- tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
- tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
- tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
- tanml/engine/segmentation_agent.py +0 -118
- tanml/engine/validation_agent.py +0 -91
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/model_loader.py +0 -35
- tanml/utils/r_loader.py +0 -30
- tanml/utils/sas_loader.py +0 -50
- tanml/utils/yaml_generator.py +0 -34
- tanml/utils/yaml_loader.py +0 -5
- tanml/validate.py +0 -209
- tanml-0.1.6.dist-info/METADATA +0 -317
- tanml-0.1.6.dist-info/RECORD +0 -62
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# tanml/check_runners/regression_metrics_runner.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
|
|
10
|
+
from tanml.checks.regression_metrics import RegressionMetricsCheck
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from scipy import stats as _scipy_stats
|
|
14
|
+
_HAS_SCIPY = True
|
|
15
|
+
except Exception:
|
|
16
|
+
_HAS_SCIPY = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ---------- utils ----------
|
|
20
|
+
def _ensure_outdir(config: Dict[str, Any]) -> str:
|
|
21
|
+
base = (config.get("options") or {}).get("save_artifacts_dir") or "reports"
|
|
22
|
+
outdir = os.path.join(base, "regression_metrics")
|
|
23
|
+
os.makedirs(outdir, exist_ok=True)
|
|
24
|
+
return outdir
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _to_1d(x: Any) -> np.ndarray:
|
|
28
|
+
return np.asarray(x).reshape(-1)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _residuals(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|
32
|
+
return y_true - y_pred
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _plot_pred_vs_actual(y_true: np.ndarray, y_pred: np.ndarray, save_path: str) -> str:
|
|
36
|
+
plt.figure()
|
|
37
|
+
plt.scatter(y_true, y_pred, s=12, alpha=0.75)
|
|
38
|
+
mn = float(min(np.min(y_true), np.min(y_pred)))
|
|
39
|
+
mx = float(max(np.max(y_true), np.max(y_pred)))
|
|
40
|
+
plt.plot([mn, mx], [mn, mx]) # reference y=x
|
|
41
|
+
plt.xlabel("Actual")
|
|
42
|
+
plt.ylabel("Predicted")
|
|
43
|
+
plt.title("Predicted vs Actual")
|
|
44
|
+
plt.tight_layout()
|
|
45
|
+
plt.savefig(save_path, dpi=160)
|
|
46
|
+
plt.close()
|
|
47
|
+
return save_path
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _plot_residuals_vs_pred(y_pred: np.ndarray, resid: np.ndarray, save_path: str) -> str:
|
|
51
|
+
plt.figure()
|
|
52
|
+
plt.scatter(y_pred, resid, s=12, alpha=0.75)
|
|
53
|
+
plt.axhline(0.0)
|
|
54
|
+
plt.xlabel("Predicted")
|
|
55
|
+
plt.ylabel("Residual (y_true - y_pred)")
|
|
56
|
+
plt.title("Residuals vs Predicted")
|
|
57
|
+
plt.tight_layout()
|
|
58
|
+
plt.savefig(save_path, dpi=160)
|
|
59
|
+
plt.close()
|
|
60
|
+
return save_path
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _plot_residual_hist(resid: np.ndarray, save_path: str) -> str:
|
|
64
|
+
plt.figure()
|
|
65
|
+
plt.hist(resid, bins=30, alpha=0.9)
|
|
66
|
+
plt.xlabel("Residual")
|
|
67
|
+
plt.ylabel("Count")
|
|
68
|
+
plt.title("Residual Distribution")
|
|
69
|
+
plt.tight_layout()
|
|
70
|
+
plt.savefig(save_path, dpi=160)
|
|
71
|
+
plt.close()
|
|
72
|
+
return save_path
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _plot_qq(resid: np.ndarray, save_path: str) -> str:
|
|
76
|
+
osm, osr = _scipy_stats.probplot(resid, dist="norm", fit=False)
|
|
77
|
+
plt.figure()
|
|
78
|
+
plt.scatter(osm, osr, s=12, alpha=0.8)
|
|
79
|
+
mn = float(min(np.min(osm), np.min(osr)))
|
|
80
|
+
mx = float(max(np.max(osm), np.max(osr)))
|
|
81
|
+
plt.plot([mn, mx], [mn, mx])
|
|
82
|
+
plt.xlabel("Theoretical Quantiles (Normal)")
|
|
83
|
+
plt.ylabel("Ordered Residuals")
|
|
84
|
+
plt.title("Residuals Q–Q Plot")
|
|
85
|
+
plt.tight_layout()
|
|
86
|
+
plt.savefig(save_path, dpi=160)
|
|
87
|
+
plt.close()
|
|
88
|
+
return save_path
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _plot_abs_error_box(abs_err: np.ndarray, save_path: str) -> str:
|
|
92
|
+
plt.figure()
|
|
93
|
+
plt.boxplot(abs_err, vert=True, showfliers=True)
|
|
94
|
+
plt.ylabel("|Residual|")
|
|
95
|
+
plt.title("Absolute Error — Box Plot")
|
|
96
|
+
plt.tight_layout()
|
|
97
|
+
plt.savefig(save_path, dpi=160)
|
|
98
|
+
plt.close()
|
|
99
|
+
return save_path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _plot_abs_error_violin(abs_err: np.ndarray, save_path: str) -> str:
|
|
103
|
+
plt.figure()
|
|
104
|
+
plt.violinplot(abs_err, showmeans=True, showmedians=True)
|
|
105
|
+
plt.ylabel("|Residual|")
|
|
106
|
+
plt.title("Absolute Error — Violin Plot")
|
|
107
|
+
plt.tight_layout()
|
|
108
|
+
plt.savefig(save_path, dpi=160)
|
|
109
|
+
plt.close()
|
|
110
|
+
return save_path
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def RegressionMetricsCheckRunner(
|
|
114
|
+
model: Any,
|
|
115
|
+
X_train: Any,
|
|
116
|
+
X_test: Any,
|
|
117
|
+
y_train: Any,
|
|
118
|
+
y_test: Any,
|
|
119
|
+
config: Dict[str, Any],
|
|
120
|
+
cleaned_df: Optional[Any] = None,
|
|
121
|
+
raw_df: Optional[Any] = None,
|
|
122
|
+
ctx: Optional[Dict[str, Any]] = None,
|
|
123
|
+
) -> Dict[str, Any]:
|
|
124
|
+
"""
|
|
125
|
+
1) Predict on X_test
|
|
126
|
+
2) Compute frozen regression metrics
|
|
127
|
+
3) Save 5 standard charts (Q–Q skipped if SciPy missing)
|
|
128
|
+
4) Return structured results for engine/report
|
|
129
|
+
"""
|
|
130
|
+
# 1) predictions
|
|
131
|
+
try:
|
|
132
|
+
y_pred = model.predict(X_test)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise RuntimeError(f"Model prediction failed in RegressionMetricsCheckRunner: {e}")
|
|
135
|
+
|
|
136
|
+
y_true = _to_1d(y_test)
|
|
137
|
+
y_pred = _to_1d(y_pred)
|
|
138
|
+
|
|
139
|
+
# n_features for Adjusted R²
|
|
140
|
+
try:
|
|
141
|
+
n_features = int(getattr(X_train, "shape", [None, None])[1])
|
|
142
|
+
except Exception:
|
|
143
|
+
n_features = None
|
|
144
|
+
|
|
145
|
+
# 2) metrics
|
|
146
|
+
chk = RegressionMetricsCheck(
|
|
147
|
+
y_true=y_true,
|
|
148
|
+
y_pred=y_pred,
|
|
149
|
+
n_features=n_features,
|
|
150
|
+
config=(config or {}),
|
|
151
|
+
)
|
|
152
|
+
metrics = chk.run()
|
|
153
|
+
|
|
154
|
+
# 3) plots
|
|
155
|
+
outdir = _ensure_outdir(config)
|
|
156
|
+
resid = _residuals(y_true, y_pred)
|
|
157
|
+
abs_err = np.abs(resid)
|
|
158
|
+
|
|
159
|
+
p1 = os.path.join(outdir, "pred_vs_actual.png")
|
|
160
|
+
p2 = os.path.join(outdir, "residuals_vs_pred.png")
|
|
161
|
+
p3 = os.path.join(outdir, "residual_hist.png")
|
|
162
|
+
p4 = os.path.join(outdir, "qq_plot.png")
|
|
163
|
+
b1 = os.path.join(outdir, "abs_error_box.png")
|
|
164
|
+
v1 = os.path.join(outdir, "abs_error_violin.png")
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
_plot_pred_vs_actual(y_true, y_pred, p1)
|
|
168
|
+
_plot_residuals_vs_pred(y_pred, resid, p2)
|
|
169
|
+
_plot_residual_hist(resid, p3)
|
|
170
|
+
|
|
171
|
+
if _HAS_SCIPY:
|
|
172
|
+
_plot_qq(resid, p4)
|
|
173
|
+
else:
|
|
174
|
+
(metrics.get("notes") or []).append("Q–Q plot skipped: SciPy not available.")
|
|
175
|
+
p4 = None
|
|
176
|
+
|
|
177
|
+
_plot_abs_error_box(abs_err, b1)
|
|
178
|
+
_plot_abs_error_violin(abs_err, v1)
|
|
179
|
+
except Exception as e:
|
|
180
|
+
(metrics.get("notes") or []).append(f"Plotting failed: {e}")
|
|
181
|
+
|
|
182
|
+
# 4) return
|
|
183
|
+
return {
|
|
184
|
+
"RegressionMetrics": {
|
|
185
|
+
**metrics,
|
|
186
|
+
"artifacts": {
|
|
187
|
+
"pred_vs_actual": p1,
|
|
188
|
+
"residuals_vs_pred": p2,
|
|
189
|
+
"residual_hist": p3,
|
|
190
|
+
"qq_plot": p4,
|
|
191
|
+
"abs_error_box": b1,
|
|
192
|
+
"abs_error_violin": v1,
|
|
193
|
+
},
|
|
194
|
+
}
|
|
195
|
+
}
|
|
@@ -1,26 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
import pandas as pd
|
|
1
4
|
from tanml.checks.stress_test import StressTestCheck
|
|
2
5
|
|
|
3
6
|
def run_stress_test_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
-
cfg = rule_config.get("StressTestCheck", {})
|
|
7
|
+
cfg = (rule_config or {}).get("StressTestCheck", {})
|
|
5
8
|
if not cfg.get("enabled", True):
|
|
6
9
|
print("ℹ️ Skipping StressTestCheck (disabled in rules.yaml)")
|
|
7
10
|
return {"StressTestCheck": {"skipped": True}}
|
|
8
11
|
|
|
9
12
|
try:
|
|
10
|
-
epsilon = cfg.get("epsilon", 0.01)
|
|
13
|
+
epsilon = cfg.get("epsilon", 0.01)
|
|
11
14
|
perturb_fraction = cfg.get("perturb_fraction", 0.2)
|
|
12
15
|
|
|
13
|
-
|
|
16
|
+
|
|
17
|
+
cols_test = getattr(X_test, "columns", None)
|
|
18
|
+
cols_train = getattr(X_train, "columns", None)
|
|
19
|
+
|
|
20
|
+
if cols_test is not None:
|
|
21
|
+
columns = list(cols_test)
|
|
22
|
+
elif cols_train is not None:
|
|
23
|
+
columns = list(cols_train)
|
|
24
|
+
else:
|
|
25
|
+
columns = None
|
|
26
|
+
|
|
27
|
+
X_test_df = pd.DataFrame(X_test, columns=columns)
|
|
28
|
+
|
|
29
|
+
checker = StressTestCheck(model, X_test_df, y_test, epsilon, perturb_fraction)
|
|
14
30
|
result = checker.run()
|
|
15
31
|
|
|
16
|
-
# Ensure output is always a dictionary
|
|
17
32
|
if isinstance(result, list):
|
|
18
|
-
|
|
33
|
+
table = result
|
|
19
34
|
elif hasattr(result, "to_dict"):
|
|
20
|
-
|
|
35
|
+
table = result.to_dict(orient="records")
|
|
21
36
|
else:
|
|
22
37
|
return {"StressTestCheck": {"output": result}}
|
|
23
38
|
|
|
39
|
+
return {"StressTestCheck": {"table": table}}
|
|
40
|
+
|
|
24
41
|
except Exception as e:
|
|
25
42
|
print(f"⚠️ StressTestCheck failed: {e}")
|
|
26
43
|
return {"StressTestCheck": {"error": str(e)}}
|
|
@@ -1,27 +1,40 @@
|
|
|
1
1
|
# tanml/check_runners/vif_runner.py
|
|
2
|
+
from __future__ import annotations
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
import pandas as pd
|
|
4
|
+
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from tanml.utils.data_loader import load_dataframe
|
|
9
|
+
from tanml.checks.vif import VIFCheck
|
|
10
|
+
|
|
6
11
|
|
|
7
12
|
def VIFCheckRunner(
|
|
8
13
|
model, X_train, X_test, y_train, y_test,
|
|
9
14
|
rule_config, cleaned_df, *args, **kwargs
|
|
10
15
|
):
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
"""
|
|
17
|
+
Ensure cleaned_df is a DataFrame; if a path (csv/xlsx/parquet/etc.),
|
|
18
|
+
load it via the universal loader, then run VIFCheck.
|
|
19
|
+
"""
|
|
20
|
+
# 1) Normalize cleaned_df to a DataFrame
|
|
21
|
+
try:
|
|
22
|
+
if isinstance(cleaned_df, (str, bytes, os.PathLike, Path)):
|
|
23
|
+
cleaned_df = load_dataframe(cleaned_df)
|
|
24
|
+
elif not isinstance(cleaned_df, pd.DataFrame):
|
|
25
|
+
msg = "cleaned_df is not a DataFrame or loadable path; skipping VIF."
|
|
26
|
+
print(f"ℹ️ {msg}")
|
|
27
|
+
return {"vif_table": [], "high_vif_features": [], "error": msg}
|
|
28
|
+
except Exception as e:
|
|
29
|
+
err = f"Could not load cleaned_df: {e}"
|
|
30
|
+
print(f"⚠️ {err}")
|
|
31
|
+
return {"vif_table": [], "high_vif_features": [], "error": err}
|
|
19
32
|
|
|
33
|
+
# 2) Run the check
|
|
20
34
|
try:
|
|
21
35
|
check = VIFCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
|
|
22
|
-
result = check.run()
|
|
23
|
-
|
|
24
|
-
# Normalize result regardless of format
|
|
36
|
+
result = check.run()
|
|
37
|
+
# 3) Normalize result
|
|
25
38
|
if isinstance(result, dict) and "vif_table" in result:
|
|
26
39
|
vif_rows = result["vif_table"]
|
|
27
40
|
elif isinstance(result, list):
|
|
@@ -29,25 +42,18 @@ def VIFCheckRunner(
|
|
|
29
42
|
else:
|
|
30
43
|
raise ValueError("Unexpected VIFCheck return shape")
|
|
31
44
|
|
|
32
|
-
#
|
|
45
|
+
# 4) Canonicalize keys and values
|
|
33
46
|
for row in vif_rows:
|
|
34
47
|
if "Feature" not in row and "feature" in row:
|
|
35
48
|
row["Feature"] = row.pop("feature")
|
|
36
|
-
|
|
49
|
+
if "VIF" in row and row["VIF"] is not None:
|
|
50
|
+
row["VIF"] = round(float(row["VIF"]), 2)
|
|
37
51
|
|
|
38
|
-
# Identify high VIF features
|
|
52
|
+
# 5) Identify high VIF features
|
|
39
53
|
threshold = rule_config.get("vif_threshold", 5)
|
|
40
|
-
high_vif = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
]
|
|
44
|
-
|
|
45
|
-
# Return final output
|
|
46
|
-
return {
|
|
47
|
-
"vif_table": vif_rows,
|
|
48
|
-
"high_vif_features": high_vif,
|
|
49
|
-
"error": None,
|
|
50
|
-
}
|
|
54
|
+
high_vif = [r["Feature"] for r in vif_rows if r.get("VIF") is not None and r["VIF"] > threshold]
|
|
55
|
+
|
|
56
|
+
return {"vif_table": vif_rows, "high_vif_features": high_vif, "error": None}
|
|
51
57
|
|
|
52
58
|
except Exception as e:
|
|
53
59
|
print(f"⚠️ VIFCheck failed: {e}")
|
tanml/checks/correlation.py
CHANGED
|
@@ -1,61 +1,261 @@
|
|
|
1
|
-
|
|
1
|
+
# checks/correlation.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Tuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
2
9
|
import pandas as pd
|
|
3
|
-
import seaborn as sns
|
|
4
10
|
import matplotlib.pyplot as plt
|
|
5
|
-
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from scipy.cluster.hierarchy import linkage, leaves_list
|
|
14
|
+
from scipy.spatial.distance import squareform
|
|
15
|
+
_HAS_SCIPY = True
|
|
16
|
+
except Exception:
|
|
17
|
+
_HAS_SCIPY = False
|
|
18
|
+
|
|
19
|
+
from .base import BaseCheck
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
DEFAULT_CFG = {
|
|
23
|
+
"method": "pearson", # "pearson" | "spearman"
|
|
24
|
+
"high_corr_threshold": 0.80, # |r| >= threshold flagged
|
|
25
|
+
"top_pairs_max": 200, # rows in the "main" table CSV
|
|
26
|
+
"heatmap_max_features_default": 20, # default plotted features
|
|
27
|
+
"heatmap_max_features_limit": 60, # max allowed via UI/slider
|
|
28
|
+
"subset_strategy": "cluster", # "cluster" | "degree"
|
|
29
|
+
"sample_rows": 150_000, # downsample for speed on huge data
|
|
30
|
+
"seed": 42,
|
|
31
|
+
"save_csv": True,
|
|
32
|
+
"save_fig": True,
|
|
33
|
+
"appendix_csv_cap": None, # None = no cap; or int (e.g., 5000)
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _ensure_dir(p: Path) -> None:
|
|
38
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _numeric_columns(df: pd.DataFrame) -> List[str]:
|
|
42
|
+
return [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _drop_constant_columns(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
|
46
|
+
keep, dropped = [], []
|
|
47
|
+
for c in df.columns:
|
|
48
|
+
s = df[c]
|
|
49
|
+
if s.dropna().nunique() <= 1:
|
|
50
|
+
dropped.append(c)
|
|
51
|
+
else:
|
|
52
|
+
keep.append(c)
|
|
53
|
+
return df[keep], dropped
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _subset_by_degree(corr_abs: pd.DataFrame, max_feats: int) -> List[str]:
|
|
57
|
+
if corr_abs.shape[0] <= max_feats:
|
|
58
|
+
return list(corr_abs.index)
|
|
59
|
+
scores = corr_abs.sum().sort_values(ascending=False)
|
|
60
|
+
return list(scores.head(max_feats).index)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _subset_by_cluster(corr_abs: pd.DataFrame, max_feats: int) -> List[str]:
|
|
64
|
+
if corr_abs.shape[0] <= max_feats:
|
|
65
|
+
return list(corr_abs.index)
|
|
66
|
+
if not _HAS_SCIPY:
|
|
67
|
+
return _subset_by_degree(corr_abs, max_feats)
|
|
68
|
+
# distance = 1 - |corr|
|
|
69
|
+
dist = 1.0 - corr_abs
|
|
70
|
+
dist = (dist + dist.T) / 2.0
|
|
71
|
+
np.fill_diagonal(dist.values, 0.0)
|
|
72
|
+
Z = linkage(squareform(dist.values, checks=False), method="average")
|
|
73
|
+
order = leaves_list(Z)
|
|
74
|
+
ordered = corr_abs.index[order]
|
|
75
|
+
step = max(1, len(ordered) // max_feats)
|
|
76
|
+
return list(ordered[::step][:max_feats])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _render_heatmap(corr: pd.DataFrame, out_path: Path, title: str) -> None:
|
|
80
|
+
fig, ax = plt.subplots(figsize=(10, 8))
|
|
81
|
+
im = ax.imshow(corr.values, vmin=-1, vmax=1)
|
|
82
|
+
ax.set_xticks(range(corr.shape[1]))
|
|
83
|
+
ax.set_yticks(range(corr.shape[0]))
|
|
84
|
+
ax.set_xticklabels(corr.columns, rotation=90, fontsize=6)
|
|
85
|
+
ax.set_yticklabels(corr.index, fontsize=6)
|
|
86
|
+
ax.set_title(title)
|
|
87
|
+
cbar = plt.colorbar(im, ax=ax)
|
|
88
|
+
cbar.set_label("Correlation")
|
|
89
|
+
fig.tight_layout()
|
|
90
|
+
fig.savefig(out_path, dpi=200)
|
|
91
|
+
plt.close(fig)
|
|
92
|
+
|
|
6
93
|
|
|
7
94
|
class CorrelationCheck(BaseCheck):
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
95
|
+
"""
|
|
96
|
+
Numeric-only correlation analysis:
|
|
97
|
+
• Pearson or Spearman (pairwise complete obs)
|
|
98
|
+
• Heatmap on ≤20 features by default (clustered subset up to 60 max)
|
|
99
|
+
• CSV of high-correlation pairs (|r| ≥ threshold), sorted by |r|
|
|
100
|
+
• Handles constant/all-NA columns, optional sampling for speed
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
cleaned_data: pd.DataFrame,
|
|
106
|
+
cfg: Dict | None = None,
|
|
107
|
+
output_dir: str = "reports/correlation",
|
|
108
|
+
):
|
|
13
109
|
self.cleaned_data = cleaned_data
|
|
110
|
+
self.cfg = {**DEFAULT_CFG, **(cfg or {})}
|
|
14
111
|
self.output_dir = output_dir
|
|
15
112
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
16
113
|
|
|
114
|
+
def _top_corr_pairs(self, corr: pd.DataFrame, thr: float) -> pd.DataFrame:
|
|
115
|
+
a = corr.copy()
|
|
116
|
+
np.fill_diagonal(a.values, np.nan)
|
|
117
|
+
s = a.stack().reset_index()
|
|
118
|
+
s.columns = ["feature_i", "feature_j", "corr"]
|
|
119
|
+
s = s.dropna()
|
|
120
|
+
# remove duplicate symmetric pairs
|
|
121
|
+
s["pair"] = s.apply(lambda r: tuple(sorted([r["feature_i"], r["feature_j"]])), axis=1)
|
|
122
|
+
s = s.drop_duplicates(subset=["pair"]).drop(columns=["pair"])
|
|
123
|
+
s["abs_corr"] = s["corr"].abs()
|
|
124
|
+
s = s[s["abs_corr"] >= thr].sort_values("abs_corr", ascending=False)
|
|
125
|
+
return s
|
|
126
|
+
|
|
17
127
|
def run(self):
|
|
18
|
-
|
|
19
|
-
|
|
128
|
+
cfg = self.cfg
|
|
129
|
+
method = cfg["method"]
|
|
130
|
+
thr = float(cfg["high_corr_threshold"])
|
|
131
|
+
top_cap = int(cfg["top_pairs_max"])
|
|
132
|
+
default_cap = int(cfg["heatmap_max_features_default"])
|
|
133
|
+
max_cap = int(cfg["heatmap_max_features_limit"])
|
|
134
|
+
subset_strategy = cfg["subset_strategy"]
|
|
135
|
+
sample_rows = int(cfg["sample_rows"])
|
|
136
|
+
seed = int(cfg["seed"])
|
|
137
|
+
save_csv = bool(cfg["save_csv"])
|
|
138
|
+
save_fig = bool(cfg["save_fig"])
|
|
139
|
+
appendix_cap = cfg.get("appendix_csv_cap", None)
|
|
140
|
+
|
|
141
|
+
# ===== 1) Select numeric & (optional) sample rows =====
|
|
142
|
+
X = self.cleaned_data.copy()
|
|
143
|
+
if len(X) > sample_rows:
|
|
144
|
+
X = X.sample(sample_rows, random_state=seed)
|
|
145
|
+
|
|
146
|
+
num_cols = _numeric_columns(X)
|
|
147
|
+
X = X[num_cols]
|
|
148
|
+
X, dropped_constants = _drop_constant_columns(X)
|
|
20
149
|
|
|
21
|
-
if
|
|
22
|
-
|
|
150
|
+
if X.shape[1] < 2:
|
|
151
|
+
msg = "⚠️ Not enough numeric features for correlation."
|
|
152
|
+
print(msg)
|
|
23
153
|
return {
|
|
24
154
|
"pearson_csv": None,
|
|
25
155
|
"spearman_csv": None,
|
|
26
156
|
"heatmap_path": None,
|
|
157
|
+
"top_pairs_csv": None,
|
|
158
|
+
"summary": {"n_numeric_features": X.shape[1]},
|
|
159
|
+
"notes": [msg, f"Dropped constant/all-NA columns: {dropped_constants}"] if dropped_constants else [msg],
|
|
27
160
|
"error": "Not enough numeric features for correlation",
|
|
28
161
|
}
|
|
29
162
|
|
|
30
|
-
#
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
163
|
+
# ===== 2) Correlation matrix =====
|
|
164
|
+
# Compute both; pick one to drive plotting/threshold logic
|
|
165
|
+
corr_pearson = X.corr(method="pearson")
|
|
166
|
+
corr_spearman = X.corr(method="spearman")
|
|
167
|
+
corr = corr_pearson if method == "pearson" else corr_spearman
|
|
168
|
+
corr_abs = corr.abs()
|
|
169
|
+
|
|
170
|
+
# ===== 3) High-correlation pairs CSV =====
|
|
171
|
+
pairs = self._top_corr_pairs(corr, thr)
|
|
172
|
+
# augment with pairwise n_used and feature missingness %
|
|
173
|
+
non_null_counts = X.notna().sum()
|
|
174
|
+
total_rows = len(X)
|
|
175
|
+
if not pairs.empty:
|
|
176
|
+
pairs["n_used"] = pairs.apply(
|
|
177
|
+
lambda r: X[[r["feature_i"], r["feature_j"]]].dropna().shape[0], axis=1
|
|
178
|
+
)
|
|
179
|
+
pairs["pct_missing_i"] = pairs.apply(
|
|
180
|
+
lambda r: 1 - non_null_counts[r["feature_i"]] / total_rows, axis=1
|
|
181
|
+
)
|
|
182
|
+
pairs["pct_missing_j"] = pairs.apply(
|
|
183
|
+
lambda r: 1 - non_null_counts[r["feature_j"]] / total_rows, axis=1
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
artifacts: Dict[str, str] = {}
|
|
187
|
+
outdir = Path(self.output_dir)
|
|
188
|
+
_ensure_dir(outdir)
|
|
189
|
+
|
|
190
|
+
# Save full correlation matrices (if enabled)
|
|
191
|
+
pearson_csv_path = outdir / "pearson_corr.csv"
|
|
192
|
+
spearman_csv_path = outdir / "spearman_corr.csv"
|
|
193
|
+
if save_csv:
|
|
194
|
+
corr_pearson.to_csv(pearson_csv_path, index=True)
|
|
195
|
+
corr_spearman.to_csv(spearman_csv_path, index=True)
|
|
196
|
+
|
|
197
|
+
# Save top-pairs CSVs (main + full/appendix)
|
|
198
|
+
if save_csv:
|
|
199
|
+
full_csv = outdir / "correlation_top_pairs.csv"
|
|
200
|
+
if appendix_cap is not None:
|
|
201
|
+
pairs.head(int(appendix_cap)).to_csv(full_csv, index=False)
|
|
202
|
+
else:
|
|
203
|
+
pairs.to_csv(full_csv, index=False)
|
|
204
|
+
artifacts["top_pairs_csv"] = str(full_csv)
|
|
205
|
+
|
|
206
|
+
main_csv = outdir / "correlation_top_pairs_main.csv"
|
|
207
|
+
pairs.head(top_cap).to_csv(main_csv, index=False)
|
|
208
|
+
artifacts["top_pairs_main_csv"] = str(main_csv)
|
|
209
|
+
|
|
210
|
+
# ===== 4) Adaptive heatmap =====
|
|
211
|
+
n_features_total = X.shape[1]
|
|
212
|
+
plotted_full_matrix = n_features_total <= default_cap
|
|
213
|
+
|
|
214
|
+
if not plotted_full_matrix:
|
|
215
|
+
cap = min(max_cap, n_features_total)
|
|
216
|
+
if subset_strategy == "cluster" and _HAS_SCIPY:
|
|
217
|
+
subset = _subset_by_cluster(corr_abs, cap)
|
|
218
|
+
else:
|
|
219
|
+
subset = _subset_by_degree(corr_abs, cap)
|
|
220
|
+
corr_plot = corr.loc[subset, subset]
|
|
221
|
+
title = f"Correlation Heatmap ({method}) — {len(subset)}/{n_features_total} features (subset)"
|
|
222
|
+
else:
|
|
223
|
+
corr_plot = corr
|
|
224
|
+
title = f"Correlation Heatmap ({method}) — full matrix ({n_features_total} features)"
|
|
225
|
+
|
|
226
|
+
heatmap_path = None
|
|
227
|
+
if save_fig:
|
|
228
|
+
heatmap_path = outdir / "heatmap.png"
|
|
229
|
+
_render_heatmap(corr_plot, heatmap_path, title)
|
|
230
|
+
artifacts["heatmap_path"] = str(heatmap_path)
|
|
231
|
+
|
|
232
|
+
# ===== 5) Summary/notes =====
|
|
233
|
+
n_pairs_total = n_features_total * (n_features_total - 1) // 2
|
|
234
|
+
n_pairs_flagged = int(pairs.shape[0]) if not pairs.empty else 0
|
|
235
|
+
notes = []
|
|
236
|
+
if dropped_constants:
|
|
237
|
+
notes.append(f"Dropped constant/all-NA columns: {sorted(dropped_constants)}")
|
|
238
|
+
if len(self.cleaned_data) > sample_rows:
|
|
239
|
+
notes.append(f"Computed on a {sample_rows}-row sample (seed={seed}).")
|
|
240
|
+
if not plotted_full_matrix:
|
|
241
|
+
notes.append(
|
|
242
|
+
f"Heatmap shows a subset ({corr_plot.shape[0]}/{n_features_total}); see CSV for full list of pairs."
|
|
243
|
+
)
|
|
56
244
|
|
|
57
245
|
return {
|
|
58
|
-
"pearson_csv":
|
|
59
|
-
"spearman_csv":
|
|
60
|
-
"heatmap_path": heatmap_path,
|
|
246
|
+
"pearson_csv": str(pearson_csv_path) if save_csv else None,
|
|
247
|
+
"spearman_csv": str(spearman_csv_path) if save_csv else None,
|
|
248
|
+
"heatmap_path": str(heatmap_path) if heatmap_path else None,
|
|
249
|
+
"top_pairs_csv": artifacts.get("top_pairs_csv"),
|
|
250
|
+
"top_pairs_main_csv": artifacts.get("top_pairs_main_csv"),
|
|
251
|
+
"summary": {
|
|
252
|
+
"n_numeric_features": int(n_features_total),
|
|
253
|
+
"n_pairs_total": int(n_pairs_total),
|
|
254
|
+
"n_pairs_flagged_ge_threshold": int(n_pairs_flagged),
|
|
255
|
+
"threshold": float(thr),
|
|
256
|
+
"method": method,
|
|
257
|
+
"plotted_features": int(corr_plot.shape[0]),
|
|
258
|
+
"plotted_full_matrix": bool(plotted_full_matrix),
|
|
259
|
+
},
|
|
260
|
+
"notes": notes,
|
|
61
261
|
}
|