tanml 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -0
- tanml/check_runners/__init__.py +0 -0
- tanml/check_runners/base_runner.py +6 -0
- tanml/check_runners/cleaning_repro_runner.py +18 -0
- tanml/check_runners/correlation_runner.py +15 -0
- tanml/check_runners/data_quality_runner.py +24 -0
- tanml/check_runners/eda_runner.py +21 -0
- tanml/check_runners/explainability_runner.py +28 -0
- tanml/check_runners/input_cluster_runner.py +43 -0
- tanml/check_runners/logistic_stats_runner.py +28 -0
- tanml/check_runners/model_meta_runner.py +23 -0
- tanml/check_runners/performance_runner.py +28 -0
- tanml/check_runners/raw_data_runner.py +41 -0
- tanml/check_runners/rule_engine_runner.py +5 -0
- tanml/check_runners/stress_test_runner.py +26 -0
- tanml/check_runners/vif_runner.py +54 -0
- tanml/checks/__init__.py +0 -0
- tanml/checks/base.py +20 -0
- tanml/checks/cleaning_repro.py +47 -0
- tanml/checks/correlation.py +61 -0
- tanml/checks/data_quality.py +26 -0
- tanml/checks/eda.py +67 -0
- tanml/checks/explainability/shap_check.py +55 -0
- tanml/checks/input_cluster.py +109 -0
- tanml/checks/logit_stats.py +59 -0
- tanml/checks/model_contents.py +40 -0
- tanml/checks/model_meta.py +50 -0
- tanml/checks/performance.py +90 -0
- tanml/checks/raw_data.py +47 -0
- tanml/checks/rule_engine.py +45 -0
- tanml/checks/stress_test.py +64 -0
- tanml/checks/vif.py +51 -0
- tanml/cli/__init__.py +0 -0
- tanml/cli/arg_parser.py +31 -0
- tanml/cli/init_cmd.py +8 -0
- tanml/cli/main.py +27 -0
- tanml/cli/validate_cmd.py +7 -0
- tanml/config_templates/__init__.py +0 -0
- tanml/config_templates/rules_multiple_models_datasets.yaml +144 -0
- tanml/config_templates/rules_one_dataset_segment_column.yaml +140 -0
- tanml/config_templates/rules_one_model_one_dataset.yaml +143 -0
- tanml/engine/__init__.py +0 -0
- tanml/engine/check_agent_registry.py +42 -0
- tanml/engine/core_engine_agent.py +115 -0
- tanml/engine/segmentation_agent.py +118 -0
- tanml/engine/validation_agent.py +91 -0
- tanml/report/report_builder.py +230 -0
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/__init__.py +0 -0
- tanml/utils/data_loader.py +17 -0
- tanml/utils/model_loader.py +35 -0
- tanml/utils/r_loader.py +30 -0
- tanml/utils/sas_loader.py +50 -0
- tanml/utils/yaml_generator.py +34 -0
- tanml/utils/yaml_loader.py +5 -0
- tanml/validate.py +209 -0
- tanml-0.1.6.dist-info/METADATA +317 -0
- tanml-0.1.6.dist-info/RECORD +62 -0
- tanml-0.1.6.dist-info/WHEEL +5 -0
- tanml-0.1.6.dist-info/entry_points.txt +2 -0
- tanml-0.1.6.dist-info/licenses/LICENSE +21 -0
- tanml-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import joblib
|
|
5
|
+
import importlib
|
|
6
|
+
from tanml.utils.sas_loader import SASLogisticModel
|
|
7
|
+
from tanml.utils.r_loader import RLogisticModel
|
|
8
|
+
from tanml.engine.validation_agent import SegmentValidator
|
|
9
|
+
from tanml.utils.data_loader import load_dataframe
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def handle_segmentation(segment_config, rule_config, args=None, report_output=None):
|
|
13
|
+
global_raw_path = rule_config.get("paths", {}).get("raw_data")
|
|
14
|
+
segment_col = segment_config.get("column") # β
only required for Scenario C
|
|
15
|
+
report_template = report_output # e.g., "reports/report_{segment}.docx"
|
|
16
|
+
|
|
17
|
+
print("π Detected segmentation setup in rules.yaml. Running each segment run separately...")
|
|
18
|
+
|
|
19
|
+
for name, run_cfg in segment_config["runs"].items():
|
|
20
|
+
print(f"\nπΉ Validating segment: {name}")
|
|
21
|
+
model_path = run_cfg["model"]
|
|
22
|
+
|
|
23
|
+
# CASE 1: Retrain model from cleaned data (model: train)
|
|
24
|
+
if isinstance(model_path, str) and model_path.lower() == "train":
|
|
25
|
+
print(f"π οΈ Retraining model from cleaned data for segment: {name}")
|
|
26
|
+
|
|
27
|
+
if "cleaned" in run_cfg:
|
|
28
|
+
cleaned_df_run = load_dataframe(run_cfg["cleaned"])
|
|
29
|
+
else:
|
|
30
|
+
full_cleaned = load_dataframe(rule_config["paths"]["cleaned_data"])
|
|
31
|
+
if segment_col:
|
|
32
|
+
cleaned_df_run = full_cleaned[full_cleaned[segment_col] == name]
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError("Missing segment column for slicing cleaned data.")
|
|
35
|
+
|
|
36
|
+
X = cleaned_df_run[rule_config["model"]["features"]]
|
|
37
|
+
y = cleaned_df_run[rule_config["model"]["target"]]
|
|
38
|
+
|
|
39
|
+
model_source = rule_config.get("model_source", {})
|
|
40
|
+
model_type = model_source.get("type")
|
|
41
|
+
model_module = model_source.get("module")
|
|
42
|
+
model_params = model_source.get("hyperparameters", {})
|
|
43
|
+
|
|
44
|
+
if not model_type or not model_module:
|
|
45
|
+
raise ValueError("β 'model_source.type' and 'model_source.module' must be defined for retraining.")
|
|
46
|
+
|
|
47
|
+
model_class = getattr(importlib.import_module(model_module), model_type)
|
|
48
|
+
model = model_class(**model_params)
|
|
49
|
+
print(f"π¦ Using model: {model}")
|
|
50
|
+
|
|
51
|
+
model.fit(X, y)
|
|
52
|
+
print(f"β
Retrained {model_type} for segment '{name}'")
|
|
53
|
+
|
|
54
|
+
elif isinstance(model_path, str) and "r_logistic" in model_path.lower():
|
|
55
|
+
model = RLogisticModel(model_path)
|
|
56
|
+
|
|
57
|
+
elif isinstance(model_path, str) and model_path.endswith(".pkl"):
|
|
58
|
+
model = joblib.load(model_path)
|
|
59
|
+
|
|
60
|
+
elif isinstance(model_path, str) and model_path.endswith(".csv"):
|
|
61
|
+
base = os.path.splitext(model_path)[0]
|
|
62
|
+
model = SASLogisticModel(
|
|
63
|
+
coeffs_path=model_path,
|
|
64
|
+
intercept_path=base + "_intercept.txt",
|
|
65
|
+
feature_order_path=base + "_features.txt"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"β Unsupported model format for segment '{name}': {model_path}")
|
|
70
|
+
|
|
71
|
+
# Load and optionally slice raw data
|
|
72
|
+
raw_df_run = None
|
|
73
|
+
if global_raw_path and os.path.exists(global_raw_path):
|
|
74
|
+
full_raw = load_dataframe(global_raw_path)
|
|
75
|
+
if segment_col:
|
|
76
|
+
raw_df_run = full_raw[full_raw[segment_col] == name]
|
|
77
|
+
else:
|
|
78
|
+
raw_df_run = full_raw
|
|
79
|
+
|
|
80
|
+
print(f"[DEBUG] raw_df_run for {name}: {type(raw_df_run)}, rows = {len(raw_df_run) if raw_df_run is not None else 'None'}")
|
|
81
|
+
|
|
82
|
+
# Load cleaned data
|
|
83
|
+
if "cleaned" in run_cfg:
|
|
84
|
+
cleaned_df_run = load_dataframe(run_cfg["cleaned"])
|
|
85
|
+
else:
|
|
86
|
+
full_cleaned = load_dataframe(rule_config["paths"]["cleaned_data"])
|
|
87
|
+
if segment_col:
|
|
88
|
+
cleaned_df_run = full_cleaned[full_cleaned[segment_col] == name]
|
|
89
|
+
else:
|
|
90
|
+
cleaned_df_run = full_cleaned
|
|
91
|
+
|
|
92
|
+
# Format output path
|
|
93
|
+
report_base = report_template.format(segment=name)
|
|
94
|
+
|
|
95
|
+
# Run validation
|
|
96
|
+
validator = SegmentValidator(
|
|
97
|
+
segment_column=segment_col,
|
|
98
|
+
segment_values=[name],
|
|
99
|
+
model=model,
|
|
100
|
+
raw_df=raw_df_run,
|
|
101
|
+
cleaned_df=cleaned_df_run,
|
|
102
|
+
target_col=rule_config.get("model", {}).get("target"),
|
|
103
|
+
config=rule_config,
|
|
104
|
+
segment_name=name,
|
|
105
|
+
report_base=report_base
|
|
106
|
+
)
|
|
107
|
+
results = validator.run()
|
|
108
|
+
|
|
109
|
+
# Extract report path from SegmentValidator result
|
|
110
|
+
report_base_path = report_template.format(segment=name)
|
|
111
|
+
report_path = os.path.join(report_base_path, f"report_{name}.docx")
|
|
112
|
+
|
|
113
|
+
print(f"π Report saved: {report_path}")
|
|
114
|
+
print(f"β
Segment '{name}' validated.")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
print("β
All segment runs completed.")
|
|
118
|
+
return True
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from importlib.resources import files
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import tzlocal
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from tanml.engine.core_engine_agent import ValidationEngine
|
|
9
|
+
from tanml.report.report_builder import ReportBuilder
|
|
10
|
+
from tanml.engine.check_agent_registry import CHECK_RUNNER_REGISTRY
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SegmentValidator:
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
model,
|
|
17
|
+
raw_df,
|
|
18
|
+
cleaned_df,
|
|
19
|
+
config,
|
|
20
|
+
target_col=None,
|
|
21
|
+
segment_column=None,
|
|
22
|
+
segment_values=None,
|
|
23
|
+
segment_name=None,
|
|
24
|
+
report_base="reports"
|
|
25
|
+
):
|
|
26
|
+
self.model = model
|
|
27
|
+
self.raw_df = raw_df
|
|
28
|
+
self.cleaned_df = cleaned_df
|
|
29
|
+
self.config = config
|
|
30
|
+
self.target_col = target_col
|
|
31
|
+
self.segment_column = segment_column
|
|
32
|
+
self.segment_values = segment_values
|
|
33
|
+
self.segment_name = segment_name
|
|
34
|
+
self.report_base = report_base
|
|
35
|
+
|
|
36
|
+
def run(self):
|
|
37
|
+
if self.segment_name:
|
|
38
|
+
return self._run_single(self.cleaned_df, self.segment_name)
|
|
39
|
+
|
|
40
|
+
if not self.segment_column:
|
|
41
|
+
raise ValueError("Segmentation column not specified in rules.yaml")
|
|
42
|
+
|
|
43
|
+
results = {}
|
|
44
|
+
for segment in self.segment_values:
|
|
45
|
+
segment_df = self.cleaned_df[self.cleaned_df[self.segment_column] == segment]
|
|
46
|
+
print(f"πΉ Running validation for segment value: {segment}")
|
|
47
|
+
result = self._run_single(segment_df, segment)
|
|
48
|
+
results[segment] = result
|
|
49
|
+
|
|
50
|
+
return results
|
|
51
|
+
|
|
52
|
+
def _run_single(self, segment_df, segment_name):
|
|
53
|
+
if self.target_col is None or self.target_col not in segment_df.columns:
|
|
54
|
+
raise ValueError(f"β Target column '{self.target_col}' missing in cleaned data for segment '{segment_name}'")
|
|
55
|
+
|
|
56
|
+
y = segment_df[self.target_col]
|
|
57
|
+
cols_to_drop = [c for c in (self.target_col, self.segment_column) if c in segment_df.columns]
|
|
58
|
+
X = segment_df.drop(columns=cols_to_drop)
|
|
59
|
+
|
|
60
|
+
# Pass raw_df to ValidationEngine so CleaningReproCheck works
|
|
61
|
+
engine = ValidationEngine(self.model, X, X, y, y, self.config, segment_df, self.raw_df)
|
|
62
|
+
results = engine.run_all_checks()
|
|
63
|
+
|
|
64
|
+
local_tz = tzlocal.get_localzone()
|
|
65
|
+
now = datetime.now(local_tz)
|
|
66
|
+
results["validation_date"] = now.strftime("%Y-%m-%d %H:%M:%S %Z (UTC%z)")
|
|
67
|
+
results["model_path"] = self.model.__class__.__name__
|
|
68
|
+
results["validated_by"] = "TanML Automated Validator"
|
|
69
|
+
results["rules"] = self.config
|
|
70
|
+
|
|
71
|
+
report_base_formatted = self.report_base.format(segment=segment_name) if "{segment}" in self.report_base else self.report_base
|
|
72
|
+
output_dir = Path(report_base_formatted)
|
|
73
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
|
|
75
|
+
tpl_cfg = self.config.get("output", {}).get("template_path") # may be None
|
|
76
|
+
template_path = (
|
|
77
|
+
Path(tpl_cfg).expanduser() if tpl_cfg
|
|
78
|
+
else files("tanml.report.templates").joinpath("report_template.docx")
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
output_path = output_dir / f"report_{segment_name}.docx"
|
|
82
|
+
|
|
83
|
+
builder = ReportBuilder(results, template_path, output_path)
|
|
84
|
+
builder.build()
|
|
85
|
+
|
|
86
|
+
print(f"π Report saved for segment '{segment_name}': {output_path}")
|
|
87
|
+
print(json.dumps(results, indent=2, default=str))
|
|
88
|
+
return results
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
__all__ = ["SegmentValidator", "CHECK_RUNNER_REGISTRY"]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# tanml/report/report_builder.py
|
|
2
|
+
from docxtpl import DocxTemplate, InlineImage
|
|
3
|
+
from docx.shared import Inches, Mm
|
|
4
|
+
from docx import Document
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import os, imgkit, copy as pycopy
|
|
7
|
+
from importlib.resources import files
|
|
8
|
+
|
|
9
|
+
TMP_DIR = Path(__file__).resolve().parents[1] / "tmp_report_assets"
|
|
10
|
+
TMP_DIR.mkdir(exist_ok=True)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AttrDict(dict):
|
|
14
|
+
def __getattr__(self, item):
|
|
15
|
+
try:
|
|
16
|
+
return self[item]
|
|
17
|
+
except KeyError:
|
|
18
|
+
raise AttributeError(item)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ReportBuilder:
|
|
22
|
+
"""Build a Word report from validation results."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, results, template_path, output_path):
|
|
25
|
+
self.results = results
|
|
26
|
+
self.template_path = template_path or files("tanml.report.templates").joinpath("report_template.docx")
|
|
27
|
+
|
|
28
|
+
self.output_path = output_path
|
|
29
|
+
|
|
30
|
+
corr = results.get("CorrelationCheck", {})
|
|
31
|
+
self.corr_heatmap_path = corr.get("heatmap_path")
|
|
32
|
+
self.corr_pearson_path = corr.get("pearson_csv", "N/A")
|
|
33
|
+
self.corr_spearman_path = corr.get("spearman_csv", "N/A")
|
|
34
|
+
|
|
35
|
+
def _grab(self, name, default=None):
|
|
36
|
+
return (
|
|
37
|
+
self.results.get(name)
|
|
38
|
+
or self.results.get("check_results", {}).get(name)
|
|
39
|
+
or default
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def build(self):
|
|
43
|
+
doc = DocxTemplate(str(self.template_path))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Ensure RuleEngineCheck exists so template never crashes
|
|
47
|
+
self.results.setdefault(
|
|
48
|
+
"RuleEngineCheck", AttrDict({"rules": {}, "overall_pass": True})
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Jinja context
|
|
52
|
+
ctx = pycopy.deepcopy(self.results)
|
|
53
|
+
ctx.update(self.results.get("check_results", {})) # 1st-level flatten
|
|
54
|
+
|
|
55
|
+
for k, v in list(ctx.items()):
|
|
56
|
+
if isinstance(v, dict) and k in v and len(v) == 1:
|
|
57
|
+
ctx[k] = v[k]
|
|
58
|
+
|
|
59
|
+
for k, note in [
|
|
60
|
+
("RawDataCheck", "Raw-data check skipped"),
|
|
61
|
+
("CleaningReproCheck", "Cleaning-repro check skipped"),
|
|
62
|
+
]:
|
|
63
|
+
ctx.setdefault(k, AttrDict({"note": note}))
|
|
64
|
+
|
|
65
|
+
if "ModelMetaCheck" not in ctx or "model_type" not in ctx["ModelMetaCheck"]:
|
|
66
|
+
meta_fields = [
|
|
67
|
+
"model_type",
|
|
68
|
+
"model_class",
|
|
69
|
+
"module",
|
|
70
|
+
"n_features",
|
|
71
|
+
"feature_names",
|
|
72
|
+
"n_train_rows",
|
|
73
|
+
"target_balance",
|
|
74
|
+
"hyperparam_table",
|
|
75
|
+
"attributes",
|
|
76
|
+
]
|
|
77
|
+
meta = {f: self.results.get(f) for f in meta_fields if self.results.get(f) is not None}
|
|
78
|
+
ctx["ModelMetaCheck"] = AttrDict(meta or {"note": "Model metadata not available"})
|
|
79
|
+
|
|
80
|
+
# SHAP image
|
|
81
|
+
shap_path = self._grab("SHAPCheck", {}).get("shap_plot_path")
|
|
82
|
+
ctx["shap_plot"] = (
|
|
83
|
+
InlineImage(doc, shap_path, width=Inches(5))
|
|
84
|
+
if shap_path and os.path.exists(shap_path)
|
|
85
|
+
else "SHAP plot not available"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# EDA
|
|
89
|
+
eda = self._grab("EDACheck", {})
|
|
90
|
+
ctx["eda_summary_path"] = eda.get("summary_stats", "N/A")
|
|
91
|
+
ctx["eda_missing_path"] = eda.get("missing_values", "N/A")
|
|
92
|
+
ctx["eda_images"] = [
|
|
93
|
+
InlineImage(doc, os.path.join("reports/eda", fn), width=Inches(4.5))
|
|
94
|
+
if os.path.exists(os.path.join("reports/eda", fn))
|
|
95
|
+
else f"Missing: {fn}"
|
|
96
|
+
for fn in eda.get("visualizations", [])
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
# Correlation visuals
|
|
100
|
+
if self.corr_heatmap_path and os.path.exists(self.corr_heatmap_path):
|
|
101
|
+
ctx["correlation_heatmap"] = InlineImage(
|
|
102
|
+
doc, self.corr_heatmap_path, width=Inches(5)
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
ctx["correlation_heatmap"] = "Heatmap not available"
|
|
106
|
+
ctx["correlation_pearson_path"] = self.corr_pearson_path
|
|
107
|
+
ctx["correlation_spearman_path"] = self.corr_spearman_path
|
|
108
|
+
|
|
109
|
+
# Performance
|
|
110
|
+
perf = self._grab("PerformanceCheck", {})
|
|
111
|
+
if not perf:
|
|
112
|
+
perf = {
|
|
113
|
+
"accuracy": "N/A",
|
|
114
|
+
"auc": "N/A",
|
|
115
|
+
"ks": "N/A",
|
|
116
|
+
"f1": "N/A",
|
|
117
|
+
"confusion_matrix": [],
|
|
118
|
+
}
|
|
119
|
+
ctx.setdefault("check_results", {})["PerformanceCheck"] = perf
|
|
120
|
+
ctx["PerformanceCheck"] = perf
|
|
121
|
+
|
|
122
|
+
# Logistic summary image
|
|
123
|
+
if "LogisticStatsCheck_obj" in self.results:
|
|
124
|
+
try:
|
|
125
|
+
add_logit_summary_image(
|
|
126
|
+
doc, self.results["LogisticStatsCheck_obj"], ctx, "LogitSummaryImg"
|
|
127
|
+
)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
print("β οΈ logistic summary image failed:", e)
|
|
130
|
+
|
|
131
|
+
# VIF
|
|
132
|
+
vif = self._grab("VIFCheck", {})
|
|
133
|
+
ctx["VIFCheck"] = AttrDict(
|
|
134
|
+
vif
|
|
135
|
+
if isinstance(vif, dict) and "vif_table" in vif
|
|
136
|
+
else {"vif_table": [], "high_vif_features": [], "error": "Invalid VIFCheck"}
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Stress / cluster
|
|
140
|
+
if isinstance(self.results.get("StressTestCheck"), list):
|
|
141
|
+
ctx["StressTestCheck"] = {"table": self.results["StressTestCheck"]}
|
|
142
|
+
|
|
143
|
+
cluster_rows = self._grab("InputClusterCheck", {}).get("cluster_table", [])
|
|
144
|
+
ctx.setdefault("InputClusterCheck", {})["cluster_table"] = [
|
|
145
|
+
{
|
|
146
|
+
"Cluster": r.get("Cluster") or r.get("cluster"),
|
|
147
|
+
"Count": r.get("Count") or r.get("count"),
|
|
148
|
+
"Percent": r.get("Percent") or r.get("percent"),
|
|
149
|
+
}
|
|
150
|
+
for r in cluster_rows
|
|
151
|
+
if isinstance(r, dict)
|
|
152
|
+
]
|
|
153
|
+
plot_path = self._grab("InputClusterCheck", {}).get("cluster_plot_img")
|
|
154
|
+
ctx["InputClusterCheck"]["cluster_plot_img"] = (
|
|
155
|
+
InlineImage(doc, plot_path, width=Inches(5))
|
|
156
|
+
if plot_path and os.path.exists(plot_path)
|
|
157
|
+
else "Plot not available"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Render DOCX template
|
|
161
|
+
print("π’ ctx top-level keys:", list(ctx.keys()))
|
|
162
|
+
print("π RawDataCheck value:", ctx.get("RawDataCheck"))
|
|
163
|
+
|
|
164
|
+
doc.render(ctx)
|
|
165
|
+
doc.save(self.output_path)
|
|
166
|
+
|
|
167
|
+
# Auto-insert tables after anchors
|
|
168
|
+
tbl_specs = [
|
|
169
|
+
{
|
|
170
|
+
"anchor": "Stress Testing Results",
|
|
171
|
+
"headers": [
|
|
172
|
+
"feature",
|
|
173
|
+
"perturbation",
|
|
174
|
+
"accuracy",
|
|
175
|
+
"auc",
|
|
176
|
+
"delta_accuracy",
|
|
177
|
+
"delta_auc",
|
|
178
|
+
],
|
|
179
|
+
"rows": ctx.get("StressTestCheck", {}).get("table", []),
|
|
180
|
+
},
|
|
181
|
+
{
|
|
182
|
+
"anchor": "Cluster Summary Table:",
|
|
183
|
+
"headers": ["Cluster", "Count", "Percent"],
|
|
184
|
+
"rows": ctx.get("InputClusterCheck", {}).get("cluster_table", []),
|
|
185
|
+
},
|
|
186
|
+
{
|
|
187
|
+
"anchor": "Variance Inflation Factor (VIF) Check",
|
|
188
|
+
"headers": ["Feature", "VIF"],
|
|
189
|
+
"rows": ctx.get("VIFCheck", {}).get("vif_table", []),
|
|
190
|
+
},
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
docx = Document(self.output_path)
|
|
194
|
+
for spec in tbl_specs:
|
|
195
|
+
if spec["rows"]:
|
|
196
|
+
tbl = build_table(docx, spec["headers"], spec["rows"])
|
|
197
|
+
insert_after(docx, spec["anchor"], tbl)
|
|
198
|
+
print(f"β
added table after Β«{spec['anchor']}Β»")
|
|
199
|
+
docx.save(self.output_path)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def build_table(doc, headers, rows):
|
|
203
|
+
tbl = doc.add_table(rows=1, cols=len(headers))
|
|
204
|
+
tbl.style = "Table Grid"
|
|
205
|
+
for i, h in enumerate(headers):
|
|
206
|
+
tbl.rows[0].cells[i].text = str(h)
|
|
207
|
+
for r in rows:
|
|
208
|
+
vals = [r.get(h, "") for h in headers] if isinstance(r, dict) else list(r)
|
|
209
|
+
vals += [""] * (len(headers) - len(vals))
|
|
210
|
+
row = tbl.add_row().cells
|
|
211
|
+
for i, v in enumerate(vals):
|
|
212
|
+
row[i].text = str(v)
|
|
213
|
+
return tbl
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def insert_after(doc, anchor, tbl):
|
|
217
|
+
for p in doc.paragraphs:
|
|
218
|
+
if anchor.lower() in p.text.lower():
|
|
219
|
+
parent = p._p.getparent()
|
|
220
|
+
parent.insert(parent.index(p._p) + 1, tbl._tbl)
|
|
221
|
+
return
|
|
222
|
+
print(f"β οΈ anchor Β«{anchor}Β» not found")
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def add_logit_summary_image(tpl_doc, sm_results, ctx, key):
|
|
226
|
+
html = TMP_DIR / "logit_summary.html"
|
|
227
|
+
html.write_text(sm_results.summary().as_html(), encoding="utf8")
|
|
228
|
+
png = TMP_DIR / "logit_summary.png"
|
|
229
|
+
imgkit.from_file(str(html), str(png), options={"quiet": ""})
|
|
230
|
+
ctx[key] = InlineImage(tpl_doc, str(png), width=Mm(160))
|
|
Binary file
|
tanml/utils/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
def load_dataframe(filepath):
|
|
5
|
+
ext = os.path.splitext(filepath)[1].lower()
|
|
6
|
+
if ext == ".csv":
|
|
7
|
+
return pd.read_csv(filepath)
|
|
8
|
+
elif ext in [".xls", ".xlsx"]:
|
|
9
|
+
return pd.read_excel(filepath)
|
|
10
|
+
elif ext == ".parquet":
|
|
11
|
+
return pd.read_parquet(filepath)
|
|
12
|
+
elif ext == ".sas7bdat":
|
|
13
|
+
return pd.read_sas(filepath)
|
|
14
|
+
elif ext in [".txt", ".tsv"]:
|
|
15
|
+
return pd.read_csv(filepath, sep="\t")
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(f"Unsupported file format: {ext}")
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# tanml/utils/model_loader.py
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import joblib
|
|
5
|
+
from tanml.utils.sas_loader import SASLogisticModel
|
|
6
|
+
from tanml.utils.r_loader import RLogisticModel
|
|
7
|
+
|
|
8
|
+
def load_model(model_path):
|
|
9
|
+
"""
|
|
10
|
+
Load a model from path. Supports:
|
|
11
|
+
- sklearn/xgboost .pkl
|
|
12
|
+
- SAS .csv with _intercept.txt and _features.txt
|
|
13
|
+
- R exported logistic CSV
|
|
14
|
+
"""
|
|
15
|
+
if not model_path:
|
|
16
|
+
raise ValueError("β No model path provided.")
|
|
17
|
+
|
|
18
|
+
if "r_logistic" in model_path.lower():
|
|
19
|
+
print("β
Detected R Logistic Regression model")
|
|
20
|
+
return RLogisticModel(model_path)
|
|
21
|
+
|
|
22
|
+
elif model_path.endswith(".pkl"):
|
|
23
|
+
print(f"β
Loading sklearn/XGB model from {model_path}")
|
|
24
|
+
return joblib.load(model_path)
|
|
25
|
+
|
|
26
|
+
elif model_path.endswith(".csv"):
|
|
27
|
+
base = os.path.splitext(model_path)[0]
|
|
28
|
+
return SASLogisticModel(
|
|
29
|
+
coeffs_path=model_path,
|
|
30
|
+
intercept_path=base + "_intercept.txt",
|
|
31
|
+
feature_order_path=base + "_features.txt"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError("β Unsupported model format. Use .pkl, .csv, or R model CSV")
|
tanml/utils/r_loader.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
class RLogisticModel:
|
|
5
|
+
"""
|
|
6
|
+
Wrapper for logistic regression models exported from R.
|
|
7
|
+
Assumes CSV with columns: ID, y_true, y_pred_proba
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, model_path: str):
|
|
11
|
+
self.model_path = model_path
|
|
12
|
+
self.df = pd.read_csv(model_path)
|
|
13
|
+
|
|
14
|
+
# Check required columns exist
|
|
15
|
+
expected_cols = {'y_true', 'y_pred_proba'}
|
|
16
|
+
if not expected_cols.issubset(set(self.df.columns)):
|
|
17
|
+
raise ValueError(f"R model CSV must contain columns: {expected_cols}")
|
|
18
|
+
|
|
19
|
+
self.y_true = self.df['y_true'].values
|
|
20
|
+
self.y_pred_proba = self.df['y_pred_proba'].values
|
|
21
|
+
|
|
22
|
+
def predict_proba(self, X=None):
|
|
23
|
+
"""
|
|
24
|
+
Mimics sklearnβs predict_proba format: n_samples x 2
|
|
25
|
+
"""
|
|
26
|
+
proba = self.y_pred_proba.reshape(-1, 1)
|
|
27
|
+
return np.hstack([1 - proba, proba])
|
|
28
|
+
|
|
29
|
+
def get_true_labels(self):
|
|
30
|
+
return self.y_true
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# File: tanml/utils/sas_loader.py
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SASLogisticModel:
|
|
8
|
+
def __init__(self, coeffs_path, intercept_path, feature_order_path):
|
|
9
|
+
self.coeffs_path = coeffs_path
|
|
10
|
+
self.intercept_path = intercept_path
|
|
11
|
+
self.feature_order_path = feature_order_path
|
|
12
|
+
|
|
13
|
+
self.coefficients = self._load_coefficients()
|
|
14
|
+
self.intercept = self._load_intercept()
|
|
15
|
+
self.feature_order = self._load_feature_order()
|
|
16
|
+
|
|
17
|
+
def _load_coefficients(self):
|
|
18
|
+
return pd.read_csv(self.coeffs_path, index_col=0).squeeze("columns")
|
|
19
|
+
|
|
20
|
+
def _load_intercept(self):
|
|
21
|
+
with open(self.intercept_path) as f:
|
|
22
|
+
return float(f.read().strip())
|
|
23
|
+
|
|
24
|
+
def _load_feature_order(self):
|
|
25
|
+
with open(self.feature_order_path) as f:
|
|
26
|
+
return [line.strip() for line in f.readlines()]
|
|
27
|
+
|
|
28
|
+
def predict_proba(self, X):
|
|
29
|
+
"""
|
|
30
|
+
Return a NumPy array shaped (n_samples, 2) like sklearn:
|
|
31
|
+
[:, 0] = P(class 0), [:, 1] = P(class 1)
|
|
32
|
+
"""
|
|
33
|
+
X = X[self.feature_order]
|
|
34
|
+
logits = X.dot(self.coefficients) + self.intercept
|
|
35
|
+
|
|
36
|
+
# numeric stability clamp
|
|
37
|
+
logits = logits.clip(-700, 700)
|
|
38
|
+
|
|
39
|
+
proba_1 = 1 / (1 + np.exp(-logits))
|
|
40
|
+
proba_0 = 1 - proba_1
|
|
41
|
+
return np.vstack([proba_0, proba_1]).T # shape (n, 2)
|
|
42
|
+
|
|
43
|
+
def predict(self, X):
|
|
44
|
+
"""
|
|
45
|
+
Return class labels (0/1) based on 0.5 threshold.
|
|
46
|
+
Works with the NumPy array returned by predict_proba().
|
|
47
|
+
"""
|
|
48
|
+
proba_1 = self.predict_proba(X)[:, 1] # probability of class 1
|
|
49
|
+
return (proba_1 >= 0.5).astype(int)
|
|
50
|
+
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import shutil
|
|
3
|
+
|
|
4
|
+
def generate_rules_yaml(
|
|
5
|
+
scenario: str = "A",
|
|
6
|
+
dest_path: str = "rules.yaml",
|
|
7
|
+
overwrite: bool = False
|
|
8
|
+
):
|
|
9
|
+
|
|
10
|
+
scenario = scenario.upper()
|
|
11
|
+
template_map = {
|
|
12
|
+
"A": "rules_one_model_one_dataset.yaml",
|
|
13
|
+
"B": "rules_multiple_models_datasets.yaml",
|
|
14
|
+
"C": "rules_one_dataset_segment_column.yaml"
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
if scenario not in template_map:
|
|
18
|
+
raise ValueError("Invalid scenario. Must be 'A', 'B', or 'C'.")
|
|
19
|
+
|
|
20
|
+
src = Path(__file__).parent.parent / "config_templates" / template_map[scenario]
|
|
21
|
+
if not src.exists():
|
|
22
|
+
raise FileNotFoundError(f"Template not found at {src}")
|
|
23
|
+
|
|
24
|
+
dst = Path(dest_path)
|
|
25
|
+
if dst.exists() and not overwrite:
|
|
26
|
+
raise FileExistsError(f"{dst} already exists. Use --overwrite to replace it.")
|
|
27
|
+
|
|
28
|
+
if dst.parent and not dst.parent.exists():
|
|
29
|
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
|
|
31
|
+
shutil.copyfile(src, dst)
|
|
32
|
+
|
|
33
|
+
print(f"β
Created: {dst.resolve()} for Scenario {scenario}")
|
|
34
|
+
print("π Now edit this YAML to fill in your model, data, and feature details.")
|