tanml 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -0
- tanml/check_runners/__init__.py +0 -0
- tanml/check_runners/base_runner.py +6 -0
- tanml/check_runners/cleaning_repro_runner.py +18 -0
- tanml/check_runners/correlation_runner.py +15 -0
- tanml/check_runners/data_quality_runner.py +24 -0
- tanml/check_runners/eda_runner.py +21 -0
- tanml/check_runners/explainability_runner.py +28 -0
- tanml/check_runners/input_cluster_runner.py +43 -0
- tanml/check_runners/logistic_stats_runner.py +28 -0
- tanml/check_runners/model_meta_runner.py +23 -0
- tanml/check_runners/performance_runner.py +28 -0
- tanml/check_runners/raw_data_runner.py +41 -0
- tanml/check_runners/rule_engine_runner.py +5 -0
- tanml/check_runners/stress_test_runner.py +26 -0
- tanml/check_runners/vif_runner.py +54 -0
- tanml/checks/__init__.py +0 -0
- tanml/checks/base.py +20 -0
- tanml/checks/cleaning_repro.py +47 -0
- tanml/checks/correlation.py +61 -0
- tanml/checks/data_quality.py +26 -0
- tanml/checks/eda.py +67 -0
- tanml/checks/explainability/shap_check.py +55 -0
- tanml/checks/input_cluster.py +109 -0
- tanml/checks/logit_stats.py +59 -0
- tanml/checks/model_contents.py +40 -0
- tanml/checks/model_meta.py +50 -0
- tanml/checks/performance.py +90 -0
- tanml/checks/raw_data.py +47 -0
- tanml/checks/rule_engine.py +45 -0
- tanml/checks/stress_test.py +64 -0
- tanml/checks/vif.py +51 -0
- tanml/cli/__init__.py +0 -0
- tanml/cli/arg_parser.py +31 -0
- tanml/cli/init_cmd.py +8 -0
- tanml/cli/main.py +27 -0
- tanml/cli/validate_cmd.py +7 -0
- tanml/config_templates/__init__.py +0 -0
- tanml/config_templates/rules_multiple_models_datasets.yaml +144 -0
- tanml/config_templates/rules_one_dataset_segment_column.yaml +140 -0
- tanml/config_templates/rules_one_model_one_dataset.yaml +143 -0
- tanml/engine/__init__.py +0 -0
- tanml/engine/check_agent_registry.py +42 -0
- tanml/engine/core_engine_agent.py +115 -0
- tanml/engine/segmentation_agent.py +118 -0
- tanml/engine/validation_agent.py +91 -0
- tanml/report/report_builder.py +230 -0
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/__init__.py +0 -0
- tanml/utils/data_loader.py +17 -0
- tanml/utils/model_loader.py +35 -0
- tanml/utils/r_loader.py +30 -0
- tanml/utils/sas_loader.py +50 -0
- tanml/utils/yaml_generator.py +34 -0
- tanml/utils/yaml_loader.py +5 -0
- tanml/validate.py +209 -0
- tanml-0.1.6.dist-info/METADATA +317 -0
- tanml-0.1.6.dist-info/RECORD +62 -0
- tanml-0.1.6.dist-info/WHEEL +5 -0
- tanml-0.1.6.dist-info/entry_points.txt +2 -0
- tanml-0.1.6.dist-info/licenses/LICENSE +21 -0
- tanml-0.1.6.dist-info/top_level.txt +1 -0
tanml/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.1"
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# tanml/check_runners/cleaning_repro_runner.py
|
|
2
|
+
from tanml.checks.cleaning_repro import CleaningReproCheck
|
|
3
|
+
|
|
4
|
+
def run_cleaning_repro_check(model, X_train, X_test, y_train, y_test,
|
|
5
|
+
config, cleaned_data, *args, **kwargs):
|
|
6
|
+
# honour rules.yaml toggle
|
|
7
|
+
if not config.get("rules", {}).get("CleaningReproCheck", {}).get("enabled", True):
|
|
8
|
+
print("ℹ️ CleaningReproCheck skipped (disabled in rules.yaml)")
|
|
9
|
+
return None
|
|
10
|
+
|
|
11
|
+
# raw_df can come from rules.yaml *or* via kwargs (passed by ValidationEngine)
|
|
12
|
+
raw_data = config.get("raw_data") or kwargs.get("raw_df")
|
|
13
|
+
if raw_data is None:
|
|
14
|
+
print("⚠️ Skipping CleaningReproCheck — raw_data missing in config and kwargs")
|
|
15
|
+
return {"CleaningReproCheck": {"error": "raw_data not available"}}
|
|
16
|
+
|
|
17
|
+
check = CleaningReproCheck(raw_data, cleaned_data)
|
|
18
|
+
return {"CleaningReproCheck": check.run()}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from tanml.checks.correlation import CorrelationCheck
|
|
2
|
+
|
|
3
|
+
def CorrelationCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
try:
|
|
5
|
+
cfg = rule_config.get("CorrelationCheck", {})
|
|
6
|
+
if not cfg.get("enabled", True):
|
|
7
|
+
print("ℹ️ CorrelationCheck skipped (disabled in rules.yaml)")
|
|
8
|
+
return None
|
|
9
|
+
|
|
10
|
+
check = CorrelationCheck(cleaned_df)
|
|
11
|
+
return check.run()
|
|
12
|
+
|
|
13
|
+
except Exception as e:
|
|
14
|
+
print(f"⚠️ CorrelationCheck failed: {e}")
|
|
15
|
+
return {"CorrelationCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from tanml.checks.data_quality import DataQualityCheck
|
|
2
|
+
|
|
3
|
+
def run_data_quality_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
dq_cfg = rule_config.get("DataQualityCheck", {})
|
|
5
|
+
if not dq_cfg.get("enabled", True):
|
|
6
|
+
print("ℹ️ Skipping DataQualityCheck (disabled in rules.yaml)")
|
|
7
|
+
return {"DataQualityCheck": {"skipped": True}}
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
check = DataQualityCheck(
|
|
11
|
+
model=model,
|
|
12
|
+
X_train=X_train,
|
|
13
|
+
X_test=X_test,
|
|
14
|
+
y_train=y_train,
|
|
15
|
+
y_test=y_test,
|
|
16
|
+
rule_config=dq_cfg,
|
|
17
|
+
cleaned_data=cleaned_df
|
|
18
|
+
)
|
|
19
|
+
result = check.run()
|
|
20
|
+
return {"DataQualityCheck": result}
|
|
21
|
+
|
|
22
|
+
except Exception as e:
|
|
23
|
+
print(f"⚠️ DataQualityCheck failed: {e}")
|
|
24
|
+
return {"DataQualityCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from tanml.checks.eda import EDACheck
|
|
2
|
+
|
|
3
|
+
def EDACheckRunner(
|
|
4
|
+
model, X_train, X_test, y_train, y_test,
|
|
5
|
+
rule_config, cleaned_df, *args, **kwargs
|
|
6
|
+
):
|
|
7
|
+
try:
|
|
8
|
+
cfg = rule_config.get("EDACheck", {})
|
|
9
|
+
if not cfg.get("enabled", True):
|
|
10
|
+
print("ℹ️ EDACheck skipped (disabled in rules.yaml)")
|
|
11
|
+
return None
|
|
12
|
+
|
|
13
|
+
check = EDACheck(
|
|
14
|
+
cleaned_data=cleaned_df,
|
|
15
|
+
rule_config=rule_config
|
|
16
|
+
)
|
|
17
|
+
return check.run()
|
|
18
|
+
|
|
19
|
+
except Exception as e:
|
|
20
|
+
print(f"⚠️ EDACheck failed: {e}")
|
|
21
|
+
return {"EDACheck": {"error": str(e)}}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# tanml/check_runners/explainability_runner.py
|
|
2
|
+
|
|
3
|
+
from tanml.checks.explainability.shap_check import SHAPCheck
|
|
4
|
+
|
|
5
|
+
def run_shap_check(
|
|
6
|
+
model, X_train, X_test, y_train, y_test,
|
|
7
|
+
rule_config, cleaned_df, *args, **kwargs
|
|
8
|
+
):
|
|
9
|
+
try:
|
|
10
|
+
cfg = rule_config.get("SHAPCheck", {})
|
|
11
|
+
if not cfg.get("enabled", True):
|
|
12
|
+
print("ℹ️ SHAPCheck skipped (disabled in rules.yaml)")
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
check = SHAPCheck(
|
|
16
|
+
model=model,
|
|
17
|
+
X_train=X_train,
|
|
18
|
+
X_test=X_test,
|
|
19
|
+
y_train=y_train,
|
|
20
|
+
y_test=y_test,
|
|
21
|
+
rule_config=rule_config,
|
|
22
|
+
cleaned_df=cleaned_df,
|
|
23
|
+
)
|
|
24
|
+
return check.run()
|
|
25
|
+
|
|
26
|
+
except Exception as e:
|
|
27
|
+
print(f"⚠️ SHAPCheck failed: {e}")
|
|
28
|
+
return {"SHAPCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
|
|
2
|
+
# -----------------------------------------------------------------------------
|
|
3
|
+
# File: tanml/check_runners/input_cluster_runner.py
|
|
4
|
+
# -----------------------------------------------------------------------------
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict
|
|
7
|
+
|
|
8
|
+
from tanml.checks.input_cluster import InputClusterCoverageCheck
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def run_input_cluster_check(
|
|
12
|
+
model,
|
|
13
|
+
X_train,
|
|
14
|
+
X_test,
|
|
15
|
+
y_train,
|
|
16
|
+
y_test,
|
|
17
|
+
rule_config,
|
|
18
|
+
cleaned_df,
|
|
19
|
+
expected_features,
|
|
20
|
+
*args,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
cfg = rule_config.get("InputClusterCoverageCheck", {})
|
|
25
|
+
if not cfg.get("enabled", False):
|
|
26
|
+
print("ℹ️ InputClusterCoverageCheck skipped (disabled in rules.yaml)")
|
|
27
|
+
return {"skipped": True}
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
result = InputClusterCoverageCheck(
|
|
31
|
+
cleaned_df=cleaned_df,
|
|
32
|
+
feature_names=expected_features,
|
|
33
|
+
rule_config=rule_config,
|
|
34
|
+
).run()
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"⚠️ InputClusterCoverageCheck failed: {e}")
|
|
39
|
+
return {"error": str(e)}
|
|
40
|
+
|
|
41
|
+
except Exception as e:
|
|
42
|
+
print(f"⚠️ InputClusterCoverageCheck failed: {e}")
|
|
43
|
+
return {"InputClusterCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from tanml.checks.logit_stats import LogisticStatsCheck
|
|
2
|
+
from sklearn.linear_model import LogisticRegression
|
|
3
|
+
|
|
4
|
+
def run_logistic_stats_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
5
|
+
try:
|
|
6
|
+
# Check if model is a LogisticRegression or statsmodels object
|
|
7
|
+
is_logistic = isinstance(model, LogisticRegression) or (
|
|
8
|
+
hasattr(model, "llf") and hasattr(model, "params") and hasattr(model, "summary")
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
if not is_logistic:
|
|
12
|
+
print("ℹ️ LogisticStatsCheck skipped — model is not logistic or not recognized")
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
# Use training data only for fitting stats
|
|
16
|
+
check = LogisticStatsCheck(model, X_train, y_train, rule_config)
|
|
17
|
+
output = check.run()
|
|
18
|
+
|
|
19
|
+
return {
|
|
20
|
+
"LogisticStatsCheck": output["table"],
|
|
21
|
+
"LogisticStatsFit": output["fit"],
|
|
22
|
+
"LogisticStatsSummary": output["summary"],
|
|
23
|
+
"LogisticStatsCheck_obj": output["object"]
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
except Exception as e:
|
|
27
|
+
print(f"⚠️ LogisticStatsCheck failed: {e}")
|
|
28
|
+
return {"LogisticStatsCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from tanml.checks.model_meta import ModelMetaCheck
|
|
2
|
+
|
|
3
|
+
def ModelMetaCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
try:
|
|
5
|
+
cfg = rule_config.get("ModelMetaCheck", {})
|
|
6
|
+
if not cfg.get("enabled", True):
|
|
7
|
+
print("ℹ️ ModelMetaCheck skipped (disabled in rules.yaml)")
|
|
8
|
+
return None
|
|
9
|
+
|
|
10
|
+
check = ModelMetaCheck(
|
|
11
|
+
model=model,
|
|
12
|
+
X_train=X_train,
|
|
13
|
+
X_test=X_test,
|
|
14
|
+
y_train=y_train,
|
|
15
|
+
y_test=y_test,
|
|
16
|
+
rule_config=rule_config,
|
|
17
|
+
cleaned_data=cleaned_df
|
|
18
|
+
)
|
|
19
|
+
return check.run()
|
|
20
|
+
|
|
21
|
+
except Exception as e:
|
|
22
|
+
print(f"⚠️ ModelMetaCheck failed: {e}")
|
|
23
|
+
return {"ModelMetaCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from tanml.checks.performance import PerformanceCheck
|
|
2
|
+
|
|
3
|
+
def run_performance_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
perf_cfg = rule_config.get("PerformanceCheck", {})
|
|
5
|
+
if not perf_cfg.get("enabled", True):
|
|
6
|
+
print("ℹ️ Skipping PerformanceCheck (disabled in rules.yaml)")
|
|
7
|
+
return {"PerformanceCheck": {"skipped": True}}
|
|
8
|
+
|
|
9
|
+
if X_test is None or y_test is None or len(X_test) == 0 or len(y_test) == 0:
|
|
10
|
+
print("⚠️ Skipping PerformanceCheck due to empty test data.")
|
|
11
|
+
return {"PerformanceCheck": {"error": "Test data is empty :skipping performance evaluation."}}
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
check = PerformanceCheck(
|
|
15
|
+
model=model,
|
|
16
|
+
X_train=X_train,
|
|
17
|
+
X_test=X_test,
|
|
18
|
+
y_train=y_train,
|
|
19
|
+
y_test=y_test,
|
|
20
|
+
rule_config=perf_cfg,
|
|
21
|
+
cleaned_data=cleaned_df
|
|
22
|
+
)
|
|
23
|
+
result = check.run()
|
|
24
|
+
return {"PerformanceCheck": result}
|
|
25
|
+
|
|
26
|
+
except Exception as e:
|
|
27
|
+
print(f"⚠️ PerformanceCheck failed: {e}")
|
|
28
|
+
return {"PerformanceCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from tanml.checks.raw_data import RawDataCheck
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
def run_raw_data_check(model, X_train, X_test, y_train, y_test,
|
|
5
|
+
rule_config, cleaned_data, *args, **kwargs):
|
|
6
|
+
try:
|
|
7
|
+
# ---- locate raw data (DF or path) -------------
|
|
8
|
+
raw_obj = (
|
|
9
|
+
rule_config.get("raw_data") or
|
|
10
|
+
rule_config.get("paths", {}).get("raw_data")
|
|
11
|
+
)
|
|
12
|
+
if raw_obj is None:
|
|
13
|
+
print("ℹ️ RawDataCheck skipped — raw_data not provided in config.")
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
# CSV path → load once
|
|
17
|
+
if isinstance(raw_obj, (str, bytes)):
|
|
18
|
+
raw_obj = pd.read_csv(raw_obj)
|
|
19
|
+
|
|
20
|
+
if not isinstance(raw_obj, pd.DataFrame):
|
|
21
|
+
print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame.")
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
# ---- run the check -----------------------------
|
|
25
|
+
check = RawDataCheck(
|
|
26
|
+
model=model,
|
|
27
|
+
X_train=X_train,
|
|
28
|
+
X_test=X_test,
|
|
29
|
+
y_train=y_train,
|
|
30
|
+
y_test=y_test,
|
|
31
|
+
rule_config=rule_config,
|
|
32
|
+
cleaned_data=cleaned_data,
|
|
33
|
+
raw_data=raw_obj
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
stats = check.run()
|
|
37
|
+
return stats["RawDataCheck"] # hand the inner dict to ValidationEngine
|
|
38
|
+
|
|
39
|
+
except Exception as e:
|
|
40
|
+
print(f"⚠️ RawDataCheck failed: {e}")
|
|
41
|
+
return {"error": str(e)}
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from tanml.checks.rule_engine import RuleEngineCheck
|
|
2
|
+
|
|
3
|
+
def RuleEngineCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
check = RuleEngineCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
|
|
5
|
+
return check.run()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from tanml.checks.stress_test import StressTestCheck
|
|
2
|
+
|
|
3
|
+
def run_stress_test_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
|
|
4
|
+
cfg = rule_config.get("StressTestCheck", {})
|
|
5
|
+
if not cfg.get("enabled", True):
|
|
6
|
+
print("ℹ️ Skipping StressTestCheck (disabled in rules.yaml)")
|
|
7
|
+
return {"StressTestCheck": {"skipped": True}}
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
epsilon = cfg.get("epsilon", 0.01)
|
|
11
|
+
perturb_fraction = cfg.get("perturb_fraction", 0.2)
|
|
12
|
+
|
|
13
|
+
checker = StressTestCheck(model, X_test, y_test, epsilon, perturb_fraction)
|
|
14
|
+
result = checker.run()
|
|
15
|
+
|
|
16
|
+
# Ensure output is always a dictionary
|
|
17
|
+
if isinstance(result, list):
|
|
18
|
+
return {"StressTestCheck": {"table": result}}
|
|
19
|
+
elif hasattr(result, "to_dict"):
|
|
20
|
+
return {"StressTestCheck": {"table": result.to_dict(orient="records")}}
|
|
21
|
+
else:
|
|
22
|
+
return {"StressTestCheck": {"output": result}}
|
|
23
|
+
|
|
24
|
+
except Exception as e:
|
|
25
|
+
print(f"⚠️ StressTestCheck failed: {e}")
|
|
26
|
+
return {"StressTestCheck": {"error": str(e)}}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# tanml/check_runners/vif_runner.py
|
|
2
|
+
|
|
3
|
+
from tanml.checks.vif import VIFCheck
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
def VIFCheckRunner(
|
|
8
|
+
model, X_train, X_test, y_train, y_test,
|
|
9
|
+
rule_config, cleaned_df, *args, **kwargs
|
|
10
|
+
):
|
|
11
|
+
# Ensure cleaned_df is a DataFrame
|
|
12
|
+
if isinstance(cleaned_df, (str, Path)):
|
|
13
|
+
try:
|
|
14
|
+
cleaned_df = pd.read_csv(cleaned_df)
|
|
15
|
+
except Exception as e:
|
|
16
|
+
err = f"Could not read cleaned_df CSV: {e}"
|
|
17
|
+
print(f"⚠️ {err}")
|
|
18
|
+
return {"vif_table": [], "high_vif_features": [], "error": err}
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
check = VIFCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
|
|
22
|
+
result = check.run() # Could be dict or list
|
|
23
|
+
|
|
24
|
+
# Normalize result regardless of format
|
|
25
|
+
if isinstance(result, dict) and "vif_table" in result:
|
|
26
|
+
vif_rows = result["vif_table"]
|
|
27
|
+
elif isinstance(result, list):
|
|
28
|
+
vif_rows = result
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError("Unexpected VIFCheck return shape")
|
|
31
|
+
|
|
32
|
+
# Rename 'feature' to 'Feature', round VIF values
|
|
33
|
+
for row in vif_rows:
|
|
34
|
+
if "Feature" not in row and "feature" in row:
|
|
35
|
+
row["Feature"] = row.pop("feature")
|
|
36
|
+
row["VIF"] = round(float(row["VIF"]), 2)
|
|
37
|
+
|
|
38
|
+
# Identify high VIF features
|
|
39
|
+
threshold = rule_config.get("vif_threshold", 5)
|
|
40
|
+
high_vif = [
|
|
41
|
+
row["Feature"] for row in vif_rows
|
|
42
|
+
if row.get("VIF") is not None and row["VIF"] > threshold
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
# Return final output
|
|
46
|
+
return {
|
|
47
|
+
"vif_table": vif_rows,
|
|
48
|
+
"high_vif_features": high_vif,
|
|
49
|
+
"error": None,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f"⚠️ VIFCheck failed: {e}")
|
|
54
|
+
return {"vif_table": [], "high_vif_features": [], "error": str(e)}
|
tanml/checks/__init__.py
ADDED
|
File without changes
|
tanml/checks/base.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
class BaseCheck(ABC):
|
|
4
|
+
def __init__(self, model, X_train, X_test, y_train, y_test, rule_config, cleaned_data):
|
|
5
|
+
self.model = model
|
|
6
|
+
self.X_train = X_train
|
|
7
|
+
self.X_test = X_test
|
|
8
|
+
self.y_train = y_train
|
|
9
|
+
self.y_test = y_test
|
|
10
|
+
self.rule_config = rule_config
|
|
11
|
+
self.cleaned_data = cleaned_data
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def run(self):
|
|
17
|
+
"""
|
|
18
|
+
This method must be implemented by every check.
|
|
19
|
+
It should return a dictionary of results.
|
|
20
|
+
"""
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
|
|
3
|
+
class CleaningReproCheck:
|
|
4
|
+
def __init__(self, raw_data, cleaned_data):
|
|
5
|
+
self.raw_data = raw_data
|
|
6
|
+
self.cleaned_data = cleaned_data
|
|
7
|
+
|
|
8
|
+
def run(self):
|
|
9
|
+
result = {}
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
raw_df = self.raw_data
|
|
13
|
+
cleaned_df = self.cleaned_data
|
|
14
|
+
|
|
15
|
+
if not isinstance(raw_df, pd.DataFrame) or not isinstance(cleaned_df, pd.DataFrame):
|
|
16
|
+
raise ValueError("Missing or invalid raw_data / cleaned_data.")
|
|
17
|
+
|
|
18
|
+
validator_df = (
|
|
19
|
+
raw_df
|
|
20
|
+
.drop(columns=["constant_col"], errors="ignore")
|
|
21
|
+
.drop_duplicates()
|
|
22
|
+
.reset_index(drop=True)
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
result["dev_shape"] = cleaned_df.shape
|
|
26
|
+
result["validator_shape"] = validator_df.shape
|
|
27
|
+
result["same_shape"] = cleaned_df.shape == validator_df.shape
|
|
28
|
+
|
|
29
|
+
dev_cols = set(cleaned_df.columns)
|
|
30
|
+
val_cols = set(validator_df.columns)
|
|
31
|
+
result["extra_columns_in_dev"] = sorted(list(dev_cols - val_cols))
|
|
32
|
+
result["missing_columns_in_dev"] = sorted(list(val_cols - dev_cols))
|
|
33
|
+
|
|
34
|
+
common = cleaned_df.columns.intersection(validator_df.columns)
|
|
35
|
+
if cleaned_df[common].shape == validator_df[common].shape:
|
|
36
|
+
mismatch_count = (
|
|
37
|
+
cleaned_df[common].reset_index(drop=True) !=
|
|
38
|
+
validator_df[common].reset_index(drop=True)
|
|
39
|
+
).to_numpy().sum()
|
|
40
|
+
result["cell_mismatches"] = int(mismatch_count)
|
|
41
|
+
else:
|
|
42
|
+
result["cell_mismatches"] = "Column/shape mismatch — cannot compare"
|
|
43
|
+
|
|
44
|
+
except Exception as e:
|
|
45
|
+
result["error"] = str(e)
|
|
46
|
+
|
|
47
|
+
return result
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from .base import BaseCheck
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
class CorrelationCheck(BaseCheck):
|
|
8
|
+
def __init__(self, cleaned_data: pd.DataFrame, output_dir: str = "reports/correlation"):
|
|
9
|
+
"""
|
|
10
|
+
Computes Pearson and Spearman correlation matrices and saves them to disk,
|
|
11
|
+
along with a heatmap for visualization.
|
|
12
|
+
"""
|
|
13
|
+
self.cleaned_data = cleaned_data
|
|
14
|
+
self.output_dir = output_dir
|
|
15
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
16
|
+
|
|
17
|
+
def run(self):
|
|
18
|
+
# Select numeric features only
|
|
19
|
+
numeric_data = self.cleaned_data.select_dtypes(include="number")
|
|
20
|
+
|
|
21
|
+
if numeric_data.shape[1] < 2:
|
|
22
|
+
print("⚠️ Not enough numeric features for correlation.")
|
|
23
|
+
return {
|
|
24
|
+
"pearson_csv": None,
|
|
25
|
+
"spearman_csv": None,
|
|
26
|
+
"heatmap_path": None,
|
|
27
|
+
"error": "Not enough numeric features for correlation",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
# Compute correlations
|
|
31
|
+
pearson_corr = numeric_data.corr(method="pearson")
|
|
32
|
+
spearman_corr = numeric_data.corr(method="spearman")
|
|
33
|
+
|
|
34
|
+
# Save CSVs
|
|
35
|
+
pearson_path = os.path.join(self.output_dir, "pearson_corr.csv")
|
|
36
|
+
spearman_path = os.path.join(self.output_dir, "spearman_corr.csv")
|
|
37
|
+
pearson_corr.to_csv(pearson_path)
|
|
38
|
+
spearman_corr.to_csv(spearman_path)
|
|
39
|
+
|
|
40
|
+
# Create heatmap
|
|
41
|
+
heatmap_path = os.path.join(self.output_dir, "heatmap.png")
|
|
42
|
+
plt.figure(figsize=(10, 8))
|
|
43
|
+
sns.heatmap(
|
|
44
|
+
pearson_corr,
|
|
45
|
+
annot=True,
|
|
46
|
+
fmt=".2f",
|
|
47
|
+
cmap="coolwarm",
|
|
48
|
+
cbar_kws={"label": "Pearson Coefficient"},
|
|
49
|
+
)
|
|
50
|
+
plt.title("Pearson Correlation Heatmap")
|
|
51
|
+
plt.xticks(rotation=45, ha="right")
|
|
52
|
+
plt.yticks(rotation=0)
|
|
53
|
+
plt.tight_layout()
|
|
54
|
+
plt.savefig(heatmap_path)
|
|
55
|
+
plt.close()
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"pearson_csv": pearson_path,
|
|
59
|
+
"spearman_csv": spearman_path,
|
|
60
|
+
"heatmap_path": heatmap_path,
|
|
61
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .base import BaseCheck
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
class DataQualityCheck(BaseCheck):
|
|
5
|
+
def __init__(self, model, X_train, X_test, y_train, y_test, rule_config, cleaned_data):
|
|
6
|
+
super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data)
|
|
7
|
+
|
|
8
|
+
def run(self):
|
|
9
|
+
df = self.cleaned_data
|
|
10
|
+
results = {}
|
|
11
|
+
|
|
12
|
+
if not isinstance(df, pd.DataFrame):
|
|
13
|
+
return {"error": "Cleaned data is not a valid DataFrame"}
|
|
14
|
+
|
|
15
|
+
# Missing value analysis
|
|
16
|
+
missing_ratio = df.isnull().mean()
|
|
17
|
+
results["avg_missing"] = round(missing_ratio.mean(), 4)
|
|
18
|
+
results["columns_with_missing"] = {
|
|
19
|
+
col: round(val, 4) for col, val in missing_ratio.items() if val > 0
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
# Constant columns (same value across all rows)
|
|
23
|
+
constant_cols = [col for col in df.columns if df[col].nunique(dropna=False) == 1]
|
|
24
|
+
results["constant_columns"] = constant_cols
|
|
25
|
+
|
|
26
|
+
return results
|
tanml/checks/eda.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from .base import BaseCheck
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
class EDACheck(BaseCheck):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
cleaned_data: pd.DataFrame,
|
|
11
|
+
rule_config: dict | None = None,
|
|
12
|
+
output_dir: str = "reports/eda",
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Basic Exploratory Data Analysis (EDA) check.
|
|
16
|
+
|
|
17
|
+
Generates summary statistics, missing values, and distribution plots.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cleaned_data (pd.DataFrame): Cleaned dataset to analyze.
|
|
21
|
+
rule_config (dict, optional): Rule configuration (YAML) to control plotting limits.
|
|
22
|
+
output_dir (str): Directory where EDA output files will be saved.
|
|
23
|
+
"""
|
|
24
|
+
self.cleaned_data = cleaned_data
|
|
25
|
+
self.output_dir = output_dir
|
|
26
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
|
27
|
+
|
|
28
|
+
# Max number of plots to generate (default = -1 = "all")
|
|
29
|
+
self.max_plots = (
|
|
30
|
+
rule_config.get("EDACheck", {}).get("max_plots", -1)
|
|
31
|
+
if rule_config is not None else -1
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def run(self):
|
|
35
|
+
print(f"📊 DEBUG — max_plots from YAML = {self.max_plots}")
|
|
36
|
+
|
|
37
|
+
# 1. Generate summary stats and uniqueness/missing info
|
|
38
|
+
summary = self.cleaned_data.describe(include='all').T.fillna('')
|
|
39
|
+
missing = self.cleaned_data.isnull().mean().round(3)
|
|
40
|
+
n_unique = self.cleaned_data.nunique()
|
|
41
|
+
|
|
42
|
+
# 2. Save CSVs
|
|
43
|
+
summary_path = os.path.join(self.output_dir, "summary_stats.csv")
|
|
44
|
+
missing_path = os.path.join(self.output_dir, "missing_values.csv")
|
|
45
|
+
summary.to_csv(summary_path)
|
|
46
|
+
pd.DataFrame({
|
|
47
|
+
"missing_ratio": missing,
|
|
48
|
+
"n_unique": n_unique
|
|
49
|
+
}).to_csv(missing_path)
|
|
50
|
+
|
|
51
|
+
# 3. Create histograms for numeric columns
|
|
52
|
+
num_cols = self.cleaned_data.select_dtypes(include="number").columns
|
|
53
|
+
cols_to_plot = num_cols if self.max_plots in (-1, None) else num_cols[:self.max_plots]
|
|
54
|
+
|
|
55
|
+
for col in cols_to_plot:
|
|
56
|
+
plt.figure(figsize=(6, 4))
|
|
57
|
+
sns.histplot(self.cleaned_data[col], kde=True)
|
|
58
|
+
plt.title(f"Distribution of {col}")
|
|
59
|
+
plt.savefig(os.path.join(self.output_dir, f"dist_{col}.png"))
|
|
60
|
+
plt.close()
|
|
61
|
+
|
|
62
|
+
# 4. Return metadata
|
|
63
|
+
return {
|
|
64
|
+
"summary_stats": summary_path,
|
|
65
|
+
"missing_values": missing_path,
|
|
66
|
+
"visualizations": [f"dist_{c}.png" for c in cols_to_plot],
|
|
67
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from tanml.checks.base import BaseCheck
|
|
2
|
+
import shap
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import traceback
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SHAPCheck(BaseCheck):
|
|
11
|
+
def __init__(self, model, X_train, X_test, y_train, y_test, rule_config=None, cleaned_df=None):
|
|
12
|
+
super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data=cleaned_df)
|
|
13
|
+
self.cleaned_df = cleaned_df
|
|
14
|
+
|
|
15
|
+
def run(self):
|
|
16
|
+
result = {}
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
expl_cfg = self.rule_config.get("explainability", {})
|
|
20
|
+
bg_n = expl_cfg.get("background_sample_size", 100)
|
|
21
|
+
test_n = expl_cfg.get("test_sample_size", 200)
|
|
22
|
+
|
|
23
|
+
X_sample = self.X_test[:test_n]
|
|
24
|
+
background = shap.utils.sample(self.X_train, bg_n, random_state=42)
|
|
25
|
+
|
|
26
|
+
X_sample = pd.DataFrame(X_sample)
|
|
27
|
+
background = pd.DataFrame(background)
|
|
28
|
+
|
|
29
|
+
explainer = shap.Explainer(self.model, background)
|
|
30
|
+
shap_exp = explainer(X_sample)
|
|
31
|
+
|
|
32
|
+
if shap_exp.values.ndim == 3:
|
|
33
|
+
shap_exp.values = shap_exp.values[:, :, 1]
|
|
34
|
+
shap_exp.base_values = shap_exp.base_values[:, 1]
|
|
35
|
+
|
|
36
|
+
segment = self.rule_config.get("meta", {}).get("segment", "global")
|
|
37
|
+
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
38
|
+
output_path = Path(f"reports/images/shap_summary_{segment}_{ts}.png")
|
|
39
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
plt.figure(figsize=(8, 6))
|
|
42
|
+
shap.plots.beeswarm(shap_exp, show=False)
|
|
43
|
+
plt.savefig(output_path, bbox_inches="tight")
|
|
44
|
+
plt.close()
|
|
45
|
+
|
|
46
|
+
print(f"✅ SHAP plot saved at: {output_path}")
|
|
47
|
+
result["shap_plot_path"] = str(output_path)
|
|
48
|
+
result["status"] = "SHAP plot generated successfully"
|
|
49
|
+
|
|
50
|
+
except Exception:
|
|
51
|
+
err = traceback.format_exc()
|
|
52
|
+
print(f"⚠️ SHAPCheck failed:\n{err}")
|
|
53
|
+
result["status"] = f"SHAP plot failed:\n{err}"
|
|
54
|
+
|
|
55
|
+
return result
|