tanml 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tanml might be problematic. Click here for more details.

Files changed (49) hide show
  1. tanml/__init__.py +1 -1
  2. tanml/check_runners/cleaning_repro_runner.py +2 -2
  3. tanml/check_runners/correlation_runner.py +49 -12
  4. tanml/check_runners/explainability_runner.py +12 -22
  5. tanml/check_runners/logistic_stats_runner.py +196 -17
  6. tanml/check_runners/performance_runner.py +82 -26
  7. tanml/check_runners/raw_data_runner.py +29 -14
  8. tanml/check_runners/regression_metrics_runner.py +195 -0
  9. tanml/check_runners/stress_test_runner.py +23 -6
  10. tanml/check_runners/vif_runner.py +33 -27
  11. tanml/checks/correlation.py +241 -41
  12. tanml/checks/explainability/shap_check.py +261 -29
  13. tanml/checks/logit_stats.py +186 -54
  14. tanml/checks/performance_classification.py +305 -0
  15. tanml/checks/raw_data.py +58 -23
  16. tanml/checks/regression_metrics.py +167 -0
  17. tanml/checks/stress_test.py +157 -53
  18. tanml/cli/main.py +99 -27
  19. tanml/engine/check_agent_registry.py +20 -10
  20. tanml/engine/core_engine_agent.py +199 -37
  21. tanml/models/registry.py +329 -0
  22. tanml/report/report_builder.py +1180 -147
  23. tanml/report/templates/report_template_cls.docx +0 -0
  24. tanml/report/templates/report_template_reg.docx +0 -0
  25. tanml/ui/app.py +1205 -0
  26. tanml/utils/data_loader.py +105 -15
  27. tanml-0.1.7.dist-info/METADATA +164 -0
  28. tanml-0.1.7.dist-info/RECORD +54 -0
  29. tanml/cli/arg_parser.py +0 -31
  30. tanml/cli/init_cmd.py +0 -8
  31. tanml/cli/validate_cmd.py +0 -7
  32. tanml/config_templates/rules_multiple_models_datasets.yaml +0 -144
  33. tanml/config_templates/rules_one_dataset_segment_column.yaml +0 -140
  34. tanml/config_templates/rules_one_model_one_dataset.yaml +0 -143
  35. tanml/engine/segmentation_agent.py +0 -118
  36. tanml/engine/validation_agent.py +0 -91
  37. tanml/report/templates/report_template.docx +0 -0
  38. tanml/utils/model_loader.py +0 -35
  39. tanml/utils/r_loader.py +0 -30
  40. tanml/utils/sas_loader.py +0 -50
  41. tanml/utils/yaml_generator.py +0 -34
  42. tanml/utils/yaml_loader.py +0 -5
  43. tanml/validate.py +0 -209
  44. tanml-0.1.6.dist-info/METADATA +0 -317
  45. tanml-0.1.6.dist-info/RECORD +0 -62
  46. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/WHEEL +0 -0
  47. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/entry_points.txt +0 -0
  48. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/licenses/LICENSE +0 -0
  49. {tanml-0.1.6.dist-info → tanml-0.1.7.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,27 @@
1
- """
2
- ValidationEngine – runs all registered check-runners and assembles a
3
- single results dictionary that the ReportBuilder / Jinja template expects.
4
- """
1
+ # tanml/engine/core_engine_agent.py
5
2
 
6
3
  from tanml.engine.check_agent_registry import CHECK_RUNNER_REGISTRY
7
- #from tanml.checks.cleaning_repro import CleaningReproCheck
8
4
 
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ try:
9
+ import statsmodels.api as sm
10
+ _SM_AVAILABLE = True
11
+ except Exception:
12
+ _SM_AVAILABLE = False
9
13
 
10
14
  KEEP_AS_NESTED = {
11
15
  "DataQualityCheck",
12
16
  "StressTestCheck",
13
17
  "InputClusterCheck",
18
+ "InputClusterCoverageCheck",
14
19
  "RawDataCheck",
15
- #"CleaningReproCheck",
16
20
  "SHAPCheck",
17
21
  "VIFCheck",
18
22
  "CorrelationCheck",
19
23
  "EDACheck",
24
+ # "RegressionMetrics", # you can keep nested if desired
20
25
  }
21
26
 
22
27
 
@@ -30,8 +35,8 @@ class ValidationEngine:
30
35
  y_test,
31
36
  config,
32
37
  cleaned_data,
33
- raw_df=None ,
34
- ctx=None
38
+ raw_df=None,
39
+ ctx=None
35
40
  ):
36
41
  self.model = model
37
42
  self.X_train = X_train
@@ -40,15 +45,149 @@ class ValidationEngine:
40
45
  self.y_test = y_test
41
46
  self.config = config
42
47
  self.cleaned_data = cleaned_data
43
- self.raw_df = raw_df
48
+ self.raw_df = raw_df
44
49
 
50
+ # allow resuming if config had check_results
45
51
  self.results = dict(config.get("check_results", {}))
46
52
  self.ctx = ctx or {}
47
- def run_all_checks(self):
53
+
54
+ self.task_type = self._infer_task_type(self.y_train, config, model)
55
+
56
+ # --- better detection logic --------------------------------------------
57
+ @staticmethod
58
+ def _infer_task_type(y, config=None, model=None):
59
+ """
60
+ Decide if task is classification or regression.
61
+ Priority:
62
+ 1. config["model"]["type"]
63
+ 2. model._estimator_type (sklearn)
64
+ 3. y values (unique count)
65
+ """
66
+ # 1. Config hint
67
+ try:
68
+ mtype = (config or {}).get("model", {}).get("type", "")
69
+ if isinstance(mtype, str):
70
+ mtype = mtype.lower()
71
+ if "class" in mtype:
72
+ return "classification"
73
+ if "regress" in mtype:
74
+ return "regression"
75
+ except Exception:
76
+ pass
77
+
78
+ # 2. Model introspection
79
+ try:
80
+ if hasattr(model, "_estimator_type"):
81
+ est = getattr(model, "_estimator_type", "")
82
+ if est == "classifier":
83
+ return "classification"
84
+ if est == "regressor":
85
+ return "regression"
86
+ if hasattr(model, "predict_proba") or hasattr(model, "decision_function"):
87
+ return "classification"
88
+ except Exception:
89
+ pass
90
+
91
+ # 3. Label based
92
+ try:
93
+ if isinstance(y, (pd.Series, pd.DataFrame)):
94
+ s = y.squeeze()
95
+ else:
96
+ s = np.asarray(y).reshape(-1)
97
+
98
+ unique_vals = pd.Series(s).dropna().unique()
99
+ # Heuristic: small discrete set -> classification
100
+ if pd.api.types.is_numeric_dtype(s):
101
+ if len(unique_vals) <= 10:
102
+ return "classification"
103
+ return "regression"
104
+ else:
105
+ # non-numeric target -> classification
106
+ return "classification"
107
+ except Exception:
108
+ pass
109
+
110
+ # Fallback
111
+ return "classification"
112
+ # -----------------------------------------------------------------------
113
+
114
+ def _pick(self, *paths, default=None):
115
+ for path in paths:
116
+ cur = self.results
117
+ ok = True
118
+ for p in path:
119
+ if isinstance(cur, dict) and p in cur:
120
+ cur = cur[p]
121
+ else:
122
+ ok = False
123
+ break
124
+ if ok:
125
+ return cur
126
+ return default
127
+
128
+ def _compute_linear_stats(self):
129
+ """
130
+ Optional: compute a statsmodels OLS summary + coefficient table for regression runs.
131
+ Writes results into self.results["LinearStats"].
132
+ """
133
+ if self.task_type != "regression":
134
+ return
135
+ if not _SM_AVAILABLE:
136
+ self.results["LinearStats"] = {
137
+ "error": "statsmodels not available; install `statsmodels` to see OLS summary."
138
+ }
139
+ return
140
+
141
+ try:
142
+ # add constant and fit OLS on TRAIN split to mirror sklearn fit
143
+ X = self.X_train
144
+ y = self.y_train
145
+ Xc = sm.add_constant(X, has_constant="add")
146
+ ols_model = sm.OLS(y, Xc, missing="drop")
147
+ ols_res = ols_model.fit()
148
+
149
+ # Build coefficient table (including intercept 'const')
150
+ params = ols_res.params
151
+ bse = ols_res.bse
152
+ tvals = ols_res.tvalues
153
+ pvals = ols_res.pvalues
154
+ ci = ols_res.conf_int(alpha=0.05)
155
+ ci.columns = ["ci_low", "ci_high"]
156
+
157
+ rows = []
158
+ for name in params.index:
159
+ rows.append({
160
+ "feature": name,
161
+ "coef": float(params[name]),
162
+ "std err": float(bse.get(name, float("nan"))),
163
+ "t": float(tvals.get(name, float("nan"))),
164
+ "P>|t|": float(pvals.get(name, float("nan"))),
165
+ "ci_low": float(ci.loc[name, "ci_low"]) if name in ci.index else None,
166
+ "ci_high": float(ci.loc[name, "ci_high"]) if name in ci.index else None,
167
+ })
168
+
169
+ self.results["LinearStats"] = {
170
+ "summary_text": ols_res.summary().as_text(),
171
+ "coeff_table": rows,
172
+ "status": "ok",
173
+ }
174
+ except Exception as e:
175
+ self.results["LinearStats"] = {"error": f"OLS stats failed: {e}"}
176
+ # ------------------------------------------------------------
177
+
178
+ def run_all_checks(self, progress_callback=None):
179
+ self.results["task_type"] = self.task_type
180
+
48
181
  for check_name, runner_func in CHECK_RUNNER_REGISTRY.items():
49
182
  if check_name in self.config.get("skip_checks", []):
50
183
  continue
51
184
 
185
+ if progress_callback:
186
+ try:
187
+ progress_callback(f"Running {check_name}…")
188
+ except Exception:
189
+ pass
190
+
52
191
  print(f"✅ Running {check_name}")
53
192
  try:
54
193
  result = runner_func(
@@ -59,30 +198,45 @@ class ValidationEngine:
59
198
  self.y_test,
60
199
  self.config,
61
200
  self.cleaned_data,
62
- raw_df=self.raw_df
201
+ raw_df=self.raw_df
63
202
  )
64
-
65
203
  self._integrate(check_name, result)
66
-
67
204
  except Exception as e:
68
205
  print(f"⚠️ {check_name} failed: {e}")
69
206
  self.results[check_name] = {"error": str(e)}
70
207
 
71
- # Add CleaningReproCheck manually
72
- # if self.raw_df is not None:
73
- # print("✅ Running CleaningReproCheck")
74
- # try:
75
- # check = CleaningReproCheck(self.raw_df, self.cleaned_data)
76
-
77
- # self.results["CleaningReproCheck"] = check.run()
78
- # except Exception as e:
79
- # print(f"⚠️ CleaningReproCheck failed: {e}")
80
- # self.results["CleaningReproCheck"] = {"error": str(e)}
81
- # else:
82
- # print("⚠️ Skipping CleaningReproCheck raw_df not provided")
83
- # self.results["CleaningReproCheck"] = {"error": "raw_data not available"}
84
-
85
- # convenience copy for template
208
+ # add OLS stats for regression (pretty coef table + p-values)
209
+ self._compute_linear_stats()
210
+
211
+ # -------- Build summary: TASK-AWARE --------
212
+ summary = {}
213
+
214
+ if self.task_type == "regression":
215
+ summary["rmse"] = self._pick(("RegressionMetrics", "rmse"))
216
+ summary["mae"] = self._pick(("RegressionMetrics", "mae"))
217
+ summary["r2"] = self._pick(("RegressionMetrics", "r2"))
218
+ else:
219
+ cls = self._pick(("performance", "classification", "summary")) or {}
220
+ summary["auc"] = cls.get("auc")
221
+ summary["ks"] = cls.get("ks")
222
+ summary["f1"] = cls.get("f1")
223
+ summary["pr_auc"] = cls.get("pr_auc")
224
+
225
+ # PSI (optional)
226
+ summary["max_psi"] = self._pick(
227
+ ("PSICheck", "max_psi"),
228
+ ("PopulationStabilityCheck", "max_psi"),
229
+ ("max_psi",)
230
+ )
231
+
232
+ # Count failed checks
233
+ failed = 0
234
+ for k, v in self.results.items():
235
+ if isinstance(v, dict) and v.get("status") == "fail":
236
+ failed += 1
237
+ summary["rules_failed"] = failed
238
+
239
+ self.results["summary"] = summary
86
240
  self.results["check_results"] = dict(self.results)
87
241
  return self.results
88
242
 
@@ -91,25 +245,33 @@ class ValidationEngine:
91
245
  if not result:
92
246
  return
93
247
 
94
- # Special flatten for LogisticStatsCheck
95
248
  if check_name == "LogisticStatsCheck":
96
249
  self.results.update(result)
97
250
  return
98
251
 
99
- # If it's a simple object (rare), store as-is
100
252
  if not isinstance(result, dict):
101
253
  self.results[check_name] = result
102
254
  return
103
255
 
104
- # Keep entire dict nested
105
- if check_name in KEEP_AS_NESTED:
256
+ cluster_aliases = {
257
+ "InputClusterCoverageCheck",
258
+ "InputClusterCoverage",
259
+ "ClusterCoverageCheck",
260
+ "InputClustersCheck",
261
+ }
262
+ if check_name in cluster_aliases:
106
263
  self.results[check_name] = result
264
+ self.results["InputClusterCheck"] = result
265
+ return
266
+
267
+ if set(result.keys()) == {"InputClusterCheck"}:
268
+ self.results["InputClusterCheck"] = result["InputClusterCheck"]
107
269
  return
108
270
 
109
- # If runner returns {"CheckName": {...}}, unwrap
110
- if set(result.keys()) == {check_name}:
111
- self.results[check_name] = result[check_name]
271
+ if check_name in KEEP_AS_NESTED:
272
+ self.results[check_name] = result
112
273
  return
113
274
 
114
- # Default: merge into root
115
- self.results.update(result)
275
+ if isinstance(result, dict):
276
+ self.results.update(result)
277
+ return
@@ -0,0 +1,329 @@
1
+ # tanml/models/registry.py
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Callable, Dict, Optional, Tuple, Literal
5
+
6
+ Task = Literal["classification", "regression"]
7
+
8
+ @dataclass(frozen=True)
9
+ class ModelSpec:
10
+ task: Task
11
+ import_path: str # e.g., "sklearn.ensemble.RandomForestClassifier"
12
+ defaults: Dict[str, Any] = field(default_factory=dict)
13
+ # UI schema: param -> (type, choices_or_None, help_or_None)
14
+ ui_schema: Dict[str, Tuple[str, Optional[Tuple[Any, ...]], Optional[str]]] = field(default_factory=dict)
15
+ aliases: Dict[str, str] = field(default_factory=dict) # optional param alias map
16
+
17
+ # -------------------- 20 MODELS --------------------
18
+
19
+ _REGISTRY: Dict[Tuple[str, str], ModelSpec] = {
20
+ # -------- Classification (10) --------
21
+ ("sklearn", "LogisticRegression"): ModelSpec(
22
+ task="classification",
23
+ import_path="sklearn.linear_model.LogisticRegression",
24
+ defaults=dict(penalty="l2", solver="lbfgs", C=1.0, class_weight=None, max_iter=1000, random_state=42),
25
+ ui_schema={
26
+ "penalty": ("choice", ("l2", "l1"), "Regularization"),
27
+ "solver": ("choice", ("lbfgs", "liblinear", "saga"), "Solver"),
28
+ "C": ("float", None, "Inverse regularization strength"),
29
+ "class_weight": ("choice", (None, "balanced"), "Imbalance handling"),
30
+ "max_iter": ("int", None, "Max iterations"),
31
+ "random_state": ("int", None, "Seed"),
32
+ },
33
+ ),
34
+ ("sklearn", "RandomForestClassifier"): ModelSpec(
35
+ task="classification",
36
+ import_path="sklearn.ensemble.RandomForestClassifier",
37
+ defaults=dict(n_estimators=400, max_depth=16, min_samples_split=2, min_samples_leaf=1,
38
+ class_weight=None, random_state=42, n_jobs=-1),
39
+ ui_schema={
40
+ "n_estimators": ("int", None, "Number of trees"),
41
+ "max_depth": ("int", None, "Tree depth (None=unbounded)"),
42
+ "min_samples_split": ("int", None, "Min samples to split"),
43
+ "min_samples_leaf": ("int", None, "Min samples per leaf"),
44
+ "class_weight": ("choice", (None, "balanced", "balanced_subsample"), "Imbalance"),
45
+ "random_state": ("int", None, "Seed"),
46
+ },
47
+ ),
48
+ ("xgboost", "XGBClassifier"): ModelSpec(
49
+ task="classification",
50
+ import_path="xgboost.XGBClassifier",
51
+ defaults=dict(n_estimators=600, max_depth=6, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8,
52
+ min_child_weight=1, reg_lambda=1.0, tree_method="hist", random_state=42, n_jobs=-1),
53
+ ui_schema={
54
+ "n_estimators": ("int", None, "Boosting rounds"),
55
+ "max_depth": ("int", None, "Tree depth"),
56
+ "learning_rate": ("float", None, "Eta"),
57
+ "subsample": ("float", None, "Row subsample"),
58
+ "colsample_bytree": ("float", None, "Column subsample"),
59
+ "min_child_weight": ("float", None, "Min child weight"),
60
+ "reg_lambda": ("float", None, "L2 regularization"),
61
+ "tree_method": ("choice", ("hist", "auto"), "Grow method"),
62
+ "random_state": ("int", None, "Seed"),
63
+ "n_jobs": ("int", None, "Threads"),
64
+ },
65
+ ),
66
+ ("lightgbm", "LGBMClassifier"): ModelSpec(
67
+ task="classification",
68
+ import_path="lightgbm.LGBMClassifier",
69
+ defaults=dict(n_estimators=800, num_leaves=31, max_depth=-1, learning_rate=0.05, subsample=0.8,
70
+ colsample_bytree=0.8, min_child_samples=20, reg_lambda=1.0, random_state=42, n_jobs=-1),
71
+ ui_schema={
72
+ "n_estimators": ("int", None, "Boosting rounds"),
73
+ "num_leaves": ("int", None, "Max leaves"),
74
+ "max_depth": ("int", None, "-1 = auto"),
75
+ "learning_rate": ("float", None, "Shrinkage"),
76
+ "subsample": ("float", None, "Row subsample"),
77
+ "colsample_bytree": ("float", None, "Column subsample"),
78
+ "min_child_samples": ("int", None, "Min child samples"),
79
+ "reg_lambda": ("float", None, "L2 reg"),
80
+ "random_state": ("int", None, "Seed"),
81
+ },
82
+ ),
83
+ ("sklearn", "SVC"): ModelSpec(
84
+ task="classification",
85
+ import_path="sklearn.svm.SVC",
86
+ defaults=dict(C=1.0, gamma="scale", kernel="rbf", probability=True, class_weight=None, random_state=42),
87
+ ui_schema={
88
+ "C": ("float", None, "Regularization"),
89
+ "gamma": ("choice", ("scale", "auto"), "RBF width"),
90
+ "kernel": ("choice", ("rbf", "linear", "poly", "sigmoid"), "Kernel"),
91
+ "class_weight": ("choice", (None, "balanced"), "Imbalance"),
92
+ "probability": ("bool", None, "Calibrated probs"),
93
+ "random_state": ("int", None, "Seed"),
94
+ },
95
+ ),
96
+ ("sklearn", "KNeighborsClassifier"): ModelSpec(
97
+ task="classification",
98
+ import_path="sklearn.neighbors.KNeighborsClassifier",
99
+ defaults=dict(n_neighbors=15, weights="distance", p=2),
100
+ ui_schema={
101
+ "n_neighbors": ("int", None, "k"),
102
+ "weights": ("choice", ("uniform", "distance"), "Weights"),
103
+ "p": ("int", None, "Minkowski p"),
104
+ },
105
+ ),
106
+ ("sklearn", "GaussianNB"): ModelSpec(
107
+ task="classification",
108
+ import_path="sklearn.naive_bayes.GaussianNB",
109
+ defaults=dict(var_smoothing=1e-9),
110
+ ui_schema={"var_smoothing": ("float", None, "Variance smoothing")},
111
+ ),
112
+ ("catboost", "CatBoostClassifier"): ModelSpec(
113
+ task="classification",
114
+ import_path="catboost.CatBoostClassifier",
115
+ defaults=dict(iterations=800, depth=6, learning_rate=0.05, l2_leaf_reg=3.0, subsample=0.8,
116
+ loss_function="Logloss", random_state=42, verbose=False),
117
+ ui_schema={
118
+ "iterations": ("int", None, "Rounds"),
119
+ "depth": ("int", None, "Depth"),
120
+ "learning_rate": ("float", None, "Eta"),
121
+ "l2_leaf_reg": ("float", None, "L2 reg"),
122
+ "subsample": ("float", None, "Row subsample"),
123
+ "random_state": ("int", None, "Seed"),
124
+ },
125
+ ),
126
+ ("sklearn", "ExtraTreesClassifier"): ModelSpec(
127
+ task="classification",
128
+ import_path="sklearn.ensemble.ExtraTreesClassifier",
129
+ defaults=dict(n_estimators=400, max_depth=None, min_samples_split=2, min_samples_leaf=1,
130
+ class_weight=None, random_state=42, n_jobs=-1),
131
+ ui_schema={
132
+ "n_estimators": ("int", None, "Trees"),
133
+ "max_depth": ("int", None, "Depth"),
134
+ "min_samples_split": ("int", None, "Min split"),
135
+ "min_samples_leaf": ("int", None, "Min leaf"),
136
+ "class_weight": ("choice", (None, "balanced", "balanced_subsample"), "Imbalance"),
137
+ "random_state": ("int", None, "Seed"),
138
+ },
139
+ ),
140
+ ("sklearn", "HistGradientBoostingClassifier"): ModelSpec(
141
+ task="classification",
142
+ import_path="sklearn.ensemble.HistGradientBoostingClassifier",
143
+ defaults=dict(max_depth=None, learning_rate=0.1, max_bins=255, l2_regularization=0.0,
144
+ early_stopping=True, random_state=42),
145
+ ui_schema={
146
+ "max_depth": ("int", None, "Depth (None=auto)"),
147
+ "learning_rate": ("float", None, "Eta"),
148
+ "max_bins": ("int", None, "Bins"),
149
+ "l2_regularization": ("float", None, "L2 reg"),
150
+ "early_stopping": ("bool", None, "Early stop"),
151
+ "random_state": ("int", None, "Seed"),
152
+ },
153
+ ),
154
+
155
+ # -------- Regression (10) --------
156
+ ("sklearn", "LinearRegression"): ModelSpec(
157
+ task="regression",
158
+ import_path="sklearn.linear_model.LinearRegression",
159
+ defaults=dict(fit_intercept=True, positive=False),
160
+ ui_schema={
161
+ "fit_intercept": ("bool", None, "Fit intercept"),
162
+ "positive": ("bool", None, "Positive coef"),
163
+ },
164
+ ),
165
+ ("sklearn", "RandomForestRegressor"): ModelSpec(
166
+ task="regression",
167
+ import_path="sklearn.ensemble.RandomForestRegressor",
168
+ defaults=dict(n_estimators=400, max_depth=16, min_samples_split=2, min_samples_leaf=1,
169
+ random_state=42, n_jobs=-1),
170
+ ui_schema={
171
+ "n_estimators": ("int", None, "Trees"),
172
+ "max_depth": ("int", None, "Depth"),
173
+ "min_samples_split": ("int", None, "Min split"),
174
+ "min_samples_leaf": ("int", None, "Min leaf"),
175
+ "random_state": ("int", None, "Seed"),
176
+ },
177
+ ),
178
+ ("xgboost", "XGBRegressor"): ModelSpec(
179
+ task="regression",
180
+ import_path="xgboost.XGBRegressor",
181
+ defaults=dict(n_estimators=800, max_depth=6, learning_rate=0.05, subsample=0.8, colsample_bytree=0.8,
182
+ min_child_weight=1, reg_lambda=1.0, tree_method="hist", random_state=42, n_jobs=-1),
183
+ ui_schema={
184
+ "n_estimators": ("int", None, "Rounds"),
185
+ "max_depth": ("int", None, "Depth"),
186
+ "learning_rate": ("float", None, "Eta"),
187
+ "subsample": ("float", None, "Row subsample"),
188
+ "colsample_bytree": ("float", None, "Column subsample"),
189
+ "min_child_weight": ("float", None, "Min child weight"),
190
+ "reg_lambda": ("float", None, "L2 reg"),
191
+ "tree_method": ("choice", ("hist", "auto"), "Grow method"),
192
+ "random_state": ("int", None, "Seed"),
193
+ "n_jobs": ("int", None, "Threads"),
194
+ },
195
+ ),
196
+ ("lightgbm", "LGBMRegressor"): ModelSpec(
197
+ task="regression",
198
+ import_path="lightgbm.LGBMRegressor",
199
+ defaults=dict(n_estimators=1200, num_leaves=31, max_depth=-1, learning_rate=0.05, subsample=0.8,
200
+ colsample_bytree=0.8, min_child_samples=20, reg_lambda=1.0, random_state=42, n_jobs=-1),
201
+ ui_schema={
202
+ "n_estimators": ("int", None, "Rounds"),
203
+ "num_leaves": ("int", None, "Max leaves"),
204
+ "max_depth": ("int", None, "-1 = auto"),
205
+ "learning_rate": ("float", None, "Eta"),
206
+ "subsample": ("float", None, "Row subsample"),
207
+ "colsample_bytree": ("float", None, "Column subsample"),
208
+ "min_child_samples": ("int", None, "Min child samples"),
209
+ "reg_lambda": ("float", None, "L2 reg"),
210
+ "random_state": ("int", None, "Seed"),
211
+ },
212
+ ),
213
+ ("sklearn", "SVR"): ModelSpec(
214
+ task="regression",
215
+ import_path="sklearn.svm.SVR",
216
+ defaults=dict(C=1.0, gamma="scale", epsilon=0.1, kernel="rbf"),
217
+ ui_schema={
218
+ "C": ("float", None, "Regularization"),
219
+ "gamma": ("choice", ("scale", "auto"), "RBF width"),
220
+ "epsilon": ("float", None, "Epsilon tube"),
221
+ "kernel": ("choice", ("rbf", "linear", "poly", "sigmoid"), "Kernel"),
222
+ },
223
+ ),
224
+ ("sklearn", "KNeighborsRegressor"): ModelSpec(
225
+ task="regression",
226
+ import_path="sklearn.neighbors.KNeighborsRegressor",
227
+ defaults=dict(n_neighbors=15, weights="distance", p=2),
228
+ ui_schema={
229
+ "n_neighbors": ("int", None, "k"),
230
+ "weights": ("choice", ("uniform", "distance"), "Weights"),
231
+ "p": ("int", None, "Minkowski p"),
232
+ },
233
+ ),
234
+ ("sklearn", "ElasticNet"): ModelSpec(
235
+ task="regression",
236
+ import_path="sklearn.linear_model.ElasticNet",
237
+ defaults=dict(alpha=0.001, l1_ratio=0.5, max_iter=1000, random_state=42),
238
+ ui_schema={
239
+ "alpha": ("float", None, "Reg strength"),
240
+ "l1_ratio": ("float", None, "L1 vs L2 mix"),
241
+ "max_iter": ("int", None, "Max iterations"),
242
+ "random_state": ("int", None, "Seed"),
243
+ },
244
+ ),
245
+ ("catboost", "CatBoostRegressor"): ModelSpec(
246
+ task="regression",
247
+ import_path="catboost.CatBoostRegressor",
248
+ defaults=dict(iterations=1000, depth=6, learning_rate=0.05, l2_leaf_reg=3.0, subsample=0.8,
249
+ loss_function="RMSE", random_state=42, verbose=False),
250
+ ui_schema={
251
+ "iterations": ("int", None, "Rounds"),
252
+ "depth": ("int", None, "Depth"),
253
+ "learning_rate": ("float", None, "Eta"),
254
+ "l2_leaf_reg": ("float", None, "L2 reg"),
255
+ "subsample": ("float", None, "Row subsample"),
256
+ "random_state": ("int", None, "Seed"),
257
+ },
258
+ ),
259
+ ("sklearn", "ExtraTreesRegressor"): ModelSpec(
260
+ task="regression",
261
+ import_path="sklearn.ensemble.ExtraTreesRegressor",
262
+ defaults=dict(n_estimators=400, max_depth=None, min_samples_split=2, min_samples_leaf=1,
263
+ random_state=42, n_jobs=-1),
264
+ ui_schema={
265
+ "n_estimators": ("int", None, "Trees"),
266
+ "max_depth": ("int", None, "Depth"),
267
+ "min_samples_split": ("int", None, "Min split"),
268
+ "min_samples_leaf": ("int", None, "Min leaf"),
269
+ "random_state": ("int", None, "Seed"),
270
+ },
271
+ ),
272
+ ("sklearn", "HistGradientBoostingRegressor"): ModelSpec(
273
+ task="regression",
274
+ import_path="sklearn.ensemble.HistGradientBoostingRegressor",
275
+ defaults=dict(max_depth=None, learning_rate=0.1, max_bins=255, l2_regularization=0.0,
276
+ early_stopping=True, random_state=42),
277
+ ui_schema={
278
+ "max_depth": ("int", None, "Depth (None=auto)"),
279
+ "learning_rate": ("float", None, "Eta"),
280
+ "max_bins": ("int", None, "Bins"),
281
+ "l2_regularization": ("float", None, "L2 reg"),
282
+ "early_stopping": ("bool", None, "Early stop"),
283
+ "random_state": ("int", None, "Seed"),
284
+ },
285
+ ),
286
+ }
287
+
288
+ # --------------- Helpers ---------------
289
+
290
+ def list_models(task: Optional[Task] = None) -> Dict[Tuple[str, str], ModelSpec]:
291
+ if task:
292
+ return {k: v for k, v in _REGISTRY.items() if v.task == task}
293
+ return dict(_REGISTRY)
294
+
295
+ def get_spec(library: str, algo: str) -> ModelSpec:
296
+ key = (library, algo)
297
+ if key not in _REGISTRY:
298
+ raise KeyError(f"Unknown model: {library}.{algo}")
299
+ return _REGISTRY[key]
300
+
301
+ def _lazy_import(import_path: str) -> Callable[..., Any]:
302
+ mod_name, cls_name = import_path.rsplit(".", 1)
303
+ mod = __import__(mod_name, fromlist=[cls_name])
304
+ return getattr(mod, cls_name)
305
+
306
+ def build_estimator(library: str, algo: str, params: Optional[Dict[str, Any]] = None):
307
+ spec = get_spec(library, algo)
308
+ Cls = _lazy_import(spec.import_path)
309
+ kwargs = dict(spec.defaults)
310
+ if params:
311
+ canon = {}
312
+ for k, v in params.items():
313
+ k2 = spec.aliases.get(k, k)
314
+ canon[k2] = v
315
+ kwargs.update({k: v for k, v in canon.items() if v is not None})
316
+ return Cls(**kwargs)
317
+
318
+ def ui_schema_for(library: str, algo: str) -> Dict[str, Tuple[str, Optional[Tuple[Any, ...]], Optional[str]]]:
319
+ return get_spec(library, algo).ui_schema
320
+
321
+ def infer_task_from_target(y) -> Task:
322
+ try:
323
+ n = int(y.nunique()) # pandas Series fast path
324
+ except Exception:
325
+ try:
326
+ n = len(set(y))
327
+ except Exception:
328
+ n = 10
329
+ return "classification" if n <= 3 else "regression"