tanml 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (62) hide show
  1. tanml/__init__.py +1 -0
  2. tanml/check_runners/__init__.py +0 -0
  3. tanml/check_runners/base_runner.py +6 -0
  4. tanml/check_runners/cleaning_repro_runner.py +18 -0
  5. tanml/check_runners/correlation_runner.py +15 -0
  6. tanml/check_runners/data_quality_runner.py +24 -0
  7. tanml/check_runners/eda_runner.py +21 -0
  8. tanml/check_runners/explainability_runner.py +28 -0
  9. tanml/check_runners/input_cluster_runner.py +43 -0
  10. tanml/check_runners/logistic_stats_runner.py +28 -0
  11. tanml/check_runners/model_meta_runner.py +23 -0
  12. tanml/check_runners/performance_runner.py +28 -0
  13. tanml/check_runners/raw_data_runner.py +41 -0
  14. tanml/check_runners/rule_engine_runner.py +5 -0
  15. tanml/check_runners/stress_test_runner.py +26 -0
  16. tanml/check_runners/vif_runner.py +54 -0
  17. tanml/checks/__init__.py +0 -0
  18. tanml/checks/base.py +20 -0
  19. tanml/checks/cleaning_repro.py +47 -0
  20. tanml/checks/correlation.py +61 -0
  21. tanml/checks/data_quality.py +26 -0
  22. tanml/checks/eda.py +67 -0
  23. tanml/checks/explainability/shap_check.py +55 -0
  24. tanml/checks/input_cluster.py +109 -0
  25. tanml/checks/logit_stats.py +59 -0
  26. tanml/checks/model_contents.py +40 -0
  27. tanml/checks/model_meta.py +50 -0
  28. tanml/checks/performance.py +90 -0
  29. tanml/checks/raw_data.py +47 -0
  30. tanml/checks/rule_engine.py +45 -0
  31. tanml/checks/stress_test.py +64 -0
  32. tanml/checks/vif.py +51 -0
  33. tanml/cli/__init__.py +0 -0
  34. tanml/cli/arg_parser.py +31 -0
  35. tanml/cli/init_cmd.py +8 -0
  36. tanml/cli/main.py +27 -0
  37. tanml/cli/validate_cmd.py +7 -0
  38. tanml/config_templates/__init__.py +0 -0
  39. tanml/config_templates/rules_multiple_models_datasets.yaml +144 -0
  40. tanml/config_templates/rules_one_dataset_segment_column.yaml +140 -0
  41. tanml/config_templates/rules_one_model_one_dataset.yaml +143 -0
  42. tanml/engine/__init__.py +0 -0
  43. tanml/engine/check_agent_registry.py +42 -0
  44. tanml/engine/core_engine_agent.py +115 -0
  45. tanml/engine/segmentation_agent.py +118 -0
  46. tanml/engine/validation_agent.py +91 -0
  47. tanml/report/report_builder.py +230 -0
  48. tanml/report/templates/report_template.docx +0 -0
  49. tanml/utils/__init__.py +0 -0
  50. tanml/utils/data_loader.py +17 -0
  51. tanml/utils/model_loader.py +35 -0
  52. tanml/utils/r_loader.py +30 -0
  53. tanml/utils/sas_loader.py +50 -0
  54. tanml/utils/yaml_generator.py +34 -0
  55. tanml/utils/yaml_loader.py +5 -0
  56. tanml/validate.py +209 -0
  57. tanml-0.1.6.dist-info/METADATA +317 -0
  58. tanml-0.1.6.dist-info/RECORD +62 -0
  59. tanml-0.1.6.dist-info/WHEEL +5 -0
  60. tanml-0.1.6.dist-info/entry_points.txt +2 -0
  61. tanml-0.1.6.dist-info/licenses/LICENSE +21 -0
  62. tanml-0.1.6.dist-info/top_level.txt +1 -0
tanml/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.1"
File without changes
@@ -0,0 +1,6 @@
1
+ class BaseCheckRunner:
2
+ def __init__(self, *args, **kwargs):
3
+ pass
4
+
5
+ def run(self, *args, **kwargs):
6
+ raise NotImplementedError("Each check runner must implement its own run method.")
@@ -0,0 +1,18 @@
1
+ # tanml/check_runners/cleaning_repro_runner.py
2
+ from tanml.checks.cleaning_repro import CleaningReproCheck
3
+
4
+ def run_cleaning_repro_check(model, X_train, X_test, y_train, y_test,
5
+ config, cleaned_data, *args, **kwargs):
6
+ # honour rules.yaml toggle
7
+ if not config.get("rules", {}).get("CleaningReproCheck", {}).get("enabled", True):
8
+ print("ℹ️ CleaningReproCheck skipped (disabled in rules.yaml)")
9
+ return None
10
+
11
+ # raw_df can come from rules.yaml *or* via kwargs (passed by ValidationEngine)
12
+ raw_data = config.get("raw_data") or kwargs.get("raw_df")
13
+ if raw_data is None:
14
+ print("⚠️ Skipping CleaningReproCheck — raw_data missing in config and kwargs")
15
+ return {"CleaningReproCheck": {"error": "raw_data not available"}}
16
+
17
+ check = CleaningReproCheck(raw_data, cleaned_data)
18
+ return {"CleaningReproCheck": check.run()}
@@ -0,0 +1,15 @@
1
+ from tanml.checks.correlation import CorrelationCheck
2
+
3
+ def CorrelationCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ try:
5
+ cfg = rule_config.get("CorrelationCheck", {})
6
+ if not cfg.get("enabled", True):
7
+ print("ℹ️ CorrelationCheck skipped (disabled in rules.yaml)")
8
+ return None
9
+
10
+ check = CorrelationCheck(cleaned_df)
11
+ return check.run()
12
+
13
+ except Exception as e:
14
+ print(f"⚠️ CorrelationCheck failed: {e}")
15
+ return {"CorrelationCheck": {"error": str(e)}}
@@ -0,0 +1,24 @@
1
+ from tanml.checks.data_quality import DataQualityCheck
2
+
3
+ def run_data_quality_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ dq_cfg = rule_config.get("DataQualityCheck", {})
5
+ if not dq_cfg.get("enabled", True):
6
+ print("ℹ️ Skipping DataQualityCheck (disabled in rules.yaml)")
7
+ return {"DataQualityCheck": {"skipped": True}}
8
+
9
+ try:
10
+ check = DataQualityCheck(
11
+ model=model,
12
+ X_train=X_train,
13
+ X_test=X_test,
14
+ y_train=y_train,
15
+ y_test=y_test,
16
+ rule_config=dq_cfg,
17
+ cleaned_data=cleaned_df
18
+ )
19
+ result = check.run()
20
+ return {"DataQualityCheck": result}
21
+
22
+ except Exception as e:
23
+ print(f"⚠️ DataQualityCheck failed: {e}")
24
+ return {"DataQualityCheck": {"error": str(e)}}
@@ -0,0 +1,21 @@
1
+ from tanml.checks.eda import EDACheck
2
+
3
+ def EDACheckRunner(
4
+ model, X_train, X_test, y_train, y_test,
5
+ rule_config, cleaned_df, *args, **kwargs
6
+ ):
7
+ try:
8
+ cfg = rule_config.get("EDACheck", {})
9
+ if not cfg.get("enabled", True):
10
+ print("ℹ️ EDACheck skipped (disabled in rules.yaml)")
11
+ return None
12
+
13
+ check = EDACheck(
14
+ cleaned_data=cleaned_df,
15
+ rule_config=rule_config
16
+ )
17
+ return check.run()
18
+
19
+ except Exception as e:
20
+ print(f"⚠️ EDACheck failed: {e}")
21
+ return {"EDACheck": {"error": str(e)}}
@@ -0,0 +1,28 @@
1
+ # tanml/check_runners/explainability_runner.py
2
+
3
+ from tanml.checks.explainability.shap_check import SHAPCheck
4
+
5
+ def run_shap_check(
6
+ model, X_train, X_test, y_train, y_test,
7
+ rule_config, cleaned_df, *args, **kwargs
8
+ ):
9
+ try:
10
+ cfg = rule_config.get("SHAPCheck", {})
11
+ if not cfg.get("enabled", True):
12
+ print("ℹ️ SHAPCheck skipped (disabled in rules.yaml)")
13
+ return None
14
+
15
+ check = SHAPCheck(
16
+ model=model,
17
+ X_train=X_train,
18
+ X_test=X_test,
19
+ y_train=y_train,
20
+ y_test=y_test,
21
+ rule_config=rule_config,
22
+ cleaned_df=cleaned_df,
23
+ )
24
+ return check.run()
25
+
26
+ except Exception as e:
27
+ print(f"⚠️ SHAPCheck failed: {e}")
28
+ return {"SHAPCheck": {"error": str(e)}}
@@ -0,0 +1,43 @@
1
+
2
+ # -----------------------------------------------------------------------------
3
+ # File: tanml/check_runners/input_cluster_runner.py
4
+ # -----------------------------------------------------------------------------
5
+ from pathlib import Path
6
+ from typing import Any, Dict
7
+
8
+ from tanml.checks.input_cluster import InputClusterCoverageCheck
9
+
10
+
11
+ def run_input_cluster_check(
12
+ model,
13
+ X_train,
14
+ X_test,
15
+ y_train,
16
+ y_test,
17
+ rule_config,
18
+ cleaned_df,
19
+ expected_features,
20
+ *args,
21
+ **kwargs,
22
+ ):
23
+
24
+ cfg = rule_config.get("InputClusterCoverageCheck", {})
25
+ if not cfg.get("enabled", False):
26
+ print("ℹ️ InputClusterCoverageCheck skipped (disabled in rules.yaml)")
27
+ return {"skipped": True}
28
+
29
+ try:
30
+ result = InputClusterCoverageCheck(
31
+ cleaned_df=cleaned_df,
32
+ feature_names=expected_features,
33
+ rule_config=rule_config,
34
+ ).run()
35
+ return result
36
+
37
+ except Exception as e:
38
+ print(f"⚠️ InputClusterCoverageCheck failed: {e}")
39
+ return {"error": str(e)}
40
+
41
+ except Exception as e:
42
+ print(f"⚠️ InputClusterCoverageCheck failed: {e}")
43
+ return {"InputClusterCheck": {"error": str(e)}}
@@ -0,0 +1,28 @@
1
+ from tanml.checks.logit_stats import LogisticStatsCheck
2
+ from sklearn.linear_model import LogisticRegression
3
+
4
+ def run_logistic_stats_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
5
+ try:
6
+ # Check if model is a LogisticRegression or statsmodels object
7
+ is_logistic = isinstance(model, LogisticRegression) or (
8
+ hasattr(model, "llf") and hasattr(model, "params") and hasattr(model, "summary")
9
+ )
10
+
11
+ if not is_logistic:
12
+ print("ℹ️ LogisticStatsCheck skipped — model is not logistic or not recognized")
13
+ return None
14
+
15
+ # Use training data only for fitting stats
16
+ check = LogisticStatsCheck(model, X_train, y_train, rule_config)
17
+ output = check.run()
18
+
19
+ return {
20
+ "LogisticStatsCheck": output["table"],
21
+ "LogisticStatsFit": output["fit"],
22
+ "LogisticStatsSummary": output["summary"],
23
+ "LogisticStatsCheck_obj": output["object"]
24
+ }
25
+
26
+ except Exception as e:
27
+ print(f"⚠️ LogisticStatsCheck failed: {e}")
28
+ return {"LogisticStatsCheck": {"error": str(e)}}
@@ -0,0 +1,23 @@
1
+ from tanml.checks.model_meta import ModelMetaCheck
2
+
3
+ def ModelMetaCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ try:
5
+ cfg = rule_config.get("ModelMetaCheck", {})
6
+ if not cfg.get("enabled", True):
7
+ print("ℹ️ ModelMetaCheck skipped (disabled in rules.yaml)")
8
+ return None
9
+
10
+ check = ModelMetaCheck(
11
+ model=model,
12
+ X_train=X_train,
13
+ X_test=X_test,
14
+ y_train=y_train,
15
+ y_test=y_test,
16
+ rule_config=rule_config,
17
+ cleaned_data=cleaned_df
18
+ )
19
+ return check.run()
20
+
21
+ except Exception as e:
22
+ print(f"⚠️ ModelMetaCheck failed: {e}")
23
+ return {"ModelMetaCheck": {"error": str(e)}}
@@ -0,0 +1,28 @@
1
+ from tanml.checks.performance import PerformanceCheck
2
+
3
+ def run_performance_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ perf_cfg = rule_config.get("PerformanceCheck", {})
5
+ if not perf_cfg.get("enabled", True):
6
+ print("ℹ️ Skipping PerformanceCheck (disabled in rules.yaml)")
7
+ return {"PerformanceCheck": {"skipped": True}}
8
+
9
+ if X_test is None or y_test is None or len(X_test) == 0 or len(y_test) == 0:
10
+ print("⚠️ Skipping PerformanceCheck due to empty test data.")
11
+ return {"PerformanceCheck": {"error": "Test data is empty :skipping performance evaluation."}}
12
+
13
+ try:
14
+ check = PerformanceCheck(
15
+ model=model,
16
+ X_train=X_train,
17
+ X_test=X_test,
18
+ y_train=y_train,
19
+ y_test=y_test,
20
+ rule_config=perf_cfg,
21
+ cleaned_data=cleaned_df
22
+ )
23
+ result = check.run()
24
+ return {"PerformanceCheck": result}
25
+
26
+ except Exception as e:
27
+ print(f"⚠️ PerformanceCheck failed: {e}")
28
+ return {"PerformanceCheck": {"error": str(e)}}
@@ -0,0 +1,41 @@
1
+ from tanml.checks.raw_data import RawDataCheck
2
+ import pandas as pd
3
+
4
+ def run_raw_data_check(model, X_train, X_test, y_train, y_test,
5
+ rule_config, cleaned_data, *args, **kwargs):
6
+ try:
7
+ # ---- locate raw data (DF or path) -------------
8
+ raw_obj = (
9
+ rule_config.get("raw_data") or
10
+ rule_config.get("paths", {}).get("raw_data")
11
+ )
12
+ if raw_obj is None:
13
+ print("ℹ️ RawDataCheck skipped — raw_data not provided in config.")
14
+ return None
15
+
16
+ # CSV path → load once
17
+ if isinstance(raw_obj, (str, bytes)):
18
+ raw_obj = pd.read_csv(raw_obj)
19
+
20
+ if not isinstance(raw_obj, pd.DataFrame):
21
+ print("ℹ️ RawDataCheck skipped — raw_data is not a DataFrame.")
22
+ return None
23
+
24
+ # ---- run the check -----------------------------
25
+ check = RawDataCheck(
26
+ model=model,
27
+ X_train=X_train,
28
+ X_test=X_test,
29
+ y_train=y_train,
30
+ y_test=y_test,
31
+ rule_config=rule_config,
32
+ cleaned_data=cleaned_data,
33
+ raw_data=raw_obj
34
+ )
35
+
36
+ stats = check.run()
37
+ return stats["RawDataCheck"] # hand the inner dict to ValidationEngine
38
+
39
+ except Exception as e:
40
+ print(f"⚠️ RawDataCheck failed: {e}")
41
+ return {"error": str(e)}
@@ -0,0 +1,5 @@
1
+ from tanml.checks.rule_engine import RuleEngineCheck
2
+
3
+ def RuleEngineCheckRunner(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ check = RuleEngineCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
5
+ return check.run()
@@ -0,0 +1,26 @@
1
+ from tanml.checks.stress_test import StressTestCheck
2
+
3
+ def run_stress_test_check(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df, *args, **kwargs):
4
+ cfg = rule_config.get("StressTestCheck", {})
5
+ if not cfg.get("enabled", True):
6
+ print("ℹ️ Skipping StressTestCheck (disabled in rules.yaml)")
7
+ return {"StressTestCheck": {"skipped": True}}
8
+
9
+ try:
10
+ epsilon = cfg.get("epsilon", 0.01)
11
+ perturb_fraction = cfg.get("perturb_fraction", 0.2)
12
+
13
+ checker = StressTestCheck(model, X_test, y_test, epsilon, perturb_fraction)
14
+ result = checker.run()
15
+
16
+ # Ensure output is always a dictionary
17
+ if isinstance(result, list):
18
+ return {"StressTestCheck": {"table": result}}
19
+ elif hasattr(result, "to_dict"):
20
+ return {"StressTestCheck": {"table": result.to_dict(orient="records")}}
21
+ else:
22
+ return {"StressTestCheck": {"output": result}}
23
+
24
+ except Exception as e:
25
+ print(f"⚠️ StressTestCheck failed: {e}")
26
+ return {"StressTestCheck": {"error": str(e)}}
@@ -0,0 +1,54 @@
1
+ # tanml/check_runners/vif_runner.py
2
+
3
+ from tanml.checks.vif import VIFCheck
4
+ import pandas as pd
5
+ from pathlib import Path
6
+
7
+ def VIFCheckRunner(
8
+ model, X_train, X_test, y_train, y_test,
9
+ rule_config, cleaned_df, *args, **kwargs
10
+ ):
11
+ # Ensure cleaned_df is a DataFrame
12
+ if isinstance(cleaned_df, (str, Path)):
13
+ try:
14
+ cleaned_df = pd.read_csv(cleaned_df)
15
+ except Exception as e:
16
+ err = f"Could not read cleaned_df CSV: {e}"
17
+ print(f"⚠️ {err}")
18
+ return {"vif_table": [], "high_vif_features": [], "error": err}
19
+
20
+ try:
21
+ check = VIFCheck(model, X_train, X_test, y_train, y_test, rule_config, cleaned_df)
22
+ result = check.run() # Could be dict or list
23
+
24
+ # Normalize result regardless of format
25
+ if isinstance(result, dict) and "vif_table" in result:
26
+ vif_rows = result["vif_table"]
27
+ elif isinstance(result, list):
28
+ vif_rows = result
29
+ else:
30
+ raise ValueError("Unexpected VIFCheck return shape")
31
+
32
+ # Rename 'feature' to 'Feature', round VIF values
33
+ for row in vif_rows:
34
+ if "Feature" not in row and "feature" in row:
35
+ row["Feature"] = row.pop("feature")
36
+ row["VIF"] = round(float(row["VIF"]), 2)
37
+
38
+ # Identify high VIF features
39
+ threshold = rule_config.get("vif_threshold", 5)
40
+ high_vif = [
41
+ row["Feature"] for row in vif_rows
42
+ if row.get("VIF") is not None and row["VIF"] > threshold
43
+ ]
44
+
45
+ # Return final output
46
+ return {
47
+ "vif_table": vif_rows,
48
+ "high_vif_features": high_vif,
49
+ "error": None,
50
+ }
51
+
52
+ except Exception as e:
53
+ print(f"⚠️ VIFCheck failed: {e}")
54
+ return {"vif_table": [], "high_vif_features": [], "error": str(e)}
File without changes
tanml/checks/base.py ADDED
@@ -0,0 +1,20 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class BaseCheck(ABC):
4
+ def __init__(self, model, X_train, X_test, y_train, y_test, rule_config, cleaned_data):
5
+ self.model = model
6
+ self.X_train = X_train
7
+ self.X_test = X_test
8
+ self.y_train = y_train
9
+ self.y_test = y_test
10
+ self.rule_config = rule_config
11
+ self.cleaned_data = cleaned_data
12
+
13
+
14
+
15
+ @abstractmethod
16
+ def run(self):
17
+ """
18
+ This method must be implemented by every check.
19
+ It should return a dictionary of results.
20
+ """
@@ -0,0 +1,47 @@
1
+ import pandas as pd
2
+
3
+ class CleaningReproCheck:
4
+ def __init__(self, raw_data, cleaned_data):
5
+ self.raw_data = raw_data
6
+ self.cleaned_data = cleaned_data
7
+
8
+ def run(self):
9
+ result = {}
10
+
11
+ try:
12
+ raw_df = self.raw_data
13
+ cleaned_df = self.cleaned_data
14
+
15
+ if not isinstance(raw_df, pd.DataFrame) or not isinstance(cleaned_df, pd.DataFrame):
16
+ raise ValueError("Missing or invalid raw_data / cleaned_data.")
17
+
18
+ validator_df = (
19
+ raw_df
20
+ .drop(columns=["constant_col"], errors="ignore")
21
+ .drop_duplicates()
22
+ .reset_index(drop=True)
23
+ )
24
+
25
+ result["dev_shape"] = cleaned_df.shape
26
+ result["validator_shape"] = validator_df.shape
27
+ result["same_shape"] = cleaned_df.shape == validator_df.shape
28
+
29
+ dev_cols = set(cleaned_df.columns)
30
+ val_cols = set(validator_df.columns)
31
+ result["extra_columns_in_dev"] = sorted(list(dev_cols - val_cols))
32
+ result["missing_columns_in_dev"] = sorted(list(val_cols - dev_cols))
33
+
34
+ common = cleaned_df.columns.intersection(validator_df.columns)
35
+ if cleaned_df[common].shape == validator_df[common].shape:
36
+ mismatch_count = (
37
+ cleaned_df[common].reset_index(drop=True) !=
38
+ validator_df[common].reset_index(drop=True)
39
+ ).to_numpy().sum()
40
+ result["cell_mismatches"] = int(mismatch_count)
41
+ else:
42
+ result["cell_mismatches"] = "Column/shape mismatch — cannot compare"
43
+
44
+ except Exception as e:
45
+ result["error"] = str(e)
46
+
47
+ return result
@@ -0,0 +1,61 @@
1
+ from .base import BaseCheck
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ import os
6
+
7
+ class CorrelationCheck(BaseCheck):
8
+ def __init__(self, cleaned_data: pd.DataFrame, output_dir: str = "reports/correlation"):
9
+ """
10
+ Computes Pearson and Spearman correlation matrices and saves them to disk,
11
+ along with a heatmap for visualization.
12
+ """
13
+ self.cleaned_data = cleaned_data
14
+ self.output_dir = output_dir
15
+ os.makedirs(self.output_dir, exist_ok=True)
16
+
17
+ def run(self):
18
+ # Select numeric features only
19
+ numeric_data = self.cleaned_data.select_dtypes(include="number")
20
+
21
+ if numeric_data.shape[1] < 2:
22
+ print("⚠️ Not enough numeric features for correlation.")
23
+ return {
24
+ "pearson_csv": None,
25
+ "spearman_csv": None,
26
+ "heatmap_path": None,
27
+ "error": "Not enough numeric features for correlation",
28
+ }
29
+
30
+ # Compute correlations
31
+ pearson_corr = numeric_data.corr(method="pearson")
32
+ spearman_corr = numeric_data.corr(method="spearman")
33
+
34
+ # Save CSVs
35
+ pearson_path = os.path.join(self.output_dir, "pearson_corr.csv")
36
+ spearman_path = os.path.join(self.output_dir, "spearman_corr.csv")
37
+ pearson_corr.to_csv(pearson_path)
38
+ spearman_corr.to_csv(spearman_path)
39
+
40
+ # Create heatmap
41
+ heatmap_path = os.path.join(self.output_dir, "heatmap.png")
42
+ plt.figure(figsize=(10, 8))
43
+ sns.heatmap(
44
+ pearson_corr,
45
+ annot=True,
46
+ fmt=".2f",
47
+ cmap="coolwarm",
48
+ cbar_kws={"label": "Pearson Coefficient"},
49
+ )
50
+ plt.title("Pearson Correlation Heatmap")
51
+ plt.xticks(rotation=45, ha="right")
52
+ plt.yticks(rotation=0)
53
+ plt.tight_layout()
54
+ plt.savefig(heatmap_path)
55
+ plt.close()
56
+
57
+ return {
58
+ "pearson_csv": pearson_path,
59
+ "spearman_csv": spearman_path,
60
+ "heatmap_path": heatmap_path,
61
+ }
@@ -0,0 +1,26 @@
1
+ from .base import BaseCheck
2
+ import pandas as pd
3
+
4
+ class DataQualityCheck(BaseCheck):
5
+ def __init__(self, model, X_train, X_test, y_train, y_test, rule_config, cleaned_data):
6
+ super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data)
7
+
8
+ def run(self):
9
+ df = self.cleaned_data
10
+ results = {}
11
+
12
+ if not isinstance(df, pd.DataFrame):
13
+ return {"error": "Cleaned data is not a valid DataFrame"}
14
+
15
+ # Missing value analysis
16
+ missing_ratio = df.isnull().mean()
17
+ results["avg_missing"] = round(missing_ratio.mean(), 4)
18
+ results["columns_with_missing"] = {
19
+ col: round(val, 4) for col, val in missing_ratio.items() if val > 0
20
+ }
21
+
22
+ # Constant columns (same value across all rows)
23
+ constant_cols = [col for col in df.columns if df[col].nunique(dropna=False) == 1]
24
+ results["constant_columns"] = constant_cols
25
+
26
+ return results
tanml/checks/eda.py ADDED
@@ -0,0 +1,67 @@
1
+ from .base import BaseCheck
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import os
6
+
7
+ class EDACheck(BaseCheck):
8
+ def __init__(
9
+ self,
10
+ cleaned_data: pd.DataFrame,
11
+ rule_config: dict | None = None,
12
+ output_dir: str = "reports/eda",
13
+ ):
14
+ """
15
+ Basic Exploratory Data Analysis (EDA) check.
16
+
17
+ Generates summary statistics, missing values, and distribution plots.
18
+
19
+ Args:
20
+ cleaned_data (pd.DataFrame): Cleaned dataset to analyze.
21
+ rule_config (dict, optional): Rule configuration (YAML) to control plotting limits.
22
+ output_dir (str): Directory where EDA output files will be saved.
23
+ """
24
+ self.cleaned_data = cleaned_data
25
+ self.output_dir = output_dir
26
+ os.makedirs(self.output_dir, exist_ok=True)
27
+
28
+ # Max number of plots to generate (default = -1 = "all")
29
+ self.max_plots = (
30
+ rule_config.get("EDACheck", {}).get("max_plots", -1)
31
+ if rule_config is not None else -1
32
+ )
33
+
34
+ def run(self):
35
+ print(f"📊 DEBUG — max_plots from YAML = {self.max_plots}")
36
+
37
+ # 1. Generate summary stats and uniqueness/missing info
38
+ summary = self.cleaned_data.describe(include='all').T.fillna('')
39
+ missing = self.cleaned_data.isnull().mean().round(3)
40
+ n_unique = self.cleaned_data.nunique()
41
+
42
+ # 2. Save CSVs
43
+ summary_path = os.path.join(self.output_dir, "summary_stats.csv")
44
+ missing_path = os.path.join(self.output_dir, "missing_values.csv")
45
+ summary.to_csv(summary_path)
46
+ pd.DataFrame({
47
+ "missing_ratio": missing,
48
+ "n_unique": n_unique
49
+ }).to_csv(missing_path)
50
+
51
+ # 3. Create histograms for numeric columns
52
+ num_cols = self.cleaned_data.select_dtypes(include="number").columns
53
+ cols_to_plot = num_cols if self.max_plots in (-1, None) else num_cols[:self.max_plots]
54
+
55
+ for col in cols_to_plot:
56
+ plt.figure(figsize=(6, 4))
57
+ sns.histplot(self.cleaned_data[col], kde=True)
58
+ plt.title(f"Distribution of {col}")
59
+ plt.savefig(os.path.join(self.output_dir, f"dist_{col}.png"))
60
+ plt.close()
61
+
62
+ # 4. Return metadata
63
+ return {
64
+ "summary_stats": summary_path,
65
+ "missing_values": missing_path,
66
+ "visualizations": [f"dist_{c}.png" for c in cols_to_plot],
67
+ }
@@ -0,0 +1,55 @@
1
+ from tanml.checks.base import BaseCheck
2
+ import shap
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import traceback
6
+ from pathlib import Path
7
+ from datetime import datetime
8
+
9
+
10
+ class SHAPCheck(BaseCheck):
11
+ def __init__(self, model, X_train, X_test, y_train, y_test, rule_config=None, cleaned_df=None):
12
+ super().__init__(model, X_train, X_test, y_train, y_test, rule_config, cleaned_data=cleaned_df)
13
+ self.cleaned_df = cleaned_df
14
+
15
+ def run(self):
16
+ result = {}
17
+
18
+ try:
19
+ expl_cfg = self.rule_config.get("explainability", {})
20
+ bg_n = expl_cfg.get("background_sample_size", 100)
21
+ test_n = expl_cfg.get("test_sample_size", 200)
22
+
23
+ X_sample = self.X_test[:test_n]
24
+ background = shap.utils.sample(self.X_train, bg_n, random_state=42)
25
+
26
+ X_sample = pd.DataFrame(X_sample)
27
+ background = pd.DataFrame(background)
28
+
29
+ explainer = shap.Explainer(self.model, background)
30
+ shap_exp = explainer(X_sample)
31
+
32
+ if shap_exp.values.ndim == 3:
33
+ shap_exp.values = shap_exp.values[:, :, 1]
34
+ shap_exp.base_values = shap_exp.base_values[:, 1]
35
+
36
+ segment = self.rule_config.get("meta", {}).get("segment", "global")
37
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
38
+ output_path = Path(f"reports/images/shap_summary_{segment}_{ts}.png")
39
+ output_path.parent.mkdir(parents=True, exist_ok=True)
40
+
41
+ plt.figure(figsize=(8, 6))
42
+ shap.plots.beeswarm(shap_exp, show=False)
43
+ plt.savefig(output_path, bbox_inches="tight")
44
+ plt.close()
45
+
46
+ print(f"✅ SHAP plot saved at: {output_path}")
47
+ result["shap_plot_path"] = str(output_path)
48
+ result["status"] = "SHAP plot generated successfully"
49
+
50
+ except Exception:
51
+ err = traceback.format_exc()
52
+ print(f"⚠️ SHAPCheck failed:\n{err}")
53
+ result["status"] = f"SHAP plot failed:\n{err}"
54
+
55
+ return result