tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tanml might be problematic. Click here for more details.
- tanml/__init__.py +1 -1
- tanml/check_runners/cleaning_repro_runner.py +2 -2
- tanml/check_runners/correlation_runner.py +49 -12
- tanml/check_runners/explainability_runner.py +12 -22
- tanml/check_runners/logistic_stats_runner.py +196 -17
- tanml/check_runners/performance_runner.py +82 -26
- tanml/check_runners/raw_data_runner.py +29 -14
- tanml/check_runners/regression_metrics_runner.py +195 -0
- tanml/check_runners/stress_test_runner.py +23 -6
- tanml/check_runners/vif_runner.py +33 -27
- tanml/checks/correlation.py +241 -41
- tanml/checks/explainability/shap_check.py +261 -29
- tanml/checks/logit_stats.py +186 -54
- tanml/checks/performance_classification.py +305 -0
- tanml/checks/raw_data.py +58 -23
- tanml/checks/regression_metrics.py +167 -0
- tanml/checks/stress_test.py +157 -53
- tanml/cli/main.py +99 -27
- tanml/engine/check_agent_registry.py +20 -10
- tanml/engine/core_engine_agent.py +199 -37
- tanml/models/registry.py +329 -0
- tanml/report/report_builder.py +1180 -147
- tanml/report/templates/report_template_cls.docx +0 -0
- tanml/report/templates/report_template_reg.docx +0 -0
- tanml/ui/app.py +1205 -0
- tanml/utils/data_loader.py +105 -15
- tanml-0.1.7.dist-info/METADATA +164 -0
- tanml-0.1.7.dist-info/RECORD +54 -0
- tanml/cli/arg_parser.py +0 -31
- tanml/cli/init_cmd.py +0 -8
- tanml/cli/validate_cmd.py +0 -7
- tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
- tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
- tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
- tanml/engine/segmentation_agent.py +0 -118
- tanml/engine/validation_agent.py +0 -91
- tanml/report/templates/report_template.docx +0 -0
- tanml/utils/model_loader.py +0 -35
- tanml/utils/r_loader.py +0 -30
- tanml/utils/sas_loader.py +0 -50
- tanml/utils/yaml_generator.py +0 -34
- tanml/utils/yaml_loader.py +0 -5
- tanml/validate.py +0 -209
- tanml-0.1.6.dist-info/METADATA +0 -317
- tanml-0.1.6.dist-info/RECORD +0 -62
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
|
@@ -1,22 +1,27 @@
|
|
|
1
|
-
|
|
2
|
-
ValidationEngine – runs all registered check-runners and assembles a
|
|
3
|
-
single results dictionary that the ReportBuilder / Jinja template expects.
|
|
4
|
-
"""
|
|
1
|
+
# tanml/engine/core_engine_agent.py
|
|
5
2
|
|
|
6
3
|
from tanml.engine.check_agent_registry import CHECK_RUNNER_REGISTRY
|
|
7
|
-
#from tanml.checks.cleaning_repro import CleaningReproCheck
|
|
8
4
|
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import statsmodels.api as sm
|
|
10
|
+
_SM_AVAILABLE = True
|
|
11
|
+
except Exception:
|
|
12
|
+
_SM_AVAILABLE = False
|
|
9
13
|
|
|
10
14
|
KEEP_AS_NESTED = {
|
|
11
15
|
"DataQualityCheck",
|
|
12
16
|
"StressTestCheck",
|
|
13
17
|
"InputClusterCheck",
|
|
18
|
+
"InputClusterCoverageCheck",
|
|
14
19
|
"RawDataCheck",
|
|
15
|
-
#"CleaningReproCheck",
|
|
16
20
|
"SHAPCheck",
|
|
17
21
|
"VIFCheck",
|
|
18
22
|
"CorrelationCheck",
|
|
19
23
|
"EDACheck",
|
|
24
|
+
# "RegressionMetrics", # you can keep nested if desired
|
|
20
25
|
}
|
|
21
26
|
|
|
22
27
|
|
|
@@ -30,8 +35,8 @@ class ValidationEngine:
|
|
|
30
35
|
y_test,
|
|
31
36
|
config,
|
|
32
37
|
cleaned_data,
|
|
33
|
-
raw_df=None
|
|
34
|
-
ctx=None
|
|
38
|
+
raw_df=None,
|
|
39
|
+
ctx=None
|
|
35
40
|
):
|
|
36
41
|
self.model = model
|
|
37
42
|
self.X_train = X_train
|
|
@@ -40,15 +45,149 @@ class ValidationEngine:
|
|
|
40
45
|
self.y_test = y_test
|
|
41
46
|
self.config = config
|
|
42
47
|
self.cleaned_data = cleaned_data
|
|
43
|
-
self.raw_df = raw_df
|
|
48
|
+
self.raw_df = raw_df
|
|
44
49
|
|
|
50
|
+
# allow resuming if config had check_results
|
|
45
51
|
self.results = dict(config.get("check_results", {}))
|
|
46
52
|
self.ctx = ctx or {}
|
|
47
|
-
|
|
53
|
+
|
|
54
|
+
self.task_type = self._infer_task_type(self.y_train, config, model)
|
|
55
|
+
|
|
56
|
+
# --- better detection logic --------------------------------------------
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _infer_task_type(y, config=None, model=None):
|
|
59
|
+
"""
|
|
60
|
+
Decide if task is classification or regression.
|
|
61
|
+
Priority:
|
|
62
|
+
1. config["model"]["type"]
|
|
63
|
+
2. model._estimator_type (sklearn)
|
|
64
|
+
3. y values (unique count)
|
|
65
|
+
"""
|
|
66
|
+
# 1. Config hint
|
|
67
|
+
try:
|
|
68
|
+
mtype = (config or {}).get("model", {}).get("type", "")
|
|
69
|
+
if isinstance(mtype, str):
|
|
70
|
+
mtype = mtype.lower()
|
|
71
|
+
if "class" in mtype:
|
|
72
|
+
return "classification"
|
|
73
|
+
if "regress" in mtype:
|
|
74
|
+
return "regression"
|
|
75
|
+
except Exception:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
# 2. Model introspection
|
|
79
|
+
try:
|
|
80
|
+
if hasattr(model, "_estimator_type"):
|
|
81
|
+
est = getattr(model, "_estimator_type", "")
|
|
82
|
+
if est == "classifier":
|
|
83
|
+
return "classification"
|
|
84
|
+
if est == "regressor":
|
|
85
|
+
return "regression"
|
|
86
|
+
if hasattr(model, "predict_proba") or hasattr(model, "decision_function"):
|
|
87
|
+
return "classification"
|
|
88
|
+
except Exception:
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
# 3. Label based
|
|
92
|
+
try:
|
|
93
|
+
if isinstance(y, (pd.Series, pd.DataFrame)):
|
|
94
|
+
s = y.squeeze()
|
|
95
|
+
else:
|
|
96
|
+
s = np.asarray(y).reshape(-1)
|
|
97
|
+
|
|
98
|
+
unique_vals = pd.Series(s).dropna().unique()
|
|
99
|
+
# Heuristic: small discrete set -> classification
|
|
100
|
+
if pd.api.types.is_numeric_dtype(s):
|
|
101
|
+
if len(unique_vals) <= 10:
|
|
102
|
+
return "classification"
|
|
103
|
+
return "regression"
|
|
104
|
+
else:
|
|
105
|
+
# non-numeric target -> classification
|
|
106
|
+
return "classification"
|
|
107
|
+
except Exception:
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
# Fallback
|
|
111
|
+
return "classification"
|
|
112
|
+
# -----------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
def _pick(self, *paths, default=None):
|
|
115
|
+
for path in paths:
|
|
116
|
+
cur = self.results
|
|
117
|
+
ok = True
|
|
118
|
+
for p in path:
|
|
119
|
+
if isinstance(cur, dict) and p in cur:
|
|
120
|
+
cur = cur[p]
|
|
121
|
+
else:
|
|
122
|
+
ok = False
|
|
123
|
+
break
|
|
124
|
+
if ok:
|
|
125
|
+
return cur
|
|
126
|
+
return default
|
|
127
|
+
|
|
128
|
+
def _compute_linear_stats(self):
|
|
129
|
+
"""
|
|
130
|
+
Optional: compute a statsmodels OLS summary + coefficient table for regression runs.
|
|
131
|
+
Writes results into self.results["LinearStats"].
|
|
132
|
+
"""
|
|
133
|
+
if self.task_type != "regression":
|
|
134
|
+
return
|
|
135
|
+
if not _SM_AVAILABLE:
|
|
136
|
+
self.results["LinearStats"] = {
|
|
137
|
+
"error": "statsmodels not available; install `statsmodels` to see OLS summary."
|
|
138
|
+
}
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
# add constant and fit OLS on TRAIN split to mirror sklearn fit
|
|
143
|
+
X = self.X_train
|
|
144
|
+
y = self.y_train
|
|
145
|
+
Xc = sm.add_constant(X, has_constant="add")
|
|
146
|
+
ols_model = sm.OLS(y, Xc, missing="drop")
|
|
147
|
+
ols_res = ols_model.fit()
|
|
148
|
+
|
|
149
|
+
# Build coefficient table (including intercept 'const')
|
|
150
|
+
params = ols_res.params
|
|
151
|
+
bse = ols_res.bse
|
|
152
|
+
tvals = ols_res.tvalues
|
|
153
|
+
pvals = ols_res.pvalues
|
|
154
|
+
ci = ols_res.conf_int(alpha=0.05)
|
|
155
|
+
ci.columns = ["ci_low", "ci_high"]
|
|
156
|
+
|
|
157
|
+
rows = []
|
|
158
|
+
for name in params.index:
|
|
159
|
+
rows.append({
|
|
160
|
+
"feature": name,
|
|
161
|
+
"coef": float(params[name]),
|
|
162
|
+
"std err": float(bse.get(name, float("nan"))),
|
|
163
|
+
"t": float(tvals.get(name, float("nan"))),
|
|
164
|
+
"P>|t|": float(pvals.get(name, float("nan"))),
|
|
165
|
+
"ci_low": float(ci.loc[name, "ci_low"]) if name in ci.index else None,
|
|
166
|
+
"ci_high": float(ci.loc[name, "ci_high"]) if name in ci.index else None,
|
|
167
|
+
})
|
|
168
|
+
|
|
169
|
+
self.results["LinearStats"] = {
|
|
170
|
+
"summary_text": ols_res.summary().as_text(),
|
|
171
|
+
"coeff_table": rows,
|
|
172
|
+
"status": "ok",
|
|
173
|
+
}
|
|
174
|
+
except Exception as e:
|
|
175
|
+
self.results["LinearStats"] = {"error": f"OLS stats failed: {e}"}
|
|
176
|
+
# ------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
def run_all_checks(self, progress_callback=None):
|
|
179
|
+
self.results["task_type"] = self.task_type
|
|
180
|
+
|
|
48
181
|
for check_name, runner_func in CHECK_RUNNER_REGISTRY.items():
|
|
49
182
|
if check_name in self.config.get("skip_checks", []):
|
|
50
183
|
continue
|
|
51
184
|
|
|
185
|
+
if progress_callback:
|
|
186
|
+
try:
|
|
187
|
+
progress_callback(f"Running {check_name}…")
|
|
188
|
+
except Exception:
|
|
189
|
+
pass
|
|
190
|
+
|
|
52
191
|
print(f"✅ Running {check_name}")
|
|
53
192
|
try:
|
|
54
193
|
result = runner_func(
|
|
@@ -59,30 +198,45 @@ class ValidationEngine:
|
|
|
59
198
|
self.y_test,
|
|
60
199
|
self.config,
|
|
61
200
|
self.cleaned_data,
|
|
62
|
-
raw_df=self.raw_df
|
|
201
|
+
raw_df=self.raw_df
|
|
63
202
|
)
|
|
64
|
-
|
|
65
203
|
self._integrate(check_name, result)
|
|
66
|
-
|
|
67
204
|
except Exception as e:
|
|
68
205
|
print(f"⚠️ {check_name} failed: {e}")
|
|
69
206
|
self.results[check_name] = {"error": str(e)}
|
|
70
207
|
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
#
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
208
|
+
# add OLS stats for regression (pretty coef table + p-values)
|
|
209
|
+
self._compute_linear_stats()
|
|
210
|
+
|
|
211
|
+
# -------- Build summary: TASK-AWARE --------
|
|
212
|
+
summary = {}
|
|
213
|
+
|
|
214
|
+
if self.task_type == "regression":
|
|
215
|
+
summary["rmse"] = self._pick(("RegressionMetrics", "rmse"))
|
|
216
|
+
summary["mae"] = self._pick(("RegressionMetrics", "mae"))
|
|
217
|
+
summary["r2"] = self._pick(("RegressionMetrics", "r2"))
|
|
218
|
+
else:
|
|
219
|
+
cls = self._pick(("performance", "classification", "summary")) or {}
|
|
220
|
+
summary["auc"] = cls.get("auc")
|
|
221
|
+
summary["ks"] = cls.get("ks")
|
|
222
|
+
summary["f1"] = cls.get("f1")
|
|
223
|
+
summary["pr_auc"] = cls.get("pr_auc")
|
|
224
|
+
|
|
225
|
+
# PSI (optional)
|
|
226
|
+
summary["max_psi"] = self._pick(
|
|
227
|
+
("PSICheck", "max_psi"),
|
|
228
|
+
("PopulationStabilityCheck", "max_psi"),
|
|
229
|
+
("max_psi",)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Count failed checks
|
|
233
|
+
failed = 0
|
|
234
|
+
for k, v in self.results.items():
|
|
235
|
+
if isinstance(v, dict) and v.get("status") == "fail":
|
|
236
|
+
failed += 1
|
|
237
|
+
summary["rules_failed"] = failed
|
|
238
|
+
|
|
239
|
+
self.results["summary"] = summary
|
|
86
240
|
self.results["check_results"] = dict(self.results)
|
|
87
241
|
return self.results
|
|
88
242
|
|
|
@@ -91,25 +245,33 @@ class ValidationEngine:
|
|
|
91
245
|
if not result:
|
|
92
246
|
return
|
|
93
247
|
|
|
94
|
-
# Special flatten for LogisticStatsCheck
|
|
95
248
|
if check_name == "LogisticStatsCheck":
|
|
96
249
|
self.results.update(result)
|
|
97
250
|
return
|
|
98
251
|
|
|
99
|
-
# If it's a simple object (rare), store as-is
|
|
100
252
|
if not isinstance(result, dict):
|
|
101
253
|
self.results[check_name] = result
|
|
102
254
|
return
|
|
103
255
|
|
|
104
|
-
|
|
105
|
-
|
|
256
|
+
cluster_aliases = {
|
|
257
|
+
"InputClusterCoverageCheck",
|
|
258
|
+
"InputClusterCoverage",
|
|
259
|
+
"ClusterCoverageCheck",
|
|
260
|
+
"InputClustersCheck",
|
|
261
|
+
}
|
|
262
|
+
if check_name in cluster_aliases:
|
|
106
263
|
self.results[check_name] = result
|
|
264
|
+
self.results["InputClusterCheck"] = result
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
if set(result.keys()) == {"InputClusterCheck"}:
|
|
268
|
+
self.results["InputClusterCheck"] = result["InputClusterCheck"]
|
|
107
269
|
return
|
|
108
270
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
self.results[check_name] = result[check_name]
|
|
271
|
+
if check_name in KEEP_AS_NESTED:
|
|
272
|
+
self.results[check_name] = result
|
|
112
273
|
return
|
|
113
274
|
|
|
114
|
-
|
|
115
|
-
|
|
275
|
+
if isinstance(result, dict):
|
|
276
|
+
self.results.update(result)
|
|
277
|
+
return
|
tanml/models/registry.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# tanml/models/registry.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Literal
|
|
5
|
+
|
|
6
|
+
Task = Literal["classification", "regression"]
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True)
|
|
9
|
+
class ModelSpec:
|
|
10
|
+
task: Task
|
|
11
|
+
import_path: str # e.g., "sklearn.ensemble.RandomForestClassifier"
|
|
12
|
+
defaults: Dict[str, Any] = field(default_factory=dict)
|
|
13
|
+
# UI schema: param -> (type, choices_or_None, help_or_None)
|
|
14
|
+
ui_schema: Dict[str, Tuple[str, Optional[Tuple[Any, ...]], Optional[str]]] = field(default_factory=dict)
|
|
15
|
+
aliases: Dict[str, str] = field(default_factory=dict) # optional param alias map
|
|
16
|
+
|
|
17
|
+
# -------------------- 20 MODELS --------------------
|
|
18
|
+
|
|
19
|
+
_REGISTRY: Dict[Tuple[str, str], ModelSpec] = {
|
|
20
|
+
# -------- Classification (10) --------
|
|
21
|
+
("sklearn", "LogisticRegression"): ModelSpec(
|
|
22
|
+
task="classification",
|
|
23
|
+
import_path="sklearn.linear_model.LogisticRegression",
|
|
24
|
+
defaults=dict(penalty="l2", solver="lbfgs", C=1.0, class_weight=None, max_iter=1000, random_state=42),
|
|
25
|
+
ui_schema={
|
|
26
|
+
"penalty": ("choice", ("l2", "l1"), "Regularization"),
|
|
27
|
+
"solver": ("choice", ("lbfgs", "liblinear", "saga"), "Solver"),
|
|
28
|
+
"C": ("float", None, "Inverse regularization strength"),
|
|
29
|
+
"class_weight": ("choice", (None, "balanced"), "Imbalance handling"),
|
|
30
|
+
"max_iter": ("int", None, "Max iterations"),
|
|
31
|
+
"random_state": ("int", None, "Seed"),
|
|
32
|
+
},
|
|
33
|
+
),
|
|
34
|
+
("sklearn", "RandomForestClassifier"): ModelSpec(
|
|
35
|
+
task="classification",
|
|
36
|
+
import_path="sklearn.ensemble.RandomForestClassifier",
|
|
37
|
+
defaults=dict(n_estimators=400, max_depth=16, min_samples_split=2, min_samples_leaf=1,
|
|
38
|
+
class_weight=None, random_state=42, n_jobs=-1),
|
|
39
|
+
ui_schema={
|
|
40
|
+
"n_estimators": ("int", None, "Number of trees"),
|
|
41
|
+
"max_depth": ("int", None, "Tree depth (None=unbounded)"),
|
|
42
|
+
"min_samples_split": ("int", None, "Min samples to split"),
|
|
43
|
+
"min_samples_leaf": ("int", None, "Min samples per leaf"),
|
|
44
|
+
"class_weight": ("choice", (None, "balanced", "balanced_subsample"), "Imbalance"),
|
|
45
|
+
"random_state": ("int", None, "Seed"),
|
|
46
|
+
},
|
|
47
|
+
),
|
|
48
|
+
("xgboost", "XGBClassifier"): ModelSpec(
|
|
49
|
+
task="classification",
|
|
50
|
+
import_path="xgboost.XGBClassifier",
|
|
51
|
+
defaults=dict(n_estimators=600, max_depth=6, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8,
|
|
52
|
+
min_child_weight=1, reg_lambda=1.0, tree_method="hist", random_state=42, n_jobs=-1),
|
|
53
|
+
ui_schema={
|
|
54
|
+
"n_estimators": ("int", None, "Boosting rounds"),
|
|
55
|
+
"max_depth": ("int", None, "Tree depth"),
|
|
56
|
+
"learning_rate": ("float", None, "Eta"),
|
|
57
|
+
"subsample": ("float", None, "Row subsample"),
|
|
58
|
+
"colsample_bytree": ("float", None, "Column subsample"),
|
|
59
|
+
"min_child_weight": ("float", None, "Min child weight"),
|
|
60
|
+
"reg_lambda": ("float", None, "L2 regularization"),
|
|
61
|
+
"tree_method": ("choice", ("hist", "auto"), "Grow method"),
|
|
62
|
+
"random_state": ("int", None, "Seed"),
|
|
63
|
+
"n_jobs": ("int", None, "Threads"),
|
|
64
|
+
},
|
|
65
|
+
),
|
|
66
|
+
("lightgbm", "LGBMClassifier"): ModelSpec(
|
|
67
|
+
task="classification",
|
|
68
|
+
import_path="lightgbm.LGBMClassifier",
|
|
69
|
+
defaults=dict(n_estimators=800, num_leaves=31, max_depth=-1, learning_rate=0.05, subsample=0.8,
|
|
70
|
+
colsample_bytree=0.8, min_child_samples=20, reg_lambda=1.0, random_state=42, n_jobs=-1),
|
|
71
|
+
ui_schema={
|
|
72
|
+
"n_estimators": ("int", None, "Boosting rounds"),
|
|
73
|
+
"num_leaves": ("int", None, "Max leaves"),
|
|
74
|
+
"max_depth": ("int", None, "-1 = auto"),
|
|
75
|
+
"learning_rate": ("float", None, "Shrinkage"),
|
|
76
|
+
"subsample": ("float", None, "Row subsample"),
|
|
77
|
+
"colsample_bytree": ("float", None, "Column subsample"),
|
|
78
|
+
"min_child_samples": ("int", None, "Min child samples"),
|
|
79
|
+
"reg_lambda": ("float", None, "L2 reg"),
|
|
80
|
+
"random_state": ("int", None, "Seed"),
|
|
81
|
+
},
|
|
82
|
+
),
|
|
83
|
+
("sklearn", "SVC"): ModelSpec(
|
|
84
|
+
task="classification",
|
|
85
|
+
import_path="sklearn.svm.SVC",
|
|
86
|
+
defaults=dict(C=1.0, gamma="scale", kernel="rbf", probability=True, class_weight=None, random_state=42),
|
|
87
|
+
ui_schema={
|
|
88
|
+
"C": ("float", None, "Regularization"),
|
|
89
|
+
"gamma": ("choice", ("scale", "auto"), "RBF width"),
|
|
90
|
+
"kernel": ("choice", ("rbf", "linear", "poly", "sigmoid"), "Kernel"),
|
|
91
|
+
"class_weight": ("choice", (None, "balanced"), "Imbalance"),
|
|
92
|
+
"probability": ("bool", None, "Calibrated probs"),
|
|
93
|
+
"random_state": ("int", None, "Seed"),
|
|
94
|
+
},
|
|
95
|
+
),
|
|
96
|
+
("sklearn", "KNeighborsClassifier"): ModelSpec(
|
|
97
|
+
task="classification",
|
|
98
|
+
import_path="sklearn.neighbors.KNeighborsClassifier",
|
|
99
|
+
defaults=dict(n_neighbors=15, weights="distance", p=2),
|
|
100
|
+
ui_schema={
|
|
101
|
+
"n_neighbors": ("int", None, "k"),
|
|
102
|
+
"weights": ("choice", ("uniform", "distance"), "Weights"),
|
|
103
|
+
"p": ("int", None, "Minkowski p"),
|
|
104
|
+
},
|
|
105
|
+
),
|
|
106
|
+
("sklearn", "GaussianNB"): ModelSpec(
|
|
107
|
+
task="classification",
|
|
108
|
+
import_path="sklearn.naive_bayes.GaussianNB",
|
|
109
|
+
defaults=dict(var_smoothing=1e-9),
|
|
110
|
+
ui_schema={"var_smoothing": ("float", None, "Variance smoothing")},
|
|
111
|
+
),
|
|
112
|
+
("catboost", "CatBoostClassifier"): ModelSpec(
|
|
113
|
+
task="classification",
|
|
114
|
+
import_path="catboost.CatBoostClassifier",
|
|
115
|
+
defaults=dict(iterations=800, depth=6, learning_rate=0.05, l2_leaf_reg=3.0, subsample=0.8,
|
|
116
|
+
loss_function="Logloss", random_state=42, verbose=False),
|
|
117
|
+
ui_schema={
|
|
118
|
+
"iterations": ("int", None, "Rounds"),
|
|
119
|
+
"depth": ("int", None, "Depth"),
|
|
120
|
+
"learning_rate": ("float", None, "Eta"),
|
|
121
|
+
"l2_leaf_reg": ("float", None, "L2 reg"),
|
|
122
|
+
"subsample": ("float", None, "Row subsample"),
|
|
123
|
+
"random_state": ("int", None, "Seed"),
|
|
124
|
+
},
|
|
125
|
+
),
|
|
126
|
+
("sklearn", "ExtraTreesClassifier"): ModelSpec(
|
|
127
|
+
task="classification",
|
|
128
|
+
import_path="sklearn.ensemble.ExtraTreesClassifier",
|
|
129
|
+
defaults=dict(n_estimators=400, max_depth=None, min_samples_split=2, min_samples_leaf=1,
|
|
130
|
+
class_weight=None, random_state=42, n_jobs=-1),
|
|
131
|
+
ui_schema={
|
|
132
|
+
"n_estimators": ("int", None, "Trees"),
|
|
133
|
+
"max_depth": ("int", None, "Depth"),
|
|
134
|
+
"min_samples_split": ("int", None, "Min split"),
|
|
135
|
+
"min_samples_leaf": ("int", None, "Min leaf"),
|
|
136
|
+
"class_weight": ("choice", (None, "balanced", "balanced_subsample"), "Imbalance"),
|
|
137
|
+
"random_state": ("int", None, "Seed"),
|
|
138
|
+
},
|
|
139
|
+
),
|
|
140
|
+
("sklearn", "HistGradientBoostingClassifier"): ModelSpec(
|
|
141
|
+
task="classification",
|
|
142
|
+
import_path="sklearn.ensemble.HistGradientBoostingClassifier",
|
|
143
|
+
defaults=dict(max_depth=None, learning_rate=0.1, max_bins=255, l2_regularization=0.0,
|
|
144
|
+
early_stopping=True, random_state=42),
|
|
145
|
+
ui_schema={
|
|
146
|
+
"max_depth": ("int", None, "Depth (None=auto)"),
|
|
147
|
+
"learning_rate": ("float", None, "Eta"),
|
|
148
|
+
"max_bins": ("int", None, "Bins"),
|
|
149
|
+
"l2_regularization": ("float", None, "L2 reg"),
|
|
150
|
+
"early_stopping": ("bool", None, "Early stop"),
|
|
151
|
+
"random_state": ("int", None, "Seed"),
|
|
152
|
+
},
|
|
153
|
+
),
|
|
154
|
+
|
|
155
|
+
# -------- Regression (10) --------
|
|
156
|
+
("sklearn", "LinearRegression"): ModelSpec(
|
|
157
|
+
task="regression",
|
|
158
|
+
import_path="sklearn.linear_model.LinearRegression",
|
|
159
|
+
defaults=dict(fit_intercept=True, positive=False),
|
|
160
|
+
ui_schema={
|
|
161
|
+
"fit_intercept": ("bool", None, "Fit intercept"),
|
|
162
|
+
"positive": ("bool", None, "Positive coef"),
|
|
163
|
+
},
|
|
164
|
+
),
|
|
165
|
+
("sklearn", "RandomForestRegressor"): ModelSpec(
|
|
166
|
+
task="regression",
|
|
167
|
+
import_path="sklearn.ensemble.RandomForestRegressor",
|
|
168
|
+
defaults=dict(n_estimators=400, max_depth=16, min_samples_split=2, min_samples_leaf=1,
|
|
169
|
+
random_state=42, n_jobs=-1),
|
|
170
|
+
ui_schema={
|
|
171
|
+
"n_estimators": ("int", None, "Trees"),
|
|
172
|
+
"max_depth": ("int", None, "Depth"),
|
|
173
|
+
"min_samples_split": ("int", None, "Min split"),
|
|
174
|
+
"min_samples_leaf": ("int", None, "Min leaf"),
|
|
175
|
+
"random_state": ("int", None, "Seed"),
|
|
176
|
+
},
|
|
177
|
+
),
|
|
178
|
+
("xgboost", "XGBRegressor"): ModelSpec(
|
|
179
|
+
task="regression",
|
|
180
|
+
import_path="xgboost.XGBRegressor",
|
|
181
|
+
defaults=dict(n_estimators=800, max_depth=6, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8,
|
|
182
|
+
min_child_weight=1, reg_lambda=1.0, tree_method="hist", random_state=42, n_jobs=-1),
|
|
183
|
+
ui_schema={
|
|
184
|
+
"n_estimators": ("int", None, "Rounds"),
|
|
185
|
+
"max_depth": ("int", None, "Depth"),
|
|
186
|
+
"learning_rate": ("float", None, "Eta"),
|
|
187
|
+
"subsample": ("float", None, "Row subsample"),
|
|
188
|
+
"colsample_bytree": ("float", None, "Column subsample"),
|
|
189
|
+
"min_child_weight": ("float", None, "Min child weight"),
|
|
190
|
+
"reg_lambda": ("float", None, "L2 reg"),
|
|
191
|
+
"tree_method": ("choice", ("hist", "auto"), "Grow method"),
|
|
192
|
+
"random_state": ("int", None, "Seed"),
|
|
193
|
+
"n_jobs": ("int", None, "Threads"),
|
|
194
|
+
},
|
|
195
|
+
),
|
|
196
|
+
("lightgbm", "LGBMRegressor"): ModelSpec(
|
|
197
|
+
task="regression",
|
|
198
|
+
import_path="lightgbm.LGBMRegressor",
|
|
199
|
+
defaults=dict(n_estimators=1200, num_leaves=31, max_depth=-1, learning_rate=0.05, subsample=0.8,
|
|
200
|
+
colsample_bytree=0.8, min_child_samples=20, reg_lambda=1.0, random_state=42, n_jobs=-1),
|
|
201
|
+
ui_schema={
|
|
202
|
+
"n_estimators": ("int", None, "Rounds"),
|
|
203
|
+
"num_leaves": ("int", None, "Max leaves"),
|
|
204
|
+
"max_depth": ("int", None, "-1 = auto"),
|
|
205
|
+
"learning_rate": ("float", None, "Eta"),
|
|
206
|
+
"subsample": ("float", None, "Row subsample"),
|
|
207
|
+
"colsample_bytree": ("float", None, "Column subsample"),
|
|
208
|
+
"min_child_samples": ("int", None, "Min child samples"),
|
|
209
|
+
"reg_lambda": ("float", None, "L2 reg"),
|
|
210
|
+
"random_state": ("int", None, "Seed"),
|
|
211
|
+
},
|
|
212
|
+
),
|
|
213
|
+
("sklearn", "SVR"): ModelSpec(
|
|
214
|
+
task="regression",
|
|
215
|
+
import_path="sklearn.svm.SVR",
|
|
216
|
+
defaults=dict(C=1.0, gamma="scale", epsilon=0.1, kernel="rbf"),
|
|
217
|
+
ui_schema={
|
|
218
|
+
"C": ("float", None, "Regularization"),
|
|
219
|
+
"gamma": ("choice", ("scale", "auto"), "RBF width"),
|
|
220
|
+
"epsilon": ("float", None, "Epsilon tube"),
|
|
221
|
+
"kernel": ("choice", ("rbf", "linear", "poly", "sigmoid"), "Kernel"),
|
|
222
|
+
},
|
|
223
|
+
),
|
|
224
|
+
("sklearn", "KNeighborsRegressor"): ModelSpec(
|
|
225
|
+
task="regression",
|
|
226
|
+
import_path="sklearn.neighbors.KNeighborsRegressor",
|
|
227
|
+
defaults=dict(n_neighbors=15, weights="distance", p=2),
|
|
228
|
+
ui_schema={
|
|
229
|
+
"n_neighbors": ("int", None, "k"),
|
|
230
|
+
"weights": ("choice", ("uniform", "distance"), "Weights"),
|
|
231
|
+
"p": ("int", None, "Minkowski p"),
|
|
232
|
+
},
|
|
233
|
+
),
|
|
234
|
+
("sklearn", "ElasticNet"): ModelSpec(
|
|
235
|
+
task="regression",
|
|
236
|
+
import_path="sklearn.linear_model.ElasticNet",
|
|
237
|
+
defaults=dict(alpha=0.001, l1_ratio=0.5, max_iter=1000, random_state=42),
|
|
238
|
+
ui_schema={
|
|
239
|
+
"alpha": ("float", None, "Reg strength"),
|
|
240
|
+
"l1_ratio": ("float", None, "L1 vs L2 mix"),
|
|
241
|
+
"max_iter": ("int", None, "Max iterations"),
|
|
242
|
+
"random_state": ("int", None, "Seed"),
|
|
243
|
+
},
|
|
244
|
+
),
|
|
245
|
+
("catboost", "CatBoostRegressor"): ModelSpec(
|
|
246
|
+
task="regression",
|
|
247
|
+
import_path="catboost.CatBoostRegressor",
|
|
248
|
+
defaults=dict(iterations=1000, depth=6, learning_rate=0.05, l2_leaf_reg=3.0, subsample=0.8,
|
|
249
|
+
loss_function="RMSE", random_state=42, verbose=False),
|
|
250
|
+
ui_schema={
|
|
251
|
+
"iterations": ("int", None, "Rounds"),
|
|
252
|
+
"depth": ("int", None, "Depth"),
|
|
253
|
+
"learning_rate": ("float", None, "Eta"),
|
|
254
|
+
"l2_leaf_reg": ("float", None, "L2 reg"),
|
|
255
|
+
"subsample": ("float", None, "Row subsample"),
|
|
256
|
+
"random_state": ("int", None, "Seed"),
|
|
257
|
+
},
|
|
258
|
+
),
|
|
259
|
+
("sklearn", "ExtraTreesRegressor"): ModelSpec(
|
|
260
|
+
task="regression",
|
|
261
|
+
import_path="sklearn.ensemble.ExtraTreesRegressor",
|
|
262
|
+
defaults=dict(n_estimators=400, max_depth=None, min_samples_split=2, min_samples_leaf=1,
|
|
263
|
+
random_state=42, n_jobs=-1),
|
|
264
|
+
ui_schema={
|
|
265
|
+
"n_estimators": ("int", None, "Trees"),
|
|
266
|
+
"max_depth": ("int", None, "Depth"),
|
|
267
|
+
"min_samples_split": ("int", None, "Min split"),
|
|
268
|
+
"min_samples_leaf": ("int", None, "Min leaf"),
|
|
269
|
+
"random_state": ("int", None, "Seed"),
|
|
270
|
+
},
|
|
271
|
+
),
|
|
272
|
+
("sklearn", "HistGradientBoostingRegressor"): ModelSpec(
|
|
273
|
+
task="regression",
|
|
274
|
+
import_path="sklearn.ensemble.HistGradientBoostingRegressor",
|
|
275
|
+
defaults=dict(max_depth=None, learning_rate=0.1, max_bins=255, l2_regularization=0.0,
|
|
276
|
+
early_stopping=True, random_state=42),
|
|
277
|
+
ui_schema={
|
|
278
|
+
"max_depth": ("int", None, "Depth (None=auto)"),
|
|
279
|
+
"learning_rate": ("float", None, "Eta"),
|
|
280
|
+
"max_bins": ("int", None, "Bins"),
|
|
281
|
+
"l2_regularization": ("float", None, "L2 reg"),
|
|
282
|
+
"early_stopping": ("bool", None, "Early stop"),
|
|
283
|
+
"random_state": ("int", None, "Seed"),
|
|
284
|
+
},
|
|
285
|
+
),
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
# --------------- Helpers ---------------
|
|
289
|
+
|
|
290
|
+
def list_models(task: Optional[Task] = None) -> Dict[Tuple[str, str], ModelSpec]:
|
|
291
|
+
if task:
|
|
292
|
+
return {k: v for k, v in _REGISTRY.items() if v.task == task}
|
|
293
|
+
return dict(_REGISTRY)
|
|
294
|
+
|
|
295
|
+
def get_spec(library: str, algo: str) -> ModelSpec:
|
|
296
|
+
key = (library, algo)
|
|
297
|
+
if key not in _REGISTRY:
|
|
298
|
+
raise KeyError(f"Unknown model: {library}.{algo}")
|
|
299
|
+
return _REGISTRY[key]
|
|
300
|
+
|
|
301
|
+
def _lazy_import(import_path: str) -> Callable[..., Any]:
|
|
302
|
+
mod_name, cls_name = import_path.rsplit(".", 1)
|
|
303
|
+
mod = __import__(mod_name, fromlist=[cls_name])
|
|
304
|
+
return getattr(mod, cls_name)
|
|
305
|
+
|
|
306
|
+
def build_estimator(library: str, algo: str, params: Optional[Dict[str, Any]] = None):
|
|
307
|
+
spec = get_spec(library, algo)
|
|
308
|
+
Cls = _lazy_import(spec.import_path)
|
|
309
|
+
kwargs = dict(spec.defaults)
|
|
310
|
+
if params:
|
|
311
|
+
canon = {}
|
|
312
|
+
for k, v in params.items():
|
|
313
|
+
k2 = spec.aliases.get(k, k)
|
|
314
|
+
canon[k2] = v
|
|
315
|
+
kwargs.update({k: v for k, v in canon.items() if v is not None})
|
|
316
|
+
return Cls(**kwargs)
|
|
317
|
+
|
|
318
|
+
def ui_schema_for(library: str, algo: str) -> Dict[str, Tuple[str, Optional[Tuple[Any, ...]], Optional[str]]]:
|
|
319
|
+
return get_spec(library, algo).ui_schema
|
|
320
|
+
|
|
321
|
+
def infer_task_from_target(y) -> Task:
|
|
322
|
+
try:
|
|
323
|
+
n = int(y.nunique()) # pandas Series fast path
|
|
324
|
+
except Exception:
|
|
325
|
+
try:
|
|
326
|
+
n = len(set(y))
|
|
327
|
+
except Exception:
|
|
328
|
+
n = 10
|
|
329
|
+
return "classification" if n <= 3 else "regression"
|