tanml 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (62) hide show
  1. tanml/__init__.py +1 -0
  2. tanml/check_runners/__init__.py +0 -0
  3. tanml/check_runners/base_runner.py +6 -0
  4. tanml/check_runners/cleaning_repro_runner.py +18 -0
  5. tanml/check_runners/correlation_runner.py +15 -0
  6. tanml/check_runners/data_quality_runner.py +24 -0
  7. tanml/check_runners/eda_runner.py +21 -0
  8. tanml/check_runners/explainability_runner.py +28 -0
  9. tanml/check_runners/input_cluster_runner.py +43 -0
  10. tanml/check_runners/logistic_stats_runner.py +28 -0
  11. tanml/check_runners/model_meta_runner.py +23 -0
  12. tanml/check_runners/performance_runner.py +28 -0
  13. tanml/check_runners/raw_data_runner.py +41 -0
  14. tanml/check_runners/rule_engine_runner.py +5 -0
  15. tanml/check_runners/stress_test_runner.py +26 -0
  16. tanml/check_runners/vif_runner.py +54 -0
  17. tanml/checks/__init__.py +0 -0
  18. tanml/checks/base.py +20 -0
  19. tanml/checks/cleaning_repro.py +47 -0
  20. tanml/checks/correlation.py +61 -0
  21. tanml/checks/data_quality.py +26 -0
  22. tanml/checks/eda.py +67 -0
  23. tanml/checks/explainability/shap_check.py +55 -0
  24. tanml/checks/input_cluster.py +109 -0
  25. tanml/checks/logit_stats.py +59 -0
  26. tanml/checks/model_contents.py +40 -0
  27. tanml/checks/model_meta.py +50 -0
  28. tanml/checks/performance.py +90 -0
  29. tanml/checks/raw_data.py +47 -0
  30. tanml/checks/rule_engine.py +45 -0
  31. tanml/checks/stress_test.py +64 -0
  32. tanml/checks/vif.py +51 -0
  33. tanml/cli/__init__.py +0 -0
  34. tanml/cli/arg_parser.py +31 -0
  35. tanml/cli/init_cmd.py +8 -0
  36. tanml/cli/main.py +27 -0
  37. tanml/cli/validate_cmd.py +7 -0
  38. tanml/config_templates/__init__.py +0 -0
  39. tanml/config_templates/rules_multiple_models_datasets.yaml +144 -0
  40. tanml/config_templates/rules_one_dataset_segment_column.yaml +140 -0
  41. tanml/config_templates/rules_one_model_one_dataset.yaml +143 -0
  42. tanml/engine/__init__.py +0 -0
  43. tanml/engine/check_agent_registry.py +42 -0
  44. tanml/engine/core_engine_agent.py +115 -0
  45. tanml/engine/segmentation_agent.py +118 -0
  46. tanml/engine/validation_agent.py +91 -0
  47. tanml/report/report_builder.py +230 -0
  48. tanml/report/templates/report_template.docx +0 -0
  49. tanml/utils/__init__.py +0 -0
  50. tanml/utils/data_loader.py +17 -0
  51. tanml/utils/model_loader.py +35 -0
  52. tanml/utils/r_loader.py +30 -0
  53. tanml/utils/sas_loader.py +50 -0
  54. tanml/utils/yaml_generator.py +34 -0
  55. tanml/utils/yaml_loader.py +5 -0
  56. tanml/validate.py +209 -0
  57. tanml-0.1.6.dist-info/METADATA +317 -0
  58. tanml-0.1.6.dist-info/RECORD +62 -0
  59. tanml-0.1.6.dist-info/WHEEL +5 -0
  60. tanml-0.1.6.dist-info/entry_points.txt +2 -0
  61. tanml-0.1.6.dist-info/licenses/LICENSE +21 -0
  62. tanml-0.1.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import joblib
5
+ import importlib
6
+ from tanml.utils.sas_loader import SASLogisticModel
7
+ from tanml.utils.r_loader import RLogisticModel
8
+ from tanml.engine.validation_agent import SegmentValidator
9
+ from tanml.utils.data_loader import load_dataframe
10
+
11
+
12
+ def handle_segmentation(segment_config, rule_config, args=None, report_output=None):
13
+ global_raw_path = rule_config.get("paths", {}).get("raw_data")
14
+ segment_col = segment_config.get("column") # βœ… only required for Scenario C
15
+ report_template = report_output # e.g., "reports/report_{segment}.docx"
16
+
17
+ print("πŸ” Detected segmentation setup in rules.yaml. Running each segment run separately...")
18
+
19
+ for name, run_cfg in segment_config["runs"].items():
20
+ print(f"\nπŸ”Ή Validating segment: {name}")
21
+ model_path = run_cfg["model"]
22
+
23
+ # CASE 1: Retrain model from cleaned data (model: train)
24
+ if isinstance(model_path, str) and model_path.lower() == "train":
25
+ print(f"πŸ› οΈ Retraining model from cleaned data for segment: {name}")
26
+
27
+ if "cleaned" in run_cfg:
28
+ cleaned_df_run = load_dataframe(run_cfg["cleaned"])
29
+ else:
30
+ full_cleaned = load_dataframe(rule_config["paths"]["cleaned_data"])
31
+ if segment_col:
32
+ cleaned_df_run = full_cleaned[full_cleaned[segment_col] == name]
33
+ else:
34
+ raise ValueError("Missing segment column for slicing cleaned data.")
35
+
36
+ X = cleaned_df_run[rule_config["model"]["features"]]
37
+ y = cleaned_df_run[rule_config["model"]["target"]]
38
+
39
+ model_source = rule_config.get("model_source", {})
40
+ model_type = model_source.get("type")
41
+ model_module = model_source.get("module")
42
+ model_params = model_source.get("hyperparameters", {})
43
+
44
+ if not model_type or not model_module:
45
+ raise ValueError("❌ 'model_source.type' and 'model_source.module' must be defined for retraining.")
46
+
47
+ model_class = getattr(importlib.import_module(model_module), model_type)
48
+ model = model_class(**model_params)
49
+ print(f"πŸ“¦ Using model: {model}")
50
+
51
+ model.fit(X, y)
52
+ print(f"βœ… Retrained {model_type} for segment '{name}'")
53
+
54
+ elif isinstance(model_path, str) and "r_logistic" in model_path.lower():
55
+ model = RLogisticModel(model_path)
56
+
57
+ elif isinstance(model_path, str) and model_path.endswith(".pkl"):
58
+ model = joblib.load(model_path)
59
+
60
+ elif isinstance(model_path, str) and model_path.endswith(".csv"):
61
+ base = os.path.splitext(model_path)[0]
62
+ model = SASLogisticModel(
63
+ coeffs_path=model_path,
64
+ intercept_path=base + "_intercept.txt",
65
+ feature_order_path=base + "_features.txt"
66
+ )
67
+
68
+ else:
69
+ raise ValueError(f"❌ Unsupported model format for segment '{name}': {model_path}")
70
+
71
+ # Load and optionally slice raw data
72
+ raw_df_run = None
73
+ if global_raw_path and os.path.exists(global_raw_path):
74
+ full_raw = load_dataframe(global_raw_path)
75
+ if segment_col:
76
+ raw_df_run = full_raw[full_raw[segment_col] == name]
77
+ else:
78
+ raw_df_run = full_raw
79
+
80
+ print(f"[DEBUG] raw_df_run for {name}: {type(raw_df_run)}, rows = {len(raw_df_run) if raw_df_run is not None else 'None'}")
81
+
82
+ # Load cleaned data
83
+ if "cleaned" in run_cfg:
84
+ cleaned_df_run = load_dataframe(run_cfg["cleaned"])
85
+ else:
86
+ full_cleaned = load_dataframe(rule_config["paths"]["cleaned_data"])
87
+ if segment_col:
88
+ cleaned_df_run = full_cleaned[full_cleaned[segment_col] == name]
89
+ else:
90
+ cleaned_df_run = full_cleaned
91
+
92
+ # Format output path
93
+ report_base = report_template.format(segment=name)
94
+
95
+ # Run validation
96
+ validator = SegmentValidator(
97
+ segment_column=segment_col,
98
+ segment_values=[name],
99
+ model=model,
100
+ raw_df=raw_df_run,
101
+ cleaned_df=cleaned_df_run,
102
+ target_col=rule_config.get("model", {}).get("target"),
103
+ config=rule_config,
104
+ segment_name=name,
105
+ report_base=report_base
106
+ )
107
+ results = validator.run()
108
+
109
+ # Extract report path from SegmentValidator result
110
+ report_base_path = report_template.format(segment=name)
111
+ report_path = os.path.join(report_base_path, f"report_{name}.docx")
112
+
113
+ print(f"πŸ“„ Report saved: {report_path}")
114
+ print(f"βœ… Segment '{name}' validated.")
115
+
116
+
117
+ print("βœ… All segment runs completed.")
118
+ return True
@@ -0,0 +1,91 @@
1
+ from __future__ import annotations
2
+ from importlib.resources import files
3
+ from datetime import datetime
4
+ import tzlocal
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from tanml.engine.core_engine_agent import ValidationEngine
9
+ from tanml.report.report_builder import ReportBuilder
10
+ from tanml.engine.check_agent_registry import CHECK_RUNNER_REGISTRY
11
+
12
+
13
+ class SegmentValidator:
14
+ def __init__(
15
+ self,
16
+ model,
17
+ raw_df,
18
+ cleaned_df,
19
+ config,
20
+ target_col=None,
21
+ segment_column=None,
22
+ segment_values=None,
23
+ segment_name=None,
24
+ report_base="reports"
25
+ ):
26
+ self.model = model
27
+ self.raw_df = raw_df
28
+ self.cleaned_df = cleaned_df
29
+ self.config = config
30
+ self.target_col = target_col
31
+ self.segment_column = segment_column
32
+ self.segment_values = segment_values
33
+ self.segment_name = segment_name
34
+ self.report_base = report_base
35
+
36
+ def run(self):
37
+ if self.segment_name:
38
+ return self._run_single(self.cleaned_df, self.segment_name)
39
+
40
+ if not self.segment_column:
41
+ raise ValueError("Segmentation column not specified in rules.yaml")
42
+
43
+ results = {}
44
+ for segment in self.segment_values:
45
+ segment_df = self.cleaned_df[self.cleaned_df[self.segment_column] == segment]
46
+ print(f"πŸ”Ή Running validation for segment value: {segment}")
47
+ result = self._run_single(segment_df, segment)
48
+ results[segment] = result
49
+
50
+ return results
51
+
52
+ def _run_single(self, segment_df, segment_name):
53
+ if self.target_col is None or self.target_col not in segment_df.columns:
54
+ raise ValueError(f"❌ Target column '{self.target_col}' missing in cleaned data for segment '{segment_name}'")
55
+
56
+ y = segment_df[self.target_col]
57
+ cols_to_drop = [c for c in (self.target_col, self.segment_column) if c in segment_df.columns]
58
+ X = segment_df.drop(columns=cols_to_drop)
59
+
60
+ # Pass raw_df to ValidationEngine so CleaningReproCheck works
61
+ engine = ValidationEngine(self.model, X, X, y, y, self.config, segment_df, self.raw_df)
62
+ results = engine.run_all_checks()
63
+
64
+ local_tz = tzlocal.get_localzone()
65
+ now = datetime.now(local_tz)
66
+ results["validation_date"] = now.strftime("%Y-%m-%d %H:%M:%S %Z (UTC%z)")
67
+ results["model_path"] = self.model.__class__.__name__
68
+ results["validated_by"] = "TanML Automated Validator"
69
+ results["rules"] = self.config
70
+
71
+ report_base_formatted = self.report_base.format(segment=segment_name) if "{segment}" in self.report_base else self.report_base
72
+ output_dir = Path(report_base_formatted)
73
+ output_dir.mkdir(parents=True, exist_ok=True)
74
+
75
+ tpl_cfg = self.config.get("output", {}).get("template_path") # may be None
76
+ template_path = (
77
+ Path(tpl_cfg).expanduser() if tpl_cfg
78
+ else files("tanml.report.templates").joinpath("report_template.docx")
79
+ )
80
+
81
+ output_path = output_dir / f"report_{segment_name}.docx"
82
+
83
+ builder = ReportBuilder(results, template_path, output_path)
84
+ builder.build()
85
+
86
+ print(f"πŸ“„ Report saved for segment '{segment_name}': {output_path}")
87
+ print(json.dumps(results, indent=2, default=str))
88
+ return results
89
+
90
+
91
+ __all__ = ["SegmentValidator", "CHECK_RUNNER_REGISTRY"]
@@ -0,0 +1,230 @@
1
+ # tanml/report/report_builder.py
2
+ from docxtpl import DocxTemplate, InlineImage
3
+ from docx.shared import Inches, Mm
4
+ from docx import Document
5
+ from pathlib import Path
6
+ import os, imgkit, copy as pycopy
7
+ from importlib.resources import files
8
+
9
+ TMP_DIR = Path(__file__).resolve().parents[1] / "tmp_report_assets"
10
+ TMP_DIR.mkdir(exist_ok=True)
11
+
12
+
13
+ class AttrDict(dict):
14
+ def __getattr__(self, item):
15
+ try:
16
+ return self[item]
17
+ except KeyError:
18
+ raise AttributeError(item)
19
+
20
+
21
+ class ReportBuilder:
22
+ """Build a Word report from validation results."""
23
+
24
+ def __init__(self, results, template_path, output_path):
25
+ self.results = results
26
+ self.template_path = template_path or files("tanml.report.templates").joinpath("report_template.docx")
27
+
28
+ self.output_path = output_path
29
+
30
+ corr = results.get("CorrelationCheck", {})
31
+ self.corr_heatmap_path = corr.get("heatmap_path")
32
+ self.corr_pearson_path = corr.get("pearson_csv", "N/A")
33
+ self.corr_spearman_path = corr.get("spearman_csv", "N/A")
34
+
35
+ def _grab(self, name, default=None):
36
+ return (
37
+ self.results.get(name)
38
+ or self.results.get("check_results", {}).get(name)
39
+ or default
40
+ )
41
+
42
+ def build(self):
43
+ doc = DocxTemplate(str(self.template_path))
44
+
45
+
46
+ # Ensure RuleEngineCheck exists so template never crashes
47
+ self.results.setdefault(
48
+ "RuleEngineCheck", AttrDict({"rules": {}, "overall_pass": True})
49
+ )
50
+
51
+ # Jinja context
52
+ ctx = pycopy.deepcopy(self.results)
53
+ ctx.update(self.results.get("check_results", {})) # 1st-level flatten
54
+
55
+ for k, v in list(ctx.items()):
56
+ if isinstance(v, dict) and k in v and len(v) == 1:
57
+ ctx[k] = v[k]
58
+
59
+ for k, note in [
60
+ ("RawDataCheck", "Raw-data check skipped"),
61
+ ("CleaningReproCheck", "Cleaning-repro check skipped"),
62
+ ]:
63
+ ctx.setdefault(k, AttrDict({"note": note}))
64
+
65
+ if "ModelMetaCheck" not in ctx or "model_type" not in ctx["ModelMetaCheck"]:
66
+ meta_fields = [
67
+ "model_type",
68
+ "model_class",
69
+ "module",
70
+ "n_features",
71
+ "feature_names",
72
+ "n_train_rows",
73
+ "target_balance",
74
+ "hyperparam_table",
75
+ "attributes",
76
+ ]
77
+ meta = {f: self.results.get(f) for f in meta_fields if self.results.get(f) is not None}
78
+ ctx["ModelMetaCheck"] = AttrDict(meta or {"note": "Model metadata not available"})
79
+
80
+ # SHAP image
81
+ shap_path = self._grab("SHAPCheck", {}).get("shap_plot_path")
82
+ ctx["shap_plot"] = (
83
+ InlineImage(doc, shap_path, width=Inches(5))
84
+ if shap_path and os.path.exists(shap_path)
85
+ else "SHAP plot not available"
86
+ )
87
+
88
+ # EDA
89
+ eda = self._grab("EDACheck", {})
90
+ ctx["eda_summary_path"] = eda.get("summary_stats", "N/A")
91
+ ctx["eda_missing_path"] = eda.get("missing_values", "N/A")
92
+ ctx["eda_images"] = [
93
+ InlineImage(doc, os.path.join("reports/eda", fn), width=Inches(4.5))
94
+ if os.path.exists(os.path.join("reports/eda", fn))
95
+ else f"Missing: {fn}"
96
+ for fn in eda.get("visualizations", [])
97
+ ]
98
+
99
+ # Correlation visuals
100
+ if self.corr_heatmap_path and os.path.exists(self.corr_heatmap_path):
101
+ ctx["correlation_heatmap"] = InlineImage(
102
+ doc, self.corr_heatmap_path, width=Inches(5)
103
+ )
104
+ else:
105
+ ctx["correlation_heatmap"] = "Heatmap not available"
106
+ ctx["correlation_pearson_path"] = self.corr_pearson_path
107
+ ctx["correlation_spearman_path"] = self.corr_spearman_path
108
+
109
+ # Performance
110
+ perf = self._grab("PerformanceCheck", {})
111
+ if not perf:
112
+ perf = {
113
+ "accuracy": "N/A",
114
+ "auc": "N/A",
115
+ "ks": "N/A",
116
+ "f1": "N/A",
117
+ "confusion_matrix": [],
118
+ }
119
+ ctx.setdefault("check_results", {})["PerformanceCheck"] = perf
120
+ ctx["PerformanceCheck"] = perf
121
+
122
+ # Logistic summary image
123
+ if "LogisticStatsCheck_obj" in self.results:
124
+ try:
125
+ add_logit_summary_image(
126
+ doc, self.results["LogisticStatsCheck_obj"], ctx, "LogitSummaryImg"
127
+ )
128
+ except Exception as e:
129
+ print("⚠️ logistic summary image failed:", e)
130
+
131
+ # VIF
132
+ vif = self._grab("VIFCheck", {})
133
+ ctx["VIFCheck"] = AttrDict(
134
+ vif
135
+ if isinstance(vif, dict) and "vif_table" in vif
136
+ else {"vif_table": [], "high_vif_features": [], "error": "Invalid VIFCheck"}
137
+ )
138
+
139
+ # Stress / cluster
140
+ if isinstance(self.results.get("StressTestCheck"), list):
141
+ ctx["StressTestCheck"] = {"table": self.results["StressTestCheck"]}
142
+
143
+ cluster_rows = self._grab("InputClusterCheck", {}).get("cluster_table", [])
144
+ ctx.setdefault("InputClusterCheck", {})["cluster_table"] = [
145
+ {
146
+ "Cluster": r.get("Cluster") or r.get("cluster"),
147
+ "Count": r.get("Count") or r.get("count"),
148
+ "Percent": r.get("Percent") or r.get("percent"),
149
+ }
150
+ for r in cluster_rows
151
+ if isinstance(r, dict)
152
+ ]
153
+ plot_path = self._grab("InputClusterCheck", {}).get("cluster_plot_img")
154
+ ctx["InputClusterCheck"]["cluster_plot_img"] = (
155
+ InlineImage(doc, plot_path, width=Inches(5))
156
+ if plot_path and os.path.exists(plot_path)
157
+ else "Plot not available"
158
+ )
159
+
160
+ # Render DOCX template
161
+ print("🟒 ctx top-level keys:", list(ctx.keys()))
162
+ print("πŸ” RawDataCheck value:", ctx.get("RawDataCheck"))
163
+
164
+ doc.render(ctx)
165
+ doc.save(self.output_path)
166
+
167
+ # Auto-insert tables after anchors
168
+ tbl_specs = [
169
+ {
170
+ "anchor": "Stress Testing Results",
171
+ "headers": [
172
+ "feature",
173
+ "perturbation",
174
+ "accuracy",
175
+ "auc",
176
+ "delta_accuracy",
177
+ "delta_auc",
178
+ ],
179
+ "rows": ctx.get("StressTestCheck", {}).get("table", []),
180
+ },
181
+ {
182
+ "anchor": "Cluster Summary Table:",
183
+ "headers": ["Cluster", "Count", "Percent"],
184
+ "rows": ctx.get("InputClusterCheck", {}).get("cluster_table", []),
185
+ },
186
+ {
187
+ "anchor": "Variance Inflation Factor (VIF) Check",
188
+ "headers": ["Feature", "VIF"],
189
+ "rows": ctx.get("VIFCheck", {}).get("vif_table", []),
190
+ },
191
+ ]
192
+
193
+ docx = Document(self.output_path)
194
+ for spec in tbl_specs:
195
+ if spec["rows"]:
196
+ tbl = build_table(docx, spec["headers"], spec["rows"])
197
+ insert_after(docx, spec["anchor"], tbl)
198
+ print(f"βœ… added table after Β«{spec['anchor']}Β»")
199
+ docx.save(self.output_path)
200
+
201
+
202
+ def build_table(doc, headers, rows):
203
+ tbl = doc.add_table(rows=1, cols=len(headers))
204
+ tbl.style = "Table Grid"
205
+ for i, h in enumerate(headers):
206
+ tbl.rows[0].cells[i].text = str(h)
207
+ for r in rows:
208
+ vals = [r.get(h, "") for h in headers] if isinstance(r, dict) else list(r)
209
+ vals += [""] * (len(headers) - len(vals))
210
+ row = tbl.add_row().cells
211
+ for i, v in enumerate(vals):
212
+ row[i].text = str(v)
213
+ return tbl
214
+
215
+
216
+ def insert_after(doc, anchor, tbl):
217
+ for p in doc.paragraphs:
218
+ if anchor.lower() in p.text.lower():
219
+ parent = p._p.getparent()
220
+ parent.insert(parent.index(p._p) + 1, tbl._tbl)
221
+ return
222
+ print(f"⚠️ anchor «{anchor}» not found")
223
+
224
+
225
+ def add_logit_summary_image(tpl_doc, sm_results, ctx, key):
226
+ html = TMP_DIR / "logit_summary.html"
227
+ html.write_text(sm_results.summary().as_html(), encoding="utf8")
228
+ png = TMP_DIR / "logit_summary.png"
229
+ imgkit.from_file(str(html), str(png), options={"quiet": ""})
230
+ ctx[key] = InlineImage(tpl_doc, str(png), width=Mm(160))
File without changes
@@ -0,0 +1,17 @@
1
+ import os
2
+ import pandas as pd
3
+
4
+ def load_dataframe(filepath):
5
+ ext = os.path.splitext(filepath)[1].lower()
6
+ if ext == ".csv":
7
+ return pd.read_csv(filepath)
8
+ elif ext in [".xls", ".xlsx"]:
9
+ return pd.read_excel(filepath)
10
+ elif ext == ".parquet":
11
+ return pd.read_parquet(filepath)
12
+ elif ext == ".sas7bdat":
13
+ return pd.read_sas(filepath)
14
+ elif ext in [".txt", ".tsv"]:
15
+ return pd.read_csv(filepath, sep="\t")
16
+ else:
17
+ raise ValueError(f"Unsupported file format: {ext}")
@@ -0,0 +1,35 @@
1
+ # tanml/utils/model_loader.py
2
+
3
+ import os
4
+ import joblib
5
+ from tanml.utils.sas_loader import SASLogisticModel
6
+ from tanml.utils.r_loader import RLogisticModel
7
+
8
+ def load_model(model_path):
9
+ """
10
+ Load a model from path. Supports:
11
+ - sklearn/xgboost .pkl
12
+ - SAS .csv with _intercept.txt and _features.txt
13
+ - R exported logistic CSV
14
+ """
15
+ if not model_path:
16
+ raise ValueError("❌ No model path provided.")
17
+
18
+ if "r_logistic" in model_path.lower():
19
+ print("βœ… Detected R Logistic Regression model")
20
+ return RLogisticModel(model_path)
21
+
22
+ elif model_path.endswith(".pkl"):
23
+ print(f"βœ… Loading sklearn/XGB model from {model_path}")
24
+ return joblib.load(model_path)
25
+
26
+ elif model_path.endswith(".csv"):
27
+ base = os.path.splitext(model_path)[0]
28
+ return SASLogisticModel(
29
+ coeffs_path=model_path,
30
+ intercept_path=base + "_intercept.txt",
31
+ feature_order_path=base + "_features.txt"
32
+ )
33
+
34
+ else:
35
+ raise ValueError("❌ Unsupported model format. Use .pkl, .csv, or R model CSV")
@@ -0,0 +1,30 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ class RLogisticModel:
5
+ """
6
+ Wrapper for logistic regression models exported from R.
7
+ Assumes CSV with columns: ID, y_true, y_pred_proba
8
+ """
9
+
10
+ def __init__(self, model_path: str):
11
+ self.model_path = model_path
12
+ self.df = pd.read_csv(model_path)
13
+
14
+ # Check required columns exist
15
+ expected_cols = {'y_true', 'y_pred_proba'}
16
+ if not expected_cols.issubset(set(self.df.columns)):
17
+ raise ValueError(f"R model CSV must contain columns: {expected_cols}")
18
+
19
+ self.y_true = self.df['y_true'].values
20
+ self.y_pred_proba = self.df['y_pred_proba'].values
21
+
22
+ def predict_proba(self, X=None):
23
+ """
24
+ Mimics sklearn’s predict_proba format: n_samples x 2
25
+ """
26
+ proba = self.y_pred_proba.reshape(-1, 1)
27
+ return np.hstack([1 - proba, proba])
28
+
29
+ def get_true_labels(self):
30
+ return self.y_true
@@ -0,0 +1,50 @@
1
+ # File: tanml/utils/sas_loader.py
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+
7
+ class SASLogisticModel:
8
+ def __init__(self, coeffs_path, intercept_path, feature_order_path):
9
+ self.coeffs_path = coeffs_path
10
+ self.intercept_path = intercept_path
11
+ self.feature_order_path = feature_order_path
12
+
13
+ self.coefficients = self._load_coefficients()
14
+ self.intercept = self._load_intercept()
15
+ self.feature_order = self._load_feature_order()
16
+
17
+ def _load_coefficients(self):
18
+ return pd.read_csv(self.coeffs_path, index_col=0).squeeze("columns")
19
+
20
+ def _load_intercept(self):
21
+ with open(self.intercept_path) as f:
22
+ return float(f.read().strip())
23
+
24
+ def _load_feature_order(self):
25
+ with open(self.feature_order_path) as f:
26
+ return [line.strip() for line in f.readlines()]
27
+
28
+ def predict_proba(self, X):
29
+ """
30
+ Return a NumPy array shaped (n_samples, 2) like sklearn:
31
+ [:, 0] = P(class 0), [:, 1] = P(class 1)
32
+ """
33
+ X = X[self.feature_order]
34
+ logits = X.dot(self.coefficients) + self.intercept
35
+
36
+ # numeric stability clamp
37
+ logits = logits.clip(-700, 700)
38
+
39
+ proba_1 = 1 / (1 + np.exp(-logits))
40
+ proba_0 = 1 - proba_1
41
+ return np.vstack([proba_0, proba_1]).T # shape (n, 2)
42
+
43
+ def predict(self, X):
44
+ """
45
+ Return class labels (0/1) based on 0.5 threshold.
46
+ Works with the NumPy array returned by predict_proba().
47
+ """
48
+ proba_1 = self.predict_proba(X)[:, 1] # probability of class 1
49
+ return (proba_1 >= 0.5).astype(int)
50
+
@@ -0,0 +1,34 @@
1
+ from pathlib import Path
2
+ import shutil
3
+
4
+ def generate_rules_yaml(
5
+ scenario: str = "A",
6
+ dest_path: str = "rules.yaml",
7
+ overwrite: bool = False
8
+ ):
9
+
10
+ scenario = scenario.upper()
11
+ template_map = {
12
+ "A": "rules_one_model_one_dataset.yaml",
13
+ "B": "rules_multiple_models_datasets.yaml",
14
+ "C": "rules_one_dataset_segment_column.yaml"
15
+ }
16
+
17
+ if scenario not in template_map:
18
+ raise ValueError("Invalid scenario. Must be 'A', 'B', or 'C'.")
19
+
20
+ src = Path(__file__).parent.parent / "config_templates" / template_map[scenario]
21
+ if not src.exists():
22
+ raise FileNotFoundError(f"Template not found at {src}")
23
+
24
+ dst = Path(dest_path)
25
+ if dst.exists() and not overwrite:
26
+ raise FileExistsError(f"{dst} already exists. Use --overwrite to replace it.")
27
+
28
+ if dst.parent and not dst.parent.exists():
29
+ dst.parent.mkdir(parents=True, exist_ok=True)
30
+
31
+ shutil.copyfile(src, dst)
32
+
33
+ print(f"βœ… Created: {dst.resolve()} for Scenario {scenario}")
34
+ print("πŸ‘‰ Now edit this YAML to fill in your model, data, and feature details.")
@@ -0,0 +1,5 @@
1
+ import yaml
2
+
3
+ def load_yaml_config(path):
4
+ with open(path, "r") as f:
5
+ return yaml.safe_load(f)