syntaxmatrix 1.4.6__py3-none-any.whl → 2.5.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syntaxmatrix/__init__.py +13 -8
- syntaxmatrix/agentic/__init__.py +0 -0
- syntaxmatrix/agentic/agent_tools.py +24 -0
- syntaxmatrix/agentic/agents.py +810 -0
- syntaxmatrix/agentic/code_tools_registry.py +37 -0
- syntaxmatrix/agentic/model_templates.py +1790 -0
- syntaxmatrix/auth.py +308 -14
- syntaxmatrix/commentary.py +328 -0
- syntaxmatrix/core.py +993 -375
- syntaxmatrix/dataset_preprocessing.py +218 -0
- syntaxmatrix/db.py +92 -95
- syntaxmatrix/display.py +95 -121
- syntaxmatrix/generate_page.py +634 -0
- syntaxmatrix/gpt_models_latest.py +46 -0
- syntaxmatrix/history_store.py +26 -29
- syntaxmatrix/kernel_manager.py +96 -17
- syntaxmatrix/llm_store.py +1 -1
- syntaxmatrix/plottings.py +6 -0
- syntaxmatrix/profiles.py +64 -8
- syntaxmatrix/project_root.py +55 -43
- syntaxmatrix/routes.py +5072 -1398
- syntaxmatrix/session.py +19 -0
- syntaxmatrix/settings/logging.py +40 -0
- syntaxmatrix/settings/model_map.py +300 -33
- syntaxmatrix/settings/prompts.py +273 -62
- syntaxmatrix/settings/string_navbar.py +3 -3
- syntaxmatrix/static/docs.md +272 -0
- syntaxmatrix/static/icons/favicon.png +0 -0
- syntaxmatrix/static/icons/hero_bg.jpg +0 -0
- syntaxmatrix/templates/dashboard.html +608 -147
- syntaxmatrix/templates/docs.html +71 -0
- syntaxmatrix/templates/error.html +2 -3
- syntaxmatrix/templates/login.html +1 -0
- syntaxmatrix/templates/register.html +1 -0
- syntaxmatrix/ui_modes.py +14 -0
- syntaxmatrix/utils.py +2482 -159
- syntaxmatrix/vectorizer.py +16 -12
- {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/METADATA +20 -17
- syntaxmatrix-2.5.5.4.dist-info/RECORD +68 -0
- syntaxmatrix/model_templates.py +0 -30
- syntaxmatrix/static/icons/favicon.ico +0 -0
- syntaxmatrix-1.4.6.dist-info/RECORD +0 -54
- {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/WHEEL +0 -0
- {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/licenses/LICENSE.txt +0 -0
- {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py
CHANGED
|
@@ -1,50 +1,1188 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import re, os, textwrap
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import re, textwrap
|
|
4
3
|
import pandas as pd
|
|
5
|
-
import
|
|
4
|
+
import numpy as np
|
|
6
5
|
import warnings
|
|
7
|
-
from .model_templates import classification
|
|
8
|
-
import syntaxmatrix as smx
|
|
9
6
|
|
|
7
|
+
from difflib import get_close_matches
|
|
8
|
+
from typing import Iterable, Tuple, Dict
|
|
9
|
+
import inspect
|
|
10
|
+
from sklearn.preprocessing import OneHotEncoder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from syntaxmatrix.agentic.model_templates import (
|
|
14
|
+
classification, regression, multilabel_classification,
|
|
15
|
+
eda_overview, eda_correlation,
|
|
16
|
+
anomaly_detection, ts_anomaly_detection,
|
|
17
|
+
dimensionality_reduction, feature_selection,
|
|
18
|
+
time_series_forecasting, time_series_classification,
|
|
19
|
+
unknown_group_proxy_pack, viz_line,
|
|
20
|
+
clustering, recommendation, topic_modelling,
|
|
21
|
+
viz_pie, viz_count_bar, viz_box, viz_scatter,
|
|
22
|
+
viz_stacked_bar, viz_distribution, viz_area, viz_kde,
|
|
23
|
+
)
|
|
24
|
+
import ast
|
|
10
25
|
|
|
11
26
|
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
|
12
27
|
|
|
28
|
+
INJECTABLE_INTENTS = [
|
|
29
|
+
"classification",
|
|
30
|
+
"multilabel_classification",
|
|
31
|
+
"regression",
|
|
32
|
+
"anomaly_detection",
|
|
33
|
+
"time_series_forecasting",
|
|
34
|
+
"time_series_classification",
|
|
35
|
+
"ts_anomaly_detection",
|
|
36
|
+
"dimensionality_reduction",
|
|
37
|
+
"feature_selection",
|
|
38
|
+
"clustering",
|
|
39
|
+
"eda",
|
|
40
|
+
"correlation_analysis",
|
|
41
|
+
"visualisation",
|
|
42
|
+
"recommendation",
|
|
43
|
+
"topic_modelling",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def classify_ml_job(prompt: str) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Very-light intent classifier.
|
|
49
|
+
Returns one of:
|
|
50
|
+
'stat_test' | 'time_series' | 'clustering'
|
|
51
|
+
'classification' | 'regression' | 'eda'
|
|
52
|
+
"""
|
|
53
|
+
p = prompt.lower()
|
|
54
|
+
|
|
55
|
+
greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
|
|
56
|
+
if any(p.startswith(g) or p == g for g in greetings):
|
|
57
|
+
return "greeting"
|
|
13
58
|
|
|
14
|
-
#
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
59
|
+
# Feature selection / importance intent
|
|
60
|
+
if any(k in p for k in (
|
|
61
|
+
"feature selection", "select k best", "selectkbest", "rfe",
|
|
62
|
+
"mutual information", "feature importance", "permutation importance",
|
|
63
|
+
"feature engineering suggestions"
|
|
64
|
+
)):
|
|
65
|
+
return "feature_selection"
|
|
66
|
+
|
|
67
|
+
# Dimensionality reduction intent
|
|
68
|
+
if any(k in p for k in (
|
|
69
|
+
"pca", "principal component", "dimensionality reduction",
|
|
70
|
+
"reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
|
|
71
|
+
)):
|
|
72
|
+
return "dimensionality_reduction"
|
|
73
|
+
|
|
74
|
+
# Anomaly / outlier intent
|
|
75
|
+
if any(k in p for k in (
|
|
76
|
+
"anomaly", "anomalies", "outlier", "outliers", "novelty",
|
|
77
|
+
"fraud", "deviation", "rare event", "rare events", "odd pattern",
|
|
78
|
+
"suspicious"
|
|
79
|
+
)):
|
|
80
|
+
return "anomaly_detection"
|
|
18
81
|
|
|
19
|
-
|
|
82
|
+
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
83
|
+
return "stat_test"
|
|
84
|
+
if "forecast" in p or "prophet" in p:
|
|
85
|
+
return "time_series"
|
|
86
|
+
if "cluster" in p or "kmeans" in p:
|
|
87
|
+
return "clustering"
|
|
88
|
+
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
89
|
+
return "classification"
|
|
90
|
+
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
91
|
+
return "regression"
|
|
20
92
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
#
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
93
|
+
return "eda"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def harden_ai_code(code: str) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Make any AI-generated cell resilient:
|
|
99
|
+
- Safe seaborn wrappers + sentinel vars (boxplot/barplot/etc.)
|
|
100
|
+
- Remove 'numeric_only=' args
|
|
101
|
+
- Replace pd.concat(...) with _safe_concat(...)
|
|
102
|
+
- Relax 'required_cols' hard fails
|
|
103
|
+
- Make static numeric_vars dynamic
|
|
104
|
+
- Wrap the whole block in try/except so no exception bubbles up
|
|
105
|
+
"""
|
|
106
|
+
# Remove any LLM-added try/except blocks (hardener adds its own)
|
|
107
|
+
import re
|
|
108
|
+
|
|
109
|
+
def strip_placeholders(code: str) -> str:
|
|
110
|
+
code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
|
|
111
|
+
"show('⚠ Block skipped due to an error.')",
|
|
112
|
+
code)
|
|
113
|
+
code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
|
|
114
|
+
return code
|
|
115
|
+
|
|
116
|
+
def _indent(code: str, spaces: int = 4) -> str:
|
|
117
|
+
pad = " " * spaces
|
|
118
|
+
return "\n".join(pad + line for line in code.splitlines())
|
|
119
|
+
|
|
120
|
+
def _SMX_OHE(**k):
|
|
121
|
+
# normalise arg name across sklearn versions
|
|
122
|
+
if "sparse" in k and "sparse_output" not in k:
|
|
123
|
+
k["sparse_output"] = k.pop("sparse")
|
|
124
|
+
# default behaviour we want
|
|
125
|
+
k.setdefault("handle_unknown", "ignore")
|
|
126
|
+
k.setdefault("sparse_output", False)
|
|
127
|
+
try:
|
|
128
|
+
# if running on old sklearn without sparse_output, translate back
|
|
129
|
+
if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
|
|
130
|
+
if "sparse_output" in k:
|
|
131
|
+
k["sparse"] = k.pop("sparse_output")
|
|
132
|
+
return OneHotEncoder(**k)
|
|
133
|
+
except TypeError:
|
|
134
|
+
# final fallback: try legacy name
|
|
135
|
+
if "sparse_output" in k:
|
|
136
|
+
k["sparse"] = k.pop("sparse_output")
|
|
137
|
+
return OneHotEncoder(**k)
|
|
138
|
+
|
|
139
|
+
def _strip_stray_backrefs(code: str) -> str:
|
|
140
|
+
code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
|
|
141
|
+
code = re.sub(r'(?m)[;]\s*\\\d+\s*', '; ', code)
|
|
142
|
+
return code
|
|
143
|
+
|
|
144
|
+
def _wrap_metric_calls(code: str) -> str:
|
|
145
|
+
names = [
|
|
146
|
+
"r2_score","accuracy_score","precision_score","recall_score","f1_score",
|
|
147
|
+
"roc_auc_score","classification_report","confusion_matrix",
|
|
148
|
+
"mean_absolute_error","mean_absolute_percentage_error",
|
|
149
|
+
"explained_variance_score","log_loss","average_precision_score",
|
|
150
|
+
"precision_recall_fscore_support"
|
|
151
|
+
]
|
|
152
|
+
pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
|
|
153
|
+
def repl(m):
|
|
154
|
+
prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
|
|
155
|
+
name = m.group(2)
|
|
156
|
+
return f"_SMX_call({prefix}{name}, "
|
|
157
|
+
return pat.sub(repl, code)
|
|
158
|
+
|
|
159
|
+
def _smx_patch_mean_squared_error_squared_kw():
|
|
160
|
+
"""
|
|
161
|
+
sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
|
|
162
|
+
Patch the module attr so 'from sklearn.metrics import mean_squared_error'
|
|
163
|
+
receives a wrapper that drops 'squared' if the underlying call rejects it.
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
import sklearn.metrics as _sm
|
|
167
|
+
_orig = getattr(_sm, "mean_squared_error", None)
|
|
168
|
+
if not callable(_orig):
|
|
169
|
+
return
|
|
170
|
+
def _mse_compat(y_true, y_pred, *a, **k):
|
|
171
|
+
if "squared" in k:
|
|
172
|
+
try:
|
|
173
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
174
|
+
except TypeError:
|
|
175
|
+
k.pop("squared", None)
|
|
176
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
177
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
178
|
+
_sm.mean_squared_error = _mse_compat
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
def _smx_patch_kmeans_n_init_auto():
|
|
183
|
+
"""
|
|
184
|
+
sklearn>=1.4 accepts n_init='auto'; older versions want an int.
|
|
185
|
+
Patch sklearn.cluster.KMeans so 'auto' is converted to 10 if TypeError occurs.
|
|
186
|
+
"""
|
|
187
|
+
try:
|
|
188
|
+
import sklearn.cluster as _sc
|
|
189
|
+
_Orig = getattr(_sc, "KMeans", None)
|
|
190
|
+
if _Orig is None:
|
|
191
|
+
return
|
|
192
|
+
class KMeansCompat(_Orig):
|
|
193
|
+
def __init__(self, *a, **k):
|
|
194
|
+
if isinstance(k.get("n_init", None), str):
|
|
195
|
+
try:
|
|
196
|
+
super().__init__(*a, **k)
|
|
197
|
+
return
|
|
198
|
+
except TypeError:
|
|
199
|
+
k["n_init"] = 10
|
|
200
|
+
super().__init__(*a, **k)
|
|
201
|
+
_sc.KMeans = KMeansCompat
|
|
202
|
+
except Exception:
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
def _smx_patch_ohe_name_api():
|
|
206
|
+
"""
|
|
207
|
+
Guard get_feature_names_out on older OneHotEncoder.
|
|
208
|
+
Your templates already use _SMX_OHE; this adds a soft fallback for feature names.
|
|
209
|
+
"""
|
|
210
|
+
try:
|
|
211
|
+
from sklearn.preprocessing import OneHotEncoder as _OHE
|
|
212
|
+
_orig_get = getattr(_OHE, "get_feature_names_out", None)
|
|
213
|
+
if _orig_get is None:
|
|
214
|
+
# Monkey-patch instance method via mixin
|
|
215
|
+
def _fallback_get_feature_names_out(self, input_features=None):
|
|
216
|
+
cats = getattr(self, "categories_", None) or []
|
|
217
|
+
input_features = input_features or [f"x{i}" for i in range(len(cats))]
|
|
218
|
+
names = []
|
|
219
|
+
for base, cat_list in zip(input_features, cats):
|
|
220
|
+
for j, _ in enumerate(cat_list):
|
|
221
|
+
names.append(f"{base}__{j}")
|
|
222
|
+
return names
|
|
223
|
+
_OHE.get_feature_names_out = _fallback_get_feature_names_out
|
|
224
|
+
except Exception:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
# Register and run patches once per execution
|
|
228
|
+
for _patch in (
|
|
229
|
+
_smx_patch_mean_squared_error_squared_kw,
|
|
230
|
+
_smx_patch_kmeans_n_init_auto,
|
|
231
|
+
_smx_patch_ohe_name_api,
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
_patch()
|
|
235
|
+
except Exception:
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
PREFACE = (
|
|
239
|
+
"# === SMX Auto-Hardening Preface (do not edit) ===\n"
|
|
240
|
+
"import warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt\n"
|
|
241
|
+
"warnings.filterwarnings('ignore')\n"
|
|
242
|
+
"try:\n"
|
|
243
|
+
" import seaborn as sns\n"
|
|
244
|
+
"except Exception:\n"
|
|
245
|
+
" class _Dummy:\n"
|
|
246
|
+
" def __getattr__(self, name):\n"
|
|
247
|
+
" def _f(*a, **k):\n"
|
|
248
|
+
" from syntaxmatrix.display import show\n"
|
|
249
|
+
" show('⚠ seaborn not available; plot skipped.')\n"
|
|
250
|
+
" return _f\n"
|
|
251
|
+
" sns = _Dummy()\n"
|
|
252
|
+
"\n"
|
|
253
|
+
"from syntaxmatrix.display import show as _SMX_base_show\n"
|
|
254
|
+
"def _SMX_caption_from_ctx():\n"
|
|
255
|
+
" g = globals()\n"
|
|
256
|
+
" t = g.get('refined_question') or g.get('askai_question') or 'Table'\n"
|
|
257
|
+
" return str(t).strip().splitlines()[0][:120]\n"
|
|
258
|
+
"\n"
|
|
259
|
+
"def show(obj, title=None):\n"
|
|
260
|
+
" try:\n"
|
|
261
|
+
" import pandas as pd\n"
|
|
262
|
+
" if isinstance(obj, pd.DataFrame):\n"
|
|
263
|
+
" cap = (title or _SMX_caption_from_ctx())\n"
|
|
264
|
+
" try:\n"
|
|
265
|
+
" return _SMX_base_show(obj.style.set_caption(cap))\n"
|
|
266
|
+
" except Exception:\n"
|
|
267
|
+
" pass\n"
|
|
268
|
+
" except Exception:\n"
|
|
269
|
+
" pass\n"
|
|
270
|
+
" return _SMX_base_show(obj)\n"
|
|
271
|
+
"\n"
|
|
272
|
+
"def _SMX_axes_have_titles(fig=None):\n"
|
|
273
|
+
" import matplotlib.pyplot as _plt\n"
|
|
274
|
+
" fig = fig or _plt.gcf()\n"
|
|
275
|
+
" try:\n"
|
|
276
|
+
" for _ax in fig.get_axes():\n"
|
|
277
|
+
" if (_ax.get_title() or '').strip():\n"
|
|
278
|
+
" return True\n"
|
|
279
|
+
" except Exception:\n"
|
|
280
|
+
" pass\n"
|
|
281
|
+
" return False\n"
|
|
282
|
+
"\n"
|
|
283
|
+
"def _SMX_export_png():\n"
|
|
284
|
+
" import io, base64\n"
|
|
285
|
+
" fig = plt.gcf()\n"
|
|
286
|
+
" try:\n"
|
|
287
|
+
" if not _SMX_axes_have_titles(fig):\n"
|
|
288
|
+
" fig.suptitle(_SMX_caption_from_ctx(), fontsize=10)\n"
|
|
289
|
+
" except Exception:\n"
|
|
290
|
+
" pass\n"
|
|
291
|
+
" buf = io.BytesIO()\n"
|
|
292
|
+
" plt.savefig(buf, format='png', bbox_inches='tight')\n"
|
|
293
|
+
" buf.seek(0)\n"
|
|
294
|
+
" from IPython.display import display, HTML\n"
|
|
295
|
+
" _img = base64.b64encode(buf.read()).decode('ascii')\n"
|
|
296
|
+
" display(HTML(f\"<img src='data:image/png;base64,{_img}' style='max-width:100%;height:auto;border:1px solid #ccc;border-radius:4px;'/>\"))\n"
|
|
297
|
+
" plt.close()\n"
|
|
298
|
+
"\n"
|
|
299
|
+
"def _pick_df():\n"
|
|
300
|
+
" return globals().get('df', None)\n"
|
|
301
|
+
"\n"
|
|
302
|
+
"def _pick_ax_slot():\n"
|
|
303
|
+
" ax = None\n"
|
|
304
|
+
" try:\n"
|
|
305
|
+
" _axes = globals().get('axes', None)\n"
|
|
306
|
+
" import numpy as _np\n"
|
|
307
|
+
" if _axes is not None:\n"
|
|
308
|
+
" arr = _np.ravel(_axes)\n"
|
|
309
|
+
" for _a in arr:\n"
|
|
310
|
+
" try:\n"
|
|
311
|
+
" if hasattr(_a,'has_data') and not _a.has_data():\n"
|
|
312
|
+
" ax = _a; break\n"
|
|
313
|
+
" except Exception:\n"
|
|
314
|
+
" continue\n"
|
|
315
|
+
" except Exception:\n"
|
|
316
|
+
" ax = None\n"
|
|
317
|
+
" return ax\n"
|
|
318
|
+
"\n"
|
|
319
|
+
"def _first_numeric(_d):\n"
|
|
320
|
+
" import numpy as np, pandas as pd\n"
|
|
321
|
+
" try:\n"
|
|
322
|
+
" preferred = [\"median_house_value\", \"price\", \"value\", \"target\", \"label\", \"y\"]\n"
|
|
323
|
+
" for c in preferred:\n"
|
|
324
|
+
" if c in _d.columns and pd.api.types.is_numeric_dtype(_d[c]):\n"
|
|
325
|
+
" return c\n"
|
|
326
|
+
" cols = _d.select_dtypes(include=[np.number]).columns.tolist()\n"
|
|
327
|
+
" return cols[0] if cols else None\n"
|
|
328
|
+
" except Exception:\n"
|
|
329
|
+
" return None\n"
|
|
330
|
+
"\n"
|
|
331
|
+
"def _first_categorical(_d):\n"
|
|
332
|
+
" import pandas as pd, numpy as np\n"
|
|
333
|
+
" try:\n"
|
|
334
|
+
" num = set(_d.select_dtypes(include=[np.number]).columns.tolist())\n"
|
|
335
|
+
" cand = [c for c in _d.columns if c not in num and _d[c].nunique(dropna=True) <= 50]\n"
|
|
336
|
+
" return cand[0] if cand else None\n"
|
|
337
|
+
" except Exception:\n"
|
|
338
|
+
" return None\n"
|
|
339
|
+
"\n"
|
|
340
|
+
"boxplot = barplot = histplot = distplot = lineplot = countplot = heatmap = pairplot = None\n"
|
|
341
|
+
"\n"
|
|
342
|
+
"def _safe_plot(func, *args, **kwargs):\n"
|
|
343
|
+
" try:\n"
|
|
344
|
+
" ax = func(*args, **kwargs)\n"
|
|
345
|
+
" if ax is None:\n"
|
|
346
|
+
" ax = plt.gca()\n"
|
|
347
|
+
" try:\n"
|
|
348
|
+
" if hasattr(ax, 'has_data') and not ax.has_data():\n"
|
|
349
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
350
|
+
" _show('⚠ Empty plot: no data drawn.')\n"
|
|
351
|
+
" except Exception:\n"
|
|
352
|
+
" pass\n"
|
|
353
|
+
" try: plt.tight_layout()\n"
|
|
354
|
+
" except Exception: pass\n"
|
|
355
|
+
" return ax\n"
|
|
356
|
+
" except Exception as e:\n"
|
|
357
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
358
|
+
" _show(f'⚠ Plot skipped: {type(e).__name__}: {e}')\n"
|
|
359
|
+
" return None\n"
|
|
360
|
+
"\n"
|
|
361
|
+
"def SB_histplot(*a, **k):\n"
|
|
362
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
363
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
364
|
+
" if (not a or _sentinel) and not k:\n"
|
|
365
|
+
" d = _pick_df()\n"
|
|
366
|
+
" if d is not None:\n"
|
|
367
|
+
" x = _first_numeric(d)\n"
|
|
368
|
+
" if x is not None:\n"
|
|
369
|
+
" def _draw():\n"
|
|
370
|
+
" plt.hist(d[x].dropna())\n"
|
|
371
|
+
" ax = plt.gca()\n"
|
|
372
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
373
|
+
" ax.set_title(f'Distribution of {x}')\n"
|
|
374
|
+
" return ax\n"
|
|
375
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
376
|
+
" if _missing:\n"
|
|
377
|
+
" return _safe_plot(lambda **kw: plt.hist([]))\n"
|
|
378
|
+
" if _sentinel:\n"
|
|
379
|
+
" a = a[1:]\n"
|
|
380
|
+
" return _safe_plot(getattr(sns,'histplot', plt.hist), *a, **k)\n"
|
|
381
|
+
"\n"
|
|
382
|
+
"def SB_barplot(*a, **k):\n"
|
|
383
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
384
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
385
|
+
" _ax = k.get('ax') or _pick_ax_slot()\n"
|
|
386
|
+
" if _ax is not None:\n"
|
|
387
|
+
" try: plt.sca(_ax)\n"
|
|
388
|
+
" except Exception: pass\n"
|
|
389
|
+
" k.setdefault('ax', _ax)\n"
|
|
390
|
+
" if (not a or _sentinel) and not k:\n"
|
|
391
|
+
" d = _pick_df()\n"
|
|
392
|
+
" if d is not None:\n"
|
|
393
|
+
" x = _first_categorical(d)\n"
|
|
394
|
+
" y = _first_numeric(d)\n"
|
|
395
|
+
" if x and y:\n"
|
|
396
|
+
" import pandas as _pd\n"
|
|
397
|
+
" g = d.groupby(x)[y].mean().reset_index()\n"
|
|
398
|
+
" def _draw():\n"
|
|
399
|
+
" if _missing:\n"
|
|
400
|
+
" plt.bar(g[x], g[y])\n"
|
|
401
|
+
" else:\n"
|
|
402
|
+
" sns.barplot(data=g, x=x, y=y, ax=k.get('ax'))\n"
|
|
403
|
+
" ax = plt.gca()\n"
|
|
404
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
405
|
+
" ax.set_title(f'Mean {y} by {x}')\n"
|
|
406
|
+
" return ax\n"
|
|
407
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
408
|
+
" if _missing:\n"
|
|
409
|
+
" return _safe_plot(lambda **kw: plt.bar([], []))\n"
|
|
410
|
+
" if _sentinel:\n"
|
|
411
|
+
" a = a[1:]\n"
|
|
412
|
+
" return _safe_plot(sns.barplot, *a, **k)\n"
|
|
413
|
+
"\n"
|
|
414
|
+
"def SB_boxplot(*a, **k):\n"
|
|
415
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
416
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
417
|
+
" _ax = k.get('ax') or _pick_ax_slot()\n"
|
|
418
|
+
" if _ax is not None:\n"
|
|
419
|
+
" try: plt.sca(_ax)\n"
|
|
420
|
+
" except Exception: pass\n"
|
|
421
|
+
" k.setdefault('ax', _ax)\n"
|
|
422
|
+
" if (not a or _sentinel) and not k:\n"
|
|
423
|
+
" d = _pick_df()\n"
|
|
424
|
+
" if d is not None:\n"
|
|
425
|
+
" x = _first_categorical(d)\n"
|
|
426
|
+
" y = _first_numeric(d)\n"
|
|
427
|
+
" if x and y:\n"
|
|
428
|
+
" def _draw():\n"
|
|
429
|
+
" if _missing:\n"
|
|
430
|
+
" plt.boxplot(d[y].dropna())\n"
|
|
431
|
+
" else:\n"
|
|
432
|
+
" sns.boxplot(data=d, x=x, y=y, ax=k.get('ax'))\n"
|
|
433
|
+
" ax = plt.gca()\n"
|
|
434
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
435
|
+
" ax.set_title(f'Distribution of {y} by {x}')\n"
|
|
436
|
+
" return ax\n"
|
|
437
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
438
|
+
" if _missing:\n"
|
|
439
|
+
" return _safe_plot(lambda **kw: plt.boxplot([]))\n"
|
|
440
|
+
" if _sentinel:\n"
|
|
441
|
+
" a = a[1:]\n"
|
|
442
|
+
" return _safe_plot(sns.boxplot, *a, **k)\n"
|
|
443
|
+
"\n"
|
|
444
|
+
"def SB_scatterplot(*a, **k):\n"
|
|
445
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
446
|
+
" fn = getattr(sns,'scatterplot', None)\n"
|
|
447
|
+
" # If seaborn is unavailable OR the caller passed (data=..., x='col', y='col'),\n"
|
|
448
|
+
" # use a robust matplotlib path that looks up data and coerces to numeric.\n"
|
|
449
|
+
" if _missing or fn is None:\n"
|
|
450
|
+
" data = k.get('data'); x = k.get('x'); y = k.get('y')\n"
|
|
451
|
+
" if data is not None and isinstance(x, str) and isinstance(y, str) and x in data.columns and y in data.columns:\n"
|
|
452
|
+
" xs = pd.to_numeric(data[x], errors='coerce')\n"
|
|
453
|
+
" ys = pd.to_numeric(data[y], errors='coerce')\n"
|
|
454
|
+
" m = xs.notna() & ys.notna()\n"
|
|
455
|
+
" def _draw():\n"
|
|
456
|
+
" plt.scatter(xs[m], ys[m])\n"
|
|
457
|
+
" ax = plt.gca()\n"
|
|
458
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
459
|
+
" ax.set_title(f'{y} vs {x}')\n"
|
|
460
|
+
" return ax\n"
|
|
461
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
462
|
+
" # else: fall back to auto-pick two numeric columns\n"
|
|
463
|
+
" d = _pick_df()\n"
|
|
464
|
+
" if d is not None:\n"
|
|
465
|
+
" num = d.select_dtypes(include=[np.number]).columns.tolist()\n"
|
|
466
|
+
" if len(num) >= 2:\n"
|
|
467
|
+
" def _draw2():\n"
|
|
468
|
+
" plt.scatter(d[num[0]], d[num[1]])\n"
|
|
469
|
+
" ax = plt.gca()\n"
|
|
470
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
471
|
+
" ax.set_title(f'{num[1]} vs {num[0]}')\n"
|
|
472
|
+
" return ax\n"
|
|
473
|
+
" return _safe_plot(lambda **kw: _draw2())\n"
|
|
474
|
+
" return _safe_plot(lambda **kw: plt.scatter([], []))\n"
|
|
475
|
+
" # seaborn path\n"
|
|
476
|
+
" return _safe_plot(fn, *a, **k)\n"
|
|
477
|
+
"\n"
|
|
478
|
+
"def SB_heatmap(*a, **k):\n"
|
|
479
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
480
|
+
" data = None\n"
|
|
481
|
+
" if a:\n"
|
|
482
|
+
" data = a[0]\n"
|
|
483
|
+
" elif 'data' in k:\n"
|
|
484
|
+
" data = k['data']\n"
|
|
485
|
+
" if data is None:\n"
|
|
486
|
+
" d = _pick_df()\n"
|
|
487
|
+
" try:\n"
|
|
488
|
+
" if d is not None:\n"
|
|
489
|
+
" import numpy as _np\n"
|
|
490
|
+
" data = d.select_dtypes(include=[_np.number]).corr()\n"
|
|
491
|
+
" except Exception:\n"
|
|
492
|
+
" data = None\n"
|
|
493
|
+
" if data is None:\n"
|
|
494
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
495
|
+
" _show('⚠ Heatmap skipped: no data.')\n"
|
|
496
|
+
" return None\n"
|
|
497
|
+
" if not _missing and hasattr(sns, 'heatmap'):\n"
|
|
498
|
+
" _k = {kk: vv for kk, vv in k.items() if kk != 'data'}\n"
|
|
499
|
+
" def _draw():\n"
|
|
500
|
+
" ax = sns.heatmap(data, **_k)\n"
|
|
501
|
+
" try:\n"
|
|
502
|
+
" ax = ax or plt.gca()\n"
|
|
503
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
504
|
+
" ax.set_title('Correlation Heatmap')\n"
|
|
505
|
+
" except Exception:\n"
|
|
506
|
+
" pass\n"
|
|
507
|
+
" return ax\n"
|
|
508
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
509
|
+
" def _mat_heat():\n"
|
|
510
|
+
" im = plt.imshow(data, aspect='auto')\n"
|
|
511
|
+
" try: plt.colorbar()\n"
|
|
512
|
+
" except Exception: pass\n"
|
|
513
|
+
" try:\n"
|
|
514
|
+
" cols = list(getattr(data, 'columns', []))\n"
|
|
515
|
+
" rows = list(getattr(data, 'index', []))\n"
|
|
516
|
+
" if cols: plt.xticks(range(len(cols)), cols, rotation=90)\n"
|
|
517
|
+
" if rows: plt.yticks(range(len(rows)), rows)\n"
|
|
518
|
+
" except Exception:\n"
|
|
519
|
+
" pass\n"
|
|
520
|
+
" ax = plt.gca()\n"
|
|
521
|
+
" try:\n"
|
|
522
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
523
|
+
" ax.set_title('Correlation Heatmap')\n"
|
|
524
|
+
" except Exception:\n"
|
|
525
|
+
" pass\n"
|
|
526
|
+
" return ax\n"
|
|
527
|
+
" return _safe_plot(lambda **kw: _mat_heat())\n"
|
|
528
|
+
"\n"
|
|
529
|
+
"def _safe_concat(objs, **kwargs):\n"
|
|
530
|
+
" import pandas as _pd\n"
|
|
531
|
+
" if objs is None: return _pd.DataFrame()\n"
|
|
532
|
+
" if isinstance(objs,(list,tuple)) and len(objs)==0: return _pd.DataFrame()\n"
|
|
533
|
+
" try: return _pd.concat(objs, **kwargs)\n"
|
|
534
|
+
" except Exception as e:\n"
|
|
535
|
+
" show(f'⚠ concat skipped: {e}')\n"
|
|
536
|
+
" return _pd.DataFrame()\n"
|
|
537
|
+
"\n"
|
|
538
|
+
"from sklearn.preprocessing import OneHotEncoder\n"
|
|
539
|
+
"import inspect\n"
|
|
540
|
+
"def _SMX_OHE(**k):\n"
|
|
541
|
+
" # normalise arg name across sklearn versions\n"
|
|
542
|
+
" if 'sparse' in k and 'sparse_output' not in k:\n"
|
|
543
|
+
" k['sparse_output'] = k.pop('sparse')\n"
|
|
544
|
+
" k.setdefault('handle_unknown','ignore')\n"
|
|
545
|
+
" k.setdefault('sparse_output', False)\n"
|
|
546
|
+
" try:\n"
|
|
547
|
+
" sig = inspect.signature(OneHotEncoder)\n"
|
|
548
|
+
" if 'sparse_output' not in sig.parameters and 'sparse_output' in k:\n"
|
|
549
|
+
" k['sparse'] = k.pop('sparse_output')\n"
|
|
550
|
+
" except Exception:\n"
|
|
551
|
+
" if 'sparse_output' in k:\n"
|
|
552
|
+
" k['sparse'] = k.pop('sparse_output')\n"
|
|
553
|
+
" return OneHotEncoder(**k)\n"
|
|
554
|
+
"\n"
|
|
555
|
+
"import numpy as _np\n"
|
|
556
|
+
"def _SMX_mm(a, b):\n"
|
|
557
|
+
" try:\n"
|
|
558
|
+
" return a @ b # normal path\n"
|
|
559
|
+
" except Exception:\n"
|
|
560
|
+
" try:\n"
|
|
561
|
+
" A = _np.asarray(a); B = _np.asarray(b)\n"
|
|
562
|
+
" # If same 2D shape (e.g. (n,k) & (n,k)), treat as row-wise dot\n"
|
|
563
|
+
" if A.ndim==2 and B.ndim==2 and A.shape==B.shape:\n"
|
|
564
|
+
" return (A * B).sum(axis=1)\n"
|
|
565
|
+
" # Otherwise try element-wise product (broadcast if possible)\n"
|
|
566
|
+
" return A * B\n"
|
|
567
|
+
" except Exception as e:\n"
|
|
568
|
+
" from syntaxmatrix.display import show\n"
|
|
569
|
+
" show(f'⚠ Matmul relaxed: {type(e).__name__}: {e}'); return _np.nan\n"
|
|
570
|
+
"\n"
|
|
571
|
+
"def _SMX_call(fn, *a, **k):\n"
|
|
572
|
+
" try:\n"
|
|
573
|
+
" return fn(*a, **k)\n"
|
|
574
|
+
" except TypeError as e:\n"
|
|
575
|
+
" msg = str(e)\n"
|
|
576
|
+
" if \"unexpected keyword argument 'squared'\" in msg:\n"
|
|
577
|
+
" k.pop('squared', None)\n"
|
|
578
|
+
" return fn(*a, **k)\n"
|
|
579
|
+
" raise\n"
|
|
580
|
+
"\n"
|
|
581
|
+
"def _SMX_rmse(y_true, y_pred):\n"
|
|
582
|
+
" try:\n"
|
|
583
|
+
" from sklearn.metrics import mean_squared_error as _mse\n"
|
|
584
|
+
" try:\n"
|
|
585
|
+
" return _mse(y_true, y_pred, squared=False)\n"
|
|
586
|
+
" except TypeError:\n"
|
|
587
|
+
" return (_mse(y_true, y_pred)) ** 0.5\n"
|
|
588
|
+
" except Exception:\n"
|
|
589
|
+
" import numpy as _np\n"
|
|
590
|
+
" yt = _np.asarray(y_true, dtype=float)\n"
|
|
591
|
+
" yp = _np.asarray(y_pred, dtype=float)\n"
|
|
592
|
+
" diff = yt - yp\n"
|
|
593
|
+
" return float((_np.mean(diff * diff)) ** 0.5)\n"
|
|
594
|
+
"\n"
|
|
595
|
+
"import pandas as _pd\n"
|
|
596
|
+
"import numpy as _np\n"
|
|
597
|
+
"def _SMX_autocoerce_dates(_df):\n"
|
|
598
|
+
" if _df is None or not hasattr(_df, 'columns'): return\n"
|
|
599
|
+
" for c in list(_df.columns):\n"
|
|
600
|
+
" s = _df[c]\n"
|
|
601
|
+
" n = str(c).lower()\n"
|
|
602
|
+
" if _pd.api.types.is_datetime64_any_dtype(s):\n"
|
|
603
|
+
" continue\n"
|
|
604
|
+
" if _pd.api.types.is_object_dtype(s) or ('date' in n or 'time' in n or 'timestamp' in n or n.endswith('_dt')):\n"
|
|
605
|
+
" try:\n"
|
|
606
|
+
" conv = _pd.to_datetime(s, errors='coerce', utc=True).dt.tz_localize(None)\n"
|
|
607
|
+
" # accept only if at least 10% (min 3) parse as dates\n"
|
|
608
|
+
" if getattr(conv, 'notna', lambda: _pd.Series([]))().sum() >= max(3, int(0.1*len(_df))):\n"
|
|
609
|
+
" _df[c] = conv\n"
|
|
610
|
+
" except Exception:\n"
|
|
611
|
+
" pass\n"
|
|
612
|
+
"\n"
|
|
613
|
+
"def _SMX_autocoerce_numeric(_df, cols):\n"
|
|
614
|
+
" if _df is None: return\n"
|
|
615
|
+
" for c in cols:\n"
|
|
616
|
+
" if c in getattr(_df, 'columns', []):\n"
|
|
617
|
+
" try:\n"
|
|
618
|
+
" _df[c] = _pd.to_numeric(_df[c], errors='coerce')\n"
|
|
619
|
+
" except Exception:\n"
|
|
620
|
+
" pass\n"
|
|
621
|
+
"\n"
|
|
622
|
+
"def show(obj, title=None):\n"
|
|
623
|
+
" try:\n"
|
|
624
|
+
" import pandas as pd, numbers\n"
|
|
625
|
+
" cap = (title or _SMX_caption_from_ctx())\n"
|
|
626
|
+
" # 1) DataFrame → Styler with caption\n"
|
|
627
|
+
" if isinstance(obj, pd.DataFrame):\n"
|
|
628
|
+
" try: return _SMX_base_show(obj.style.set_caption(cap))\n"
|
|
629
|
+
" except Exception: pass\n"
|
|
630
|
+
" # 2) dict of scalars → DataFrame with caption\n"
|
|
631
|
+
" if isinstance(obj, dict) and all(isinstance(v, numbers.Number) for v in obj.values()):\n"
|
|
632
|
+
" df_ = pd.DataFrame({'metric': list(obj.keys()), 'value': list(obj.values())})\n"
|
|
633
|
+
" try: return _SMX_base_show(df_.style.set_caption(cap))\n"
|
|
634
|
+
" except Exception: return _SMX_base_show(df_)\n"
|
|
635
|
+
" except Exception:\n"
|
|
636
|
+
" pass\n"
|
|
637
|
+
" return _SMX_base_show(obj)\n"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
PREFACE_IMPORT = "from syntaxmatrix.smx_preface import *\n"
|
|
641
|
+
# if PREFACE not in code:
|
|
642
|
+
# code = PREFACE_IMPORT + code
|
|
643
|
+
|
|
644
|
+
fixed = code
|
|
645
|
+
|
|
646
|
+
fixed = re.sub(
|
|
647
|
+
r"(?s)^\s*try:\s*(.*?)\s*except\s+Exception\s+as\s+\w+:\s*\n\s*show\([^\n]*\)\s*$",
|
|
648
|
+
r"\1",
|
|
649
|
+
fixed.strip()
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# 1) Strip numeric_only=... (version-agnostic)
|
|
653
|
+
fixed = re.sub(r",\s*numeric_only\s*=\s*(True|False|None)", "", fixed, flags=re.I)
|
|
654
|
+
fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\s*,\s*", "", fixed, flags=re.I)
|
|
655
|
+
fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\b", "", fixed, flags=re.I)
|
|
656
|
+
|
|
657
|
+
# 2) Use safe seaborn wrappers
|
|
658
|
+
fixed = re.sub(r"\bsns\.boxplot\s*\(", "SB_boxplot(", fixed)
|
|
659
|
+
fixed = re.sub(r"\bsns\.barplot\s*\(", "SB_barplot(", fixed)
|
|
660
|
+
fixed = re.sub(r"\bsns\.histplot\s*\(", "SB_histplot(", fixed)
|
|
661
|
+
fixed = re.sub(r"\bsns\.scatterplot\s*\(", "SB_scatterplot(", fixed)
|
|
662
|
+
|
|
663
|
+
# 3) Guard concat calls
|
|
664
|
+
fixed = re.sub(r"\bpd\.concat\s*\(", "_safe_concat(", fixed)
|
|
665
|
+
fixed = re.sub(r"\bOneHotEncoder\s*\(", "_SMX_OHE(", fixed)
|
|
666
|
+
# Route np.dot to tolerant matmul
|
|
667
|
+
fixed = re.sub(r"\bnp\.dot\s*\(", "_SMX_mm(", fixed)
|
|
668
|
+
fixed = re.sub(r"(df\s*\[[^\]]+\])\s*\.dt", r"SMX_dt(\1).dt", fixed)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
# 4) Relax any 'required_cols' hard failure blocks
|
|
672
|
+
fixed = re.sub(
|
|
673
|
+
r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
674
|
+
"required_cols = [c for c in df.columns]\n# (relaxed by SMX hardener)",
|
|
675
|
+
fixed,
|
|
676
|
+
flags=re.S,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# 5) Make static numeric_vars lists dynamic
|
|
680
|
+
fixed = re.sub(
|
|
681
|
+
r"\bnumeric_vars\s*=\s*\[.*?\]",
|
|
682
|
+
"numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
683
|
+
fixed,
|
|
684
|
+
flags=re.S,
|
|
685
|
+
)
|
|
686
|
+
# normalise all .dt usages on df[...] / df.attr / df.loc[...] to SMX_dt(...)
|
|
687
|
+
fixed = re.sub(
|
|
688
|
+
r"((?:df\s*(?:\.\s*(?:loc|iloc)\s*)?\[[^\]]+\]|df\s*\.\s*[A-Za-z_]\w*))\s*\.dt\b",
|
|
689
|
+
lambda m: f"SMX_dt({m.group(1)}).dt",
|
|
690
|
+
fixed
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
try:
|
|
694
|
+
class _SMXMatmulRewriter(ast.NodeTransformer):
|
|
695
|
+
def visit_BinOp(self, node):
|
|
696
|
+
self.generic_visit(node)
|
|
697
|
+
if isinstance(node.op, ast.MatMult):
|
|
698
|
+
return ast.Call(func=ast.Name(id="_SMX_mm", ctx=ast.Load()),
|
|
699
|
+
args=[node.left, node.right], keywords=[])
|
|
700
|
+
return node
|
|
701
|
+
_tree = ast.parse(fixed)
|
|
702
|
+
_tree = _SMXMatmulRewriter().visit(_tree)
|
|
703
|
+
fixed = ast.unparse(_tree)
|
|
704
|
+
except Exception:
|
|
705
|
+
# If AST rewrite fails, keep original; _SMX_mm will still handle np.dot(...)
|
|
706
|
+
pass
|
|
707
|
+
|
|
708
|
+
# 6) Final safety wrapper
|
|
709
|
+
fixed = fixed.replace("\t", " ")
|
|
710
|
+
fixed = textwrap.dedent(fixed).strip("\n")
|
|
711
|
+
|
|
712
|
+
fixed = _strip_stray_backrefs(fixed)
|
|
713
|
+
fixed = _wrap_metric_calls(fixed)
|
|
714
|
+
|
|
715
|
+
# If the transformed code is still not syntactically valid, fall back to a
|
|
716
|
+
# very defensive generic snippet that depends only on `df`. This guarantees
|
|
717
|
+
try:
|
|
718
|
+
ast.parse(fixed)
|
|
719
|
+
except (SyntaxError, IndentationError):
|
|
720
|
+
fixed = (
|
|
721
|
+
"import pandas as pd\n"
|
|
722
|
+
"df = df.copy()\n"
|
|
723
|
+
"_info = {\n"
|
|
724
|
+
" 'rows': len(df),\n"
|
|
725
|
+
" 'cols': len(df.columns),\n"
|
|
726
|
+
" 'numeric_cols': len(df.select_dtypes(include=['number','bool']).columns),\n"
|
|
727
|
+
" 'categorical_cols': len(df.select_dtypes(exclude=['number','bool']).columns),\n"
|
|
728
|
+
"}\n"
|
|
729
|
+
"show(df.head(), title='Sample of data')\n"
|
|
730
|
+
"show(_info, title='Dataset summary')\n"
|
|
731
|
+
"try:\n"
|
|
732
|
+
" _num = df.select_dtypes(include=['number','bool']).columns.tolist()\n"
|
|
733
|
+
" if _num:\n"
|
|
734
|
+
" SB_histplot()\n"
|
|
735
|
+
" _SMX_export_png()\n"
|
|
736
|
+
"except Exception as e:\n"
|
|
737
|
+
" show(f\"⚠ Fallback visualisation failed: {type(e).__name__}: {e}\")\n"
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Fix placeholder Ellipsis handlers from LLM
|
|
741
|
+
fixed = re.sub(
|
|
742
|
+
r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
|
|
743
|
+
"except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
|
|
744
|
+
fixed,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
wrapped = PREFACE + "try:\n" + _indent(fixed) + "\nexcept Exception as e:\n show(...)\n"
|
|
748
|
+
wrapped = wrapped.lstrip()
|
|
749
|
+
return wrapped
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def indent_code(code: str, spaces: int = 4) -> str:
|
|
753
|
+
pad = " " * spaces
|
|
754
|
+
return "\n".join(pad + line for line in code.splitlines())
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def wrap_llm_code_safe(code: str) -> str:
|
|
758
|
+
# Swallow any runtime error from the LLM block instead of crashing the run
|
|
759
|
+
return (
|
|
760
|
+
"# __SAFE_WRAPPED__\n"
|
|
761
|
+
"try:\n" + indent_code(code) + "\n"
|
|
762
|
+
"except Exception as e:\n"
|
|
763
|
+
" from syntaxmatrix.display import show\n"
|
|
764
|
+
" show(f\"⚠️ Skipped LLM block due to: {type(e).__name__}: {e}\")\n"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def fix_boxplot_placeholder(code: str) -> str:
|
|
769
|
+
# Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
|
|
770
|
+
return re.sub(
|
|
771
|
+
r"sns\.boxplot\(\s*boxplot\s*\)",
|
|
772
|
+
"sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
|
|
773
|
+
code
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def relax_required_columns(code: str) -> str:
|
|
778
|
+
# Remove hard failure on required_cols; keep a soft filter instead
|
|
779
|
+
return re.sub(
|
|
780
|
+
r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
781
|
+
"required_cols = [c for c in df.columns]\n",
|
|
782
|
+
code,
|
|
783
|
+
flags=re.S
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def make_numeric_vars_dynamic(code: str) -> str:
|
|
788
|
+
# Replace any static numeric_vars list with a dynamic selection
|
|
789
|
+
return re.sub(
|
|
790
|
+
r"numeric_vars\s*=\s*\[.*?\]",
|
|
791
|
+
"numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
792
|
+
code,
|
|
793
|
+
flags=re.S
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def auto_inject_template(code: str, intents, df) -> str:
|
|
798
|
+
"""If the LLM forgot the core logic, prepend a skeleton block."""
|
|
799
|
+
|
|
800
|
+
has_fit = ".fit(" in code
|
|
801
|
+
has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
|
|
802
|
+
|
|
803
|
+
UNKNOWN_TOKENS = {
|
|
804
|
+
"unknown","not reported","not_reported","not known","n/a","na",
|
|
805
|
+
"none","nan","missing","unreported","unspecified","null","-",""
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
# --- Safe template caller: passes only supported kwargs, falls back cleanly ---
|
|
809
|
+
def _call_template(func, df, **hints):
|
|
810
|
+
import inspect
|
|
811
|
+
try:
|
|
812
|
+
params = inspect.signature(func).parameters
|
|
813
|
+
kw = {k: v for k, v in hints.items() if k in params}
|
|
814
|
+
try:
|
|
815
|
+
return func(df, **kw)
|
|
816
|
+
except TypeError:
|
|
817
|
+
# In case the template changed its signature at runtime
|
|
818
|
+
return func(df)
|
|
819
|
+
except Exception:
|
|
820
|
+
# Absolute safety net
|
|
821
|
+
try:
|
|
822
|
+
return func(df)
|
|
823
|
+
except Exception:
|
|
824
|
+
# As a last resort, return empty code so we don't 500
|
|
825
|
+
return ""
|
|
826
|
+
|
|
827
|
+
def _guess_classification_target(df: pd.DataFrame) -> str | None:
|
|
828
|
+
cols = list(df.columns)
|
|
829
|
+
|
|
830
|
+
# Helper: does this column look like a sensible label?
|
|
831
|
+
def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
|
|
832
|
+
try:
|
|
833
|
+
nunq = s.dropna().nunique()
|
|
834
|
+
except Exception:
|
|
835
|
+
return False
|
|
836
|
+
# need at least 2 classes, but not hundreds
|
|
837
|
+
if nunq < 2 or nunq > 20:
|
|
838
|
+
return False
|
|
839
|
+
bad_name_keys = ("id", "identifier", "index", "uuid", "key")
|
|
840
|
+
name = str(col_name).lower()
|
|
841
|
+
if any(k in name for k in bad_name_keys):
|
|
842
|
+
return False
|
|
843
|
+
return True
|
|
844
|
+
|
|
845
|
+
# 1) columns whose names look like labels
|
|
846
|
+
label_keys = ("target", "label", "outcome", "class", "y", "status")
|
|
847
|
+
name_candidates: list[str] = []
|
|
848
|
+
for key in label_keys:
|
|
849
|
+
for c in cols:
|
|
850
|
+
if key in str(c).lower():
|
|
851
|
+
name_candidates.append(c)
|
|
852
|
+
if name_candidates:
|
|
853
|
+
break # keep the earliest matching key-group
|
|
854
|
+
|
|
855
|
+
# prioritise name-based candidates that also look like proper label columns
|
|
856
|
+
for c in name_candidates:
|
|
857
|
+
if _is_reasonable_class_col(df[c], c):
|
|
858
|
+
return c
|
|
859
|
+
if name_candidates:
|
|
860
|
+
# fall back to the first name-based candidate if none passed the shape test
|
|
861
|
+
return name_candidates[0]
|
|
862
|
+
|
|
863
|
+
# 2) any column with a small number of distinct values (likely a class label)
|
|
864
|
+
for c in cols:
|
|
865
|
+
s = df[c]
|
|
866
|
+
if _is_reasonable_class_col(s, c):
|
|
867
|
+
return c
|
|
868
|
+
|
|
869
|
+
# Nothing suitable found
|
|
870
|
+
return None
|
|
871
|
+
|
|
872
|
+
def _guess_regression_target(df: pd.DataFrame) -> str | None:
|
|
873
|
+
num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
|
|
874
|
+
if not num_cols:
|
|
875
|
+
return None
|
|
876
|
+
# Avoid obvious ID-like columns
|
|
877
|
+
bad_keys = ("id", "identifier", "index")
|
|
878
|
+
candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
|
|
879
|
+
return (candidates or num_cols)[-1]
|
|
880
|
+
|
|
881
|
+
def _guess_time_col(df: pd.DataFrame) -> str | None:
|
|
882
|
+
# Prefer actual datetime dtype
|
|
883
|
+
dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
|
|
884
|
+
if dt_cols:
|
|
885
|
+
return dt_cols[0]
|
|
886
|
+
|
|
887
|
+
# Fallback: name-based hints
|
|
888
|
+
name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
|
|
889
|
+
for c in df.columns:
|
|
890
|
+
name = str(c).lower()
|
|
891
|
+
if any(k in name for k in name_keys):
|
|
892
|
+
return c
|
|
893
|
+
return None
|
|
894
|
+
|
|
895
|
+
def _guess_entity_col(df: pd.DataFrame) -> str | None:
|
|
896
|
+
# Typical sequence IDs: id, patient, subject, device, series, entity
|
|
897
|
+
keys = ["id", "patient", "subject", "device", "series", "entity"]
|
|
898
|
+
candidates = []
|
|
899
|
+
for c in df.columns:
|
|
900
|
+
name = str(c).lower()
|
|
901
|
+
if any(k in name for k in keys):
|
|
902
|
+
candidates.append(c)
|
|
903
|
+
return candidates[0] if candidates else None
|
|
904
|
+
|
|
905
|
+
def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
|
|
906
|
+
# Try label-like names first
|
|
907
|
+
keys = ["target", "label", "class", "outcome", "y"]
|
|
908
|
+
for key in keys:
|
|
909
|
+
for c in df.columns:
|
|
910
|
+
if key in str(c).lower():
|
|
911
|
+
return c
|
|
912
|
+
|
|
913
|
+
# Fallback: any column with few distinct values (e.g. <= 10)
|
|
914
|
+
for c in df.columns:
|
|
915
|
+
s = df[c]
|
|
916
|
+
# avoid obvious IDs
|
|
917
|
+
if any(k in str(c).lower() for k in ["id", "index"]):
|
|
918
|
+
continue
|
|
919
|
+
try:
|
|
920
|
+
nunq = s.dropna().nunique()
|
|
921
|
+
except Exception:
|
|
922
|
+
continue
|
|
923
|
+
if 1 < nunq <= 10:
|
|
924
|
+
return c
|
|
925
|
+
|
|
926
|
+
return None
|
|
927
|
+
|
|
928
|
+
def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
|
|
929
|
+
cols = list(df.columns)
|
|
930
|
+
lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
|
|
931
|
+
# also include boolean/binary columns with suitable names
|
|
932
|
+
for c in cols:
|
|
933
|
+
s = df[c]
|
|
934
|
+
try:
|
|
935
|
+
nunq = s.dropna().nunique()
|
|
936
|
+
except Exception:
|
|
937
|
+
continue
|
|
938
|
+
if nunq in (2,) and c not in lbl_like:
|
|
939
|
+
# avoid obvious IDs
|
|
940
|
+
if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
|
|
941
|
+
lbl_like.append(c)
|
|
942
|
+
# keep at most, say, 12 to avoid accidental flood
|
|
943
|
+
return lbl_like[:12]
|
|
944
|
+
|
|
945
|
+
def _find_unknownish_column(df: pd.DataFrame) -> str | None:
|
|
946
|
+
# Search categorical-like columns for any 'unknown-like' values or high missingness
|
|
947
|
+
candidates = []
|
|
948
|
+
for c in df.columns:
|
|
949
|
+
s = df[c]
|
|
950
|
+
# focus on object/category/boolean-ish or low-card columns
|
|
951
|
+
if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
|
|
952
|
+
continue
|
|
953
|
+
try:
|
|
954
|
+
vals = s.astype(str).str.strip().str.lower()
|
|
955
|
+
except Exception:
|
|
956
|
+
continue
|
|
957
|
+
# score: presence of unknown tokens + missing rate
|
|
958
|
+
token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
|
|
959
|
+
miss_rate = s.isna().mean()
|
|
960
|
+
name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
|
|
961
|
+
score = 3*token_hit + 2*name_bonus + miss_rate
|
|
962
|
+
if token_hit or miss_rate > 0.05 or name_bonus:
|
|
963
|
+
candidates.append((score, c))
|
|
964
|
+
if not candidates:
|
|
965
|
+
return None
|
|
966
|
+
candidates.sort(reverse=True)
|
|
967
|
+
return candidates[0][1]
|
|
968
|
+
|
|
969
|
+
def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
|
|
970
|
+
cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
|
|
971
|
+
# prefer non-constant columns
|
|
972
|
+
scored = []
|
|
973
|
+
for c in cols:
|
|
974
|
+
try:
|
|
975
|
+
v = df[c].dropna()
|
|
976
|
+
var = float(v.var()) if len(v) else 0.0
|
|
977
|
+
scored.append((var, c))
|
|
978
|
+
except Exception:
|
|
979
|
+
continue
|
|
980
|
+
scored.sort(reverse=True)
|
|
981
|
+
return [c for _, c in scored[:max_n]]
|
|
982
|
+
|
|
983
|
+
def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
|
|
984
|
+
exclude = exclude or set()
|
|
985
|
+
picks = []
|
|
986
|
+
for c in df.columns:
|
|
987
|
+
if c in exclude:
|
|
988
|
+
continue
|
|
989
|
+
s = df[c]
|
|
990
|
+
if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
|
|
991
|
+
nunq = s.dropna().nunique()
|
|
992
|
+
if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
|
|
993
|
+
picks.append((nunq, c))
|
|
994
|
+
picks.sort(reverse=True)
|
|
995
|
+
return [c for _, c in picks[:max_n]]
|
|
996
|
+
|
|
997
|
+
def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
|
|
998
|
+
exclude = exclude or set()
|
|
999
|
+
# name hints first
|
|
1000
|
+
name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
|
|
1001
|
+
for c in df.columns:
|
|
1002
|
+
if c in exclude:
|
|
1003
|
+
continue
|
|
1004
|
+
name = str(c).lower()
|
|
1005
|
+
if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
|
|
1006
|
+
return c
|
|
1007
|
+
# fallback: any binary numeric
|
|
1008
|
+
for c in df.select_dtypes(include=[np.number, "bool"]).columns:
|
|
1009
|
+
if c in exclude:
|
|
1010
|
+
continue
|
|
1011
|
+
try:
|
|
1012
|
+
if df[c].dropna().nunique() == 2:
|
|
1013
|
+
return c
|
|
1014
|
+
except Exception:
|
|
1015
|
+
continue
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def _pick_viz_template(signal: str):
|
|
1020
|
+
s = signal.lower()
|
|
1021
|
+
|
|
1022
|
+
# explicit chart requests
|
|
1023
|
+
if any(k in s for k in ("pie", "donut")):
|
|
1024
|
+
return viz_pie
|
|
1025
|
+
|
|
1026
|
+
if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
|
|
1027
|
+
return viz_stacked_bar
|
|
1028
|
+
|
|
1029
|
+
if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
|
|
1030
|
+
return viz_distribution
|
|
1031
|
+
|
|
1032
|
+
if any(k in s for k in ("kde", "density")):
|
|
1033
|
+
return viz_kde
|
|
1034
|
+
|
|
1035
|
+
# these three you asked about
|
|
1036
|
+
if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
|
|
1037
|
+
return viz_box
|
|
1038
|
+
|
|
1039
|
+
if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
|
|
1040
|
+
return viz_scatter
|
|
1041
|
+
|
|
1042
|
+
if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
|
|
1043
|
+
return viz_count_bar
|
|
1044
|
+
|
|
1045
|
+
if any(k in s for k in ("area", "trend", "over time", "time series")):
|
|
1046
|
+
return viz_area
|
|
1047
|
+
|
|
1048
|
+
# fallback
|
|
1049
|
+
return viz_line
|
|
1050
|
+
|
|
1051
|
+
for intent in intents:
|
|
1052
|
+
|
|
1053
|
+
if intent not in INJECTABLE_INTENTS:
|
|
1054
|
+
return code
|
|
1055
|
+
|
|
1056
|
+
# Correlation analysis
|
|
1057
|
+
if intent == "correlation_analysis" and not has_fit:
|
|
1058
|
+
return eda_correlation(df) + "\n\n" + code
|
|
1059
|
+
|
|
1060
|
+
# Generic visualisation (keyword-based)
|
|
1061
|
+
if intent == "visualisation" and not has_fit and not has_plot:
|
|
1062
|
+
rq = str(globals().get("refined_question", ""))
|
|
1063
|
+
# aq = str(globals().get("askai_question", ""))
|
|
1064
|
+
signal = rq + "\n" + str(intents) + "\n" + code
|
|
1065
|
+
tpl = _pick_viz_template(signal)
|
|
1066
|
+
return tpl(df) + "\n\n" + code
|
|
1067
|
+
|
|
1068
|
+
if intent == "clustering" and not has_fit:
|
|
1069
|
+
return clustering(df) + "\n\n" + code
|
|
1070
|
+
|
|
1071
|
+
if intent == "recommendation" and not has_fit:
|
|
1072
|
+
return recommendation(df) + "\\n\\n" + code
|
|
1073
|
+
|
|
1074
|
+
if intent == "topic_modelling" and not has_fit:
|
|
1075
|
+
return topic_modelling(df) + "\\n\\n" + code
|
|
1076
|
+
|
|
1077
|
+
if intent == "eda" and not has_fit:
|
|
1078
|
+
return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
|
|
1079
|
+
|
|
1080
|
+
# --- Classification ------------------------------------------------
|
|
1081
|
+
if intent == "classification" and not has_fit:
|
|
1082
|
+
target = _guess_classification_target(df)
|
|
1083
|
+
if target:
|
|
1084
|
+
return classification(df) + "\n\n" + code
|
|
1085
|
+
# return _call_template(classification, df, target) + "\n\n" + code
|
|
1086
|
+
|
|
1087
|
+
# --- Regression ----------------------------------------------------
|
|
1088
|
+
if intent == "regression" and not has_fit:
|
|
1089
|
+
target = _guess_regression_target(df)
|
|
1090
|
+
if target:
|
|
1091
|
+
return regression(df) + "\n\n" + code
|
|
1092
|
+
# return _call_template(regression, df, target) + "\n\n" + code
|
|
1093
|
+
|
|
1094
|
+
# --- Anomaly detection --------------------------------------------
|
|
1095
|
+
if intent == "anomaly_detection":
|
|
1096
|
+
uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
|
|
1097
|
+
if not uses_anomaly:
|
|
1098
|
+
return anomaly_detection(df) + "\n\n" + code
|
|
1099
|
+
|
|
1100
|
+
# --- Time-series anomaly detection --------------------------------
|
|
1101
|
+
if intent == "ts_anomaly_detection":
|
|
1102
|
+
uses_ts = "STL(" in code or "seasonal_decompose(" in code
|
|
1103
|
+
if not uses_ts:
|
|
1104
|
+
return ts_anomaly_detection(df) + "\n\n" + code
|
|
1105
|
+
|
|
1106
|
+
# --- Time-series classification --------------------------------
|
|
1107
|
+
if intent == "time_series_classification" and not has_fit:
|
|
1108
|
+
time_col = _guess_time_col(df)
|
|
1109
|
+
entity_col = _guess_entity_col(df)
|
|
1110
|
+
target_col = _guess_ts_class_target(df)
|
|
1111
|
+
|
|
1112
|
+
# If we can't confidently identify these, do NOT inject anything
|
|
1113
|
+
if time_col and entity_col and target_col:
|
|
1114
|
+
return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
|
|
1115
|
+
|
|
1116
|
+
# --- Dimensionality reduction --------------------------------------
|
|
1117
|
+
if intent == "dimensionality_reduction":
|
|
1118
|
+
uses_dr = any(k in code for k in ("PCA(", "TSNE("))
|
|
1119
|
+
if not uses_dr:
|
|
1120
|
+
return dimensionality_reduction(df) + "\n\n" + code
|
|
1121
|
+
|
|
1122
|
+
# --- Feature selection ---------------------------------------------
|
|
1123
|
+
if intent == "feature_selection":
|
|
1124
|
+
uses_fs = any(k in code for k in (
|
|
1125
|
+
"mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
|
|
1126
|
+
))
|
|
1127
|
+
if not uses_fs:
|
|
1128
|
+
return feature_selection(df) + "\n\n" + code
|
|
1129
|
+
|
|
1130
|
+
# --- EDA / correlation / visualisation -----------------------------
|
|
1131
|
+
if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
|
|
1132
|
+
if intent == "correlation_analysis":
|
|
1133
|
+
return eda_correlation(df) + "\n\n" + code
|
|
1134
|
+
else:
|
|
1135
|
+
return eda_overview(df) + "\n\n" + code
|
|
1136
|
+
|
|
1137
|
+
# --- Time-series forecasting ---------------------------------------
|
|
1138
|
+
if intent == "time_series_forecasting" and not has_fit:
|
|
1139
|
+
uses_ts_forecast = any(k in code for k in (
|
|
1140
|
+
"ARIMA", "ExponentialSmoothing", "forecast", "predict("
|
|
1141
|
+
))
|
|
1142
|
+
if not uses_ts_forecast:
|
|
1143
|
+
return time_series_forecasting(df) + "\n\n" + code
|
|
1144
|
+
|
|
1145
|
+
# --- Multi-label classification -----------------------------------
|
|
1146
|
+
if intent in ("multilabel_classification",) and not has_fit:
|
|
1147
|
+
label_cols = _guess_multilabel_cols(df)
|
|
1148
|
+
if len(label_cols) >= 2:
|
|
1149
|
+
return multilabel_classification(df, label_cols) + "\n\n" + code
|
|
1150
|
+
|
|
1151
|
+
group_col = _find_unknownish_column(df)
|
|
1152
|
+
if group_col:
|
|
1153
|
+
num_cols = _guess_numeric_cols(df)
|
|
1154
|
+
cat_cols = _guess_categorical_cols(df, exclude={group_col})
|
|
1155
|
+
outcome_col = None # generic; let template skip if not present
|
|
1156
|
+
tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
|
|
1157
|
+
|
|
1158
|
+
# Return template + guarded (repaired) LLM code, so it never crashes
|
|
1159
|
+
repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
|
|
1160
|
+
return tpl + "\n\n" + wrap_llm_code_safe(repaired)
|
|
1161
|
+
|
|
1162
|
+
return code
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1166
|
+
"""
|
|
1167
|
+
If a previous pass injected numeric_only=True into a NumPy-style sum,
|
|
1168
|
+
e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
|
|
1169
|
+
"""
|
|
1170
|
+
# .values.sum(numeric_only=True, ...)
|
|
1171
|
+
code = re.sub(
|
|
1172
|
+
r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1173
|
+
".to_numpy().sum()",
|
|
1174
|
+
code,
|
|
1175
|
+
flags=re.IGNORECASE,
|
|
1176
|
+
)
|
|
1177
|
+
# .to_numpy().sum(numeric_only=True, ...)
|
|
1178
|
+
code = re.sub(
|
|
1179
|
+
r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1180
|
+
".to_numpy().sum()",
|
|
1181
|
+
code,
|
|
1182
|
+
flags=re.IGNORECASE,
|
|
1183
|
+
)
|
|
1184
|
+
return code
|
|
1185
|
+
|
|
48
1186
|
|
|
49
1187
|
def strip_describe_slice(code: str) -> str:
|
|
50
1188
|
"""
|
|
@@ -59,10 +1197,12 @@ def strip_describe_slice(code: str) -> str:
|
|
|
59
1197
|
)
|
|
60
1198
|
return pat.sub(r"\1)", code)
|
|
61
1199
|
|
|
1200
|
+
|
|
62
1201
|
def remove_plt_show(code: str) -> str:
|
|
63
1202
|
"""Removes all plt.show() calls from the generated code string."""
|
|
64
1203
|
return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
|
|
65
1204
|
|
|
1205
|
+
|
|
66
1206
|
def patch_plot_with_table(code: str) -> str:
|
|
67
1207
|
"""
|
|
68
1208
|
▸ strips every `plt.show()` (avoids warnings)
|
|
@@ -149,7 +1289,7 @@ def patch_plot_with_table(code: str) -> str:
|
|
|
149
1289
|
")\n"
|
|
150
1290
|
)
|
|
151
1291
|
|
|
152
|
-
tbl_block += "
|
|
1292
|
+
tbl_block += "show(summary_table, title='Summary Statistics')"
|
|
153
1293
|
|
|
154
1294
|
# 5. inject image-export block, then table block, after the plot
|
|
155
1295
|
patched = (
|
|
@@ -289,10 +1429,10 @@ def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
|
289
1429
|
"Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
|
|
290
1430
|
)
|
|
291
1431
|
|
|
292
|
-
|
|
293
1432
|
# 9. Fallback: return the raw question
|
|
294
1433
|
return q
|
|
295
1434
|
|
|
1435
|
+
|
|
296
1436
|
def patch_plot_code(code, df, user_question=None):
|
|
297
1437
|
|
|
298
1438
|
# ── Early guard: abort nicely if the generated code references columns that
|
|
@@ -313,10 +1453,13 @@ def patch_plot_code(code, df, user_question=None):
|
|
|
313
1453
|
|
|
314
1454
|
if missing_cols:
|
|
315
1455
|
cols_list = ", ".join(missing_cols)
|
|
316
|
-
|
|
317
|
-
f"
|
|
318
|
-
|
|
1456
|
+
warning = (
|
|
1457
|
+
f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
|
|
1458
|
+
"These must either exist in df or be created earlier in the code; "
|
|
1459
|
+
"otherwise you may see a KeyError.')\n"
|
|
319
1460
|
)
|
|
1461
|
+
# Prepend the warning but keep the original code so it can still run
|
|
1462
|
+
code = warning + code
|
|
320
1463
|
|
|
321
1464
|
# 1. For line plots (auto-aggregate)
|
|
322
1465
|
m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
@@ -425,6 +1568,16 @@ def patch_plot_code(code, df, user_question=None):
|
|
|
425
1568
|
# Fallback: Return original code
|
|
426
1569
|
return code
|
|
427
1570
|
|
|
1571
|
+
|
|
1572
|
+
def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1573
|
+
import re
|
|
1574
|
+
makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
|
|
1575
|
+
has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
|
|
1576
|
+
if makes_plot and not has_title:
|
|
1577
|
+
code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
|
|
1578
|
+
return code
|
|
1579
|
+
|
|
1580
|
+
|
|
428
1581
|
def ensure_output(code: str) -> str:
|
|
429
1582
|
"""
|
|
430
1583
|
Guarantees that AI-generated code actually surfaces results in the UI
|
|
@@ -441,7 +1594,6 @@ def ensure_output(code: str) -> str:
|
|
|
441
1594
|
# not a comment / print / assignment / pyplot call
|
|
442
1595
|
if (last and not last.startswith(("print(", "plt.", "#")) and "=" not in last):
|
|
443
1596
|
lines[-1] = f"_out = {last}"
|
|
444
|
-
lines.append("from syntaxmatrix.display import show")
|
|
445
1597
|
lines.append("show(_out)")
|
|
446
1598
|
|
|
447
1599
|
# ── 3· auto-surface common stats tuples (stat, p) ───────────────────
|
|
@@ -449,14 +1601,12 @@ def ensure_output(code: str) -> str:
|
|
|
449
1601
|
if re.search(r"\bchi2\s*,\s*p\s*,", code) and "show((" in code:
|
|
450
1602
|
pass # AI already shows the tuple
|
|
451
1603
|
elif re.search(r"\bchi2\s*,\s*p\s*,", code):
|
|
452
|
-
lines.append("from syntaxmatrix.display import show")
|
|
453
1604
|
lines.append("show((chi2, p))")
|
|
454
1605
|
|
|
455
1606
|
# ── 4· classification report (string) ───────────────────────────────
|
|
456
1607
|
cr_match = re.search(r"^\s*(\w+)\s*=\s*classification_report\(", code, re.M)
|
|
457
1608
|
if cr_match and f"show({cr_match.group(1)})" not in "\n".join(lines):
|
|
458
1609
|
var = cr_match.group(1)
|
|
459
|
-
lines.append("from syntaxmatrix.display import show")
|
|
460
1610
|
lines.append(f"show({var})")
|
|
461
1611
|
|
|
462
1612
|
# 5-bis · pivot tables (DataFrame)
|
|
@@ -493,18 +1643,17 @@ def ensure_output(code: str) -> str:
|
|
|
493
1643
|
assign_scalar = re.match(r"\s*(\w+)\s*=\s*.+\.shape\[\s*0\s*\]\s*$", lines[-1])
|
|
494
1644
|
if assign_scalar:
|
|
495
1645
|
var = assign_scalar.group(1)
|
|
496
|
-
lines.append("from syntaxmatrix.display import show")
|
|
497
1646
|
lines.append(f"show({var})")
|
|
498
1647
|
|
|
499
1648
|
# ── 8. utils.ensure_output()
|
|
500
1649
|
assign_df = re.match(r"\s*(\w+)\s*=\s*df\[", lines[-1])
|
|
501
1650
|
if assign_df:
|
|
502
1651
|
var = assign_df.group(1)
|
|
503
|
-
lines.append("from syntaxmatrix.display import show")
|
|
504
1652
|
lines.append(f"show({var})")
|
|
505
1653
|
|
|
506
1654
|
return "\n".join(lines)
|
|
507
1655
|
|
|
1656
|
+
|
|
508
1657
|
def get_plotting_imports(code):
|
|
509
1658
|
imports = []
|
|
510
1659
|
if "plt." in code and "import matplotlib.pyplot as plt" not in code:
|
|
@@ -524,6 +1673,7 @@ def get_plotting_imports(code):
|
|
|
524
1673
|
code = "\n".join(imports) + "\n\n" + code
|
|
525
1674
|
return code
|
|
526
1675
|
|
|
1676
|
+
|
|
527
1677
|
def patch_pairplot(code, df):
|
|
528
1678
|
if "sns.pairplot" in code:
|
|
529
1679
|
# Always assign and print pairgrid
|
|
@@ -534,29 +1684,82 @@ def patch_pairplot(code, df):
|
|
|
534
1684
|
code += "\nprint(pairgrid)"
|
|
535
1685
|
return code
|
|
536
1686
|
|
|
1687
|
+
|
|
537
1688
|
def ensure_image_output(code: str) -> str:
|
|
538
1689
|
"""
|
|
539
|
-
|
|
540
|
-
|
|
1690
|
+
Replace each plt.show() with an indented _SMX_export_png() call.
|
|
1691
|
+
This keeps block indentation valid and still renders images in the dashboard.
|
|
541
1692
|
"""
|
|
542
1693
|
if "plt.show()" not in code:
|
|
543
1694
|
return code
|
|
544
1695
|
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
"
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
1696
|
+
import re
|
|
1697
|
+
out_lines = []
|
|
1698
|
+
for ln in code.splitlines():
|
|
1699
|
+
if "plt.show()" not in ln:
|
|
1700
|
+
out_lines.append(ln)
|
|
1701
|
+
continue
|
|
1702
|
+
|
|
1703
|
+
# works for:
|
|
1704
|
+
# plt.show()
|
|
1705
|
+
# plt.tight_layout(); plt.show()
|
|
1706
|
+
# ... ; plt.show(); ... (multiple on one line)
|
|
1707
|
+
indent = re.match(r"^(\s*)", ln).group(1)
|
|
1708
|
+
parts = ln.split("plt.show()")
|
|
1709
|
+
|
|
1710
|
+
# keep whatever is before the first plt.show()
|
|
1711
|
+
if parts[0].strip():
|
|
1712
|
+
out_lines.append(parts[0].rstrip())
|
|
1713
|
+
|
|
1714
|
+
# for every plt.show() we removed, insert exporter at same indent
|
|
1715
|
+
for _ in range(len(parts) - 1):
|
|
1716
|
+
out_lines.append(indent + "_SMX_export_png()")
|
|
1717
|
+
|
|
1718
|
+
# keep whatever comes after the last plt.show()
|
|
1719
|
+
if parts[-1].strip():
|
|
1720
|
+
out_lines.append(indent + parts[-1].lstrip())
|
|
1721
|
+
|
|
1722
|
+
return "\n".join(out_lines)
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
def clean_llm_code(code: str) -> str:
|
|
1726
|
+
"""
|
|
1727
|
+
Make LLM output safe to exec:
|
|
1728
|
+
- If fenced blocks exist, keep the largest one (usually the real code).
|
|
1729
|
+
- Otherwise strip any stray ``` / ```python lines.
|
|
1730
|
+
- Remove common markdown/preamble junk.
|
|
1731
|
+
"""
|
|
1732
|
+
code = str(code or "")
|
|
1733
|
+
|
|
1734
|
+
# Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1735
|
+
blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1736
|
+
|
|
1737
|
+
if blocks:
|
|
1738
|
+
# pick the largest block; small trailing blocks are usually garbage
|
|
1739
|
+
largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1740
|
+
if len(largest.strip().splitlines()) >= 10:
|
|
1741
|
+
code = largest
|
|
1742
|
+
else:
|
|
1743
|
+
# if no meaningful block, just remove fence markers
|
|
1744
|
+
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1745
|
+
else:
|
|
1746
|
+
# no complete blocks — still remove any stray fence lines
|
|
1747
|
+
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1748
|
+
|
|
1749
|
+
# Strip common markdown/preamble lines
|
|
1750
|
+
drop_prefixes = (
|
|
1751
|
+
"here is", "here's", "below is", "sure,", "certainly",
|
|
1752
|
+
"explanation", "note:", "```"
|
|
556
1753
|
)
|
|
1754
|
+
cleaned_lines = []
|
|
1755
|
+
for ln in code.splitlines():
|
|
1756
|
+
s = ln.strip().lower()
|
|
1757
|
+
if any(s.startswith(p) for p in drop_prefixes):
|
|
1758
|
+
continue
|
|
1759
|
+
cleaned_lines.append(ln)
|
|
1760
|
+
|
|
1761
|
+
return "\n".join(cleaned_lines).strip()
|
|
557
1762
|
|
|
558
|
-
# exporter BEFORE the original plt.show()
|
|
559
|
-
return code.replace("plt.show()", exporter + "plt.show()")
|
|
560
1763
|
|
|
561
1764
|
def fix_groupby_describe_slice(code: str) -> str:
|
|
562
1765
|
"""
|
|
@@ -579,6 +1782,7 @@ def fix_groupby_describe_slice(code: str) -> str:
|
|
|
579
1782
|
)
|
|
580
1783
|
return pat.sub(repl, code)
|
|
581
1784
|
|
|
1785
|
+
|
|
582
1786
|
def fix_importance_groupby(code: str) -> str:
|
|
583
1787
|
pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
|
|
584
1788
|
if "importance_df" in code:
|
|
@@ -625,10 +1829,12 @@ def inject_auto_preprocessing(code: str) -> str:
|
|
|
625
1829
|
# simply prepend; model code that follows can wrap estimator in a Pipeline
|
|
626
1830
|
return prep_snippet + code
|
|
627
1831
|
|
|
1832
|
+
|
|
628
1833
|
def fix_to_datetime_errors(code: str) -> str:
|
|
629
1834
|
"""
|
|
630
1835
|
Force every pd.to_datetime(…) call to ignore bad dates so that
|
|
631
|
-
|
|
1836
|
+
|
|
1837
|
+
'year 16500 is out of range' and similar issues don’t crash runs.
|
|
632
1838
|
"""
|
|
633
1839
|
import re
|
|
634
1840
|
# look for any pd.to_datetime( … )
|
|
@@ -641,25 +1847,67 @@ def fix_to_datetime_errors(code: str) -> str:
|
|
|
641
1847
|
return f"pd.to_datetime({inside}, errors='coerce')"
|
|
642
1848
|
return pat.sub(repl, code)
|
|
643
1849
|
|
|
1850
|
+
|
|
644
1851
|
def fix_numeric_sum(code: str) -> str:
|
|
645
1852
|
"""
|
|
646
|
-
|
|
647
|
-
|
|
1853
|
+
Make .sum(...) code safe across pandas versions by removing any
|
|
1854
|
+
numeric_only=... argument (True/False/None) from function calls.
|
|
1855
|
+
|
|
1856
|
+
This avoids errors on pandas versions where numeric_only is not
|
|
1857
|
+
supported for Series/grouped sums, and we rely instead on explicit
|
|
1858
|
+
numeric column selection (e.g. select_dtypes) in the generated code.
|
|
648
1859
|
"""
|
|
649
|
-
|
|
1860
|
+
# Case 1: ..., numeric_only=True/False/None
|
|
1861
|
+
code = re.sub(
|
|
1862
|
+
r",\s*numeric_only\s*=\s*(True|False|None)",
|
|
1863
|
+
"",
|
|
1864
|
+
code,
|
|
1865
|
+
flags=re.IGNORECASE,
|
|
1866
|
+
)
|
|
650
1867
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
1868
|
+
# Case 2: numeric_only=True/False/None, ... (as first argument)
|
|
1869
|
+
code = re.sub(
|
|
1870
|
+
r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
|
|
1871
|
+
"",
|
|
1872
|
+
code,
|
|
1873
|
+
flags=re.IGNORECASE,
|
|
1874
|
+
)
|
|
655
1875
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
1876
|
+
# Case 3: numeric_only=True/False/None (only argument)
|
|
1877
|
+
code = re.sub(
|
|
1878
|
+
r"numeric_only\s*=\s*(True|False|None)",
|
|
1879
|
+
"",
|
|
1880
|
+
code,
|
|
1881
|
+
flags=re.IGNORECASE,
|
|
1882
|
+
)
|
|
1883
|
+
|
|
1884
|
+
return code
|
|
1885
|
+
|
|
1886
|
+
|
|
1887
|
+
def fix_concat_empty_list(code: str) -> str:
|
|
1888
|
+
"""
|
|
1889
|
+
Make pd.concat calls resilient to empty lists of objects.
|
|
1890
|
+
|
|
1891
|
+
Transforms patterns like:
|
|
1892
|
+
pd.concat(frames, ignore_index=True)
|
|
1893
|
+
pd.concat(frames)
|
|
1894
|
+
|
|
1895
|
+
into:
|
|
1896
|
+
pd.concat(frames or [pd.DataFrame()], ignore_index=True)
|
|
1897
|
+
pd.concat(frames or [pd.DataFrame()])
|
|
1898
|
+
|
|
1899
|
+
Only triggers when the first argument is a simple variable name.
|
|
1900
|
+
"""
|
|
1901
|
+
pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
|
|
1902
|
+
|
|
1903
|
+
def _repl(m):
|
|
1904
|
+
name = m.group(1)
|
|
1905
|
+
sep = m.group(2) # ',' or ')'
|
|
1906
|
+
return f"pd.concat({name} or [pd.DataFrame()]{sep}"
|
|
660
1907
|
|
|
661
1908
|
return pattern.sub(_repl, code)
|
|
662
1909
|
|
|
1910
|
+
|
|
663
1911
|
def fix_numeric_aggs(code: str) -> str:
|
|
664
1912
|
_AGG_FUNCS = ("sum", "mean")
|
|
665
1913
|
pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
|
|
@@ -673,69 +1921,66 @@ def fix_numeric_aggs(code: str) -> str:
|
|
|
673
1921
|
return f".{func}({args}numeric_only=True)"
|
|
674
1922
|
return pat.sub(_repl, code)
|
|
675
1923
|
|
|
1924
|
+
|
|
676
1925
|
def ensure_accuracy_block(code: str) -> str:
|
|
677
1926
|
"""
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
1927
|
+
Inject a sensible evaluation block right after the last `<est>.fit(...)`
|
|
1928
|
+
Classification → accuracy + weighted F1
|
|
1929
|
+
Regression → R², RMSE, MAE
|
|
1930
|
+
Heuristic: infer task from estimator names present in the code.
|
|
681
1931
|
"""
|
|
682
|
-
|
|
683
|
-
|
|
1932
|
+
import re, textwrap
|
|
1933
|
+
|
|
1934
|
+
# If any proper metric already exists, do nothing
|
|
1935
|
+
if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
|
|
684
1936
|
return code
|
|
685
1937
|
|
|
686
|
-
# Find the last
|
|
1938
|
+
# Find the last "<var>.fit(" occurrence to reuse the estimator variable name
|
|
687
1939
|
m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
|
|
688
1940
|
if not m:
|
|
689
|
-
return code
|
|
1941
|
+
return code # no estimator
|
|
690
1942
|
|
|
691
|
-
var = m[-1].group(1)
|
|
1943
|
+
var = m[-1].group(1)
|
|
1944
|
+
# indent with same leading whitespace used on that line
|
|
692
1945
|
indent = re.match(r"\s*", code[m[-1].start():]).group(0)
|
|
693
1946
|
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
1947
|
+
# Detect regression by estimator names / hints in code
|
|
1948
|
+
is_regression = bool(
|
|
1949
|
+
re.search(
|
|
1950
|
+
r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
1951
|
+
r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
1952
|
+
r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
1953
|
+
)
|
|
1954
|
+
or re.search(r"\bOLS\s*\(", code)
|
|
1955
|
+
or re.search(r"\bRegressor\b", code)
|
|
1956
|
+
)
|
|
1957
|
+
|
|
1958
|
+
if is_regression:
|
|
1959
|
+
# inject numpy import if needed for RMSE
|
|
1960
|
+
if "import numpy as np" not in code and "np." not in code:
|
|
1961
|
+
code = "import numpy as np\n" + code
|
|
1962
|
+
eval_block = textwrap.dedent(f"""
|
|
1963
|
+
{indent}# ── automatic regression evaluation ─────────
|
|
1964
|
+
{indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
|
1965
|
+
{indent}y_pred = {var}.predict(X_test)
|
|
1966
|
+
{indent}r2 = r2_score(y_test, y_pred)
|
|
1967
|
+
{indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
|
1968
|
+
{indent}mae = float(mean_absolute_error(y_test, y_pred))
|
|
1969
|
+
{indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
|
|
1970
|
+
""")
|
|
1971
|
+
else:
|
|
1972
|
+
eval_block = textwrap.dedent(f"""
|
|
1973
|
+
{indent}# ── automatic classification evaluation ─────────
|
|
1974
|
+
{indent}from sklearn.metrics import accuracy_score, f1_score
|
|
1975
|
+
{indent}y_pred = {var}.predict(X_test)
|
|
1976
|
+
{indent}acc = accuracy_score(y_test, y_pred)
|
|
1977
|
+
{indent}f1 = f1_score(y_test, y_pred, average='weighted')
|
|
1978
|
+
{indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
|
|
1979
|
+
""")
|
|
701
1980
|
|
|
702
1981
|
insert_at = code.find("\n", m[-1].end()) + 1
|
|
703
1982
|
return code[:insert_at] + eval_block + code[insert_at:]
|
|
704
1983
|
|
|
705
|
-
def classify(prompt: str) -> str:
|
|
706
|
-
"""
|
|
707
|
-
Very-light intent classifier.
|
|
708
|
-
Returns one of:
|
|
709
|
-
'stat_test' | 'time_series' | 'clustering'
|
|
710
|
-
'classification' | 'regression' | 'eda'
|
|
711
|
-
"""
|
|
712
|
-
p = prompt.lower().strip()
|
|
713
|
-
greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
|
|
714
|
-
if any(p.startswith(g) or p == g for g in greetings):
|
|
715
|
-
return "greeting"
|
|
716
|
-
|
|
717
|
-
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
718
|
-
return "stat_test"
|
|
719
|
-
if "forecast" in p or "prophet" in p:
|
|
720
|
-
return "time_series"
|
|
721
|
-
if "cluster" in p or "kmeans" in p:
|
|
722
|
-
return "clustering"
|
|
723
|
-
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
724
|
-
return "classification"
|
|
725
|
-
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
726
|
-
return "regression"
|
|
727
|
-
return "eda"
|
|
728
|
-
|
|
729
|
-
def auto_inject_template(code: str, intent: str, df) -> str:
|
|
730
|
-
"""If the LLM forgot the core logic, prepend a skeleton block."""
|
|
731
|
-
has_fit = ".fit(" in code
|
|
732
|
-
|
|
733
|
-
if intent == "classification" and not has_fit:
|
|
734
|
-
# guess a y column that contains 'diabetes' as in your dataset
|
|
735
|
-
target = next((c for c in df.columns if "diabetes" in c.lower()), None)
|
|
736
|
-
if target:
|
|
737
|
-
return classification(df, target) + "\n\n" + code
|
|
738
|
-
return code
|
|
739
1984
|
|
|
740
1985
|
def fix_scatter_and_summary(code: str) -> str:
|
|
741
1986
|
"""
|
|
@@ -763,6 +2008,7 @@ def fix_scatter_and_summary(code: str) -> str:
|
|
|
763
2008
|
|
|
764
2009
|
return code
|
|
765
2010
|
|
|
2011
|
+
|
|
766
2012
|
def auto_format_with_black(code: str) -> str:
|
|
767
2013
|
"""
|
|
768
2014
|
Format the generated code with Black. Falls back silently if Black
|
|
@@ -777,6 +2023,7 @@ def auto_format_with_black(code: str) -> str:
|
|
|
777
2023
|
except Exception:
|
|
778
2024
|
return code
|
|
779
2025
|
|
|
2026
|
+
|
|
780
2027
|
def ensure_preproc_in_pipeline(code: str) -> str:
|
|
781
2028
|
"""
|
|
782
2029
|
If code defines `preproc = ColumnTransformer(...)` but then builds
|
|
@@ -789,52 +2036,1128 @@ def ensure_preproc_in_pipeline(code: str) -> str:
|
|
|
789
2036
|
code
|
|
790
2037
|
)
|
|
791
2038
|
|
|
2039
|
+
|
|
792
2040
|
def fix_plain_prints(code: str) -> str:
|
|
793
2041
|
"""
|
|
794
|
-
Rewrite
|
|
795
|
-
|
|
2042
|
+
Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
|
|
2043
|
+
to go through SyntaxMatrix's smart display (so it renders in the dashboard).
|
|
2044
|
+
Keeps string prints alone.
|
|
796
2045
|
"""
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
2046
|
+
|
|
2047
|
+
# Skip obvious string-literal prints
|
|
2048
|
+
new = re.sub(
|
|
2049
|
+
r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
|
|
800
2050
|
r"from syntaxmatrix.display import show\nshow(\1)",
|
|
801
2051
|
code,
|
|
802
2052
|
)
|
|
2053
|
+
return new
|
|
2054
|
+
|
|
803
2055
|
|
|
804
|
-
|
|
805
|
-
# ✂
|
|
806
|
-
# --------------------------------------------------------------------------
|
|
807
|
-
def drop_bad_classification_metrics(code: str, y) -> str:
|
|
2056
|
+
def fix_print_html(code: str) -> str:
|
|
808
2057
|
"""
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
2058
|
+
Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
|
|
2059
|
+
not printed as `<IPython.core.display.HTML object>` to the server console.
|
|
2060
|
+
- Rewrites: print(HTML(...)) → display(HTML(...))
|
|
2061
|
+
print(display(...)) → display(...)
|
|
2062
|
+
print(df.to_html(...)) → display(HTML(df.to_html(...)))
|
|
2063
|
+
Also prepends `from IPython.display import display, HTML` if required.
|
|
2064
|
+
"""
|
|
2065
|
+
import re
|
|
2066
|
+
|
|
2067
|
+
new = code
|
|
2068
|
+
|
|
2069
|
+
# 1) print(HTML(...)) -> display(HTML(...))
|
|
2070
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
|
|
2071
|
+
|
|
2072
|
+
# 2) print(display(...)) -> display(...)
|
|
2073
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
|
|
2074
|
+
|
|
2075
|
+
# 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
|
|
2076
|
+
new = re.sub(
|
|
2077
|
+
r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
|
|
2078
|
+
r"display(HTML(\1.to_html(", new
|
|
2079
|
+
)
|
|
2080
|
+
|
|
2081
|
+
# If code references HTML() or display() make sure the import exists
|
|
2082
|
+
if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
|
|
2083
|
+
"from IPython.display import display, HTML" not in new:
|
|
2084
|
+
new = "from IPython.display import display, HTML\n" + new
|
|
2085
|
+
|
|
2086
|
+
return new
|
|
2087
|
+
|
|
813
2088
|
|
|
814
|
-
|
|
815
|
-
• a pandas Series -> y.dtype.kind is available
|
|
816
|
-
• a pandas DataFrame (multi-column) -> we infer by looking at *all*
|
|
2089
|
+
def ensure_ipy_display(code: str) -> str:
|
|
817
2090
|
"""
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
2091
|
+
Guarantee that the cell has proper IPython display imports so that
|
|
2092
|
+
display(HTML(...)) produces 'display_data' events the kernel captures.
|
|
2093
|
+
"""
|
|
2094
|
+
if "display(" in code and "from IPython.display import display, HTML" not in code:
|
|
2095
|
+
return "from IPython.display import display, HTML\n" + code
|
|
2096
|
+
return code
|
|
2097
|
+
|
|
2098
|
+
|
|
2099
|
+
def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
2100
|
+
"""
|
|
2101
|
+
Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
|
|
2102
|
+
if the generated cell is *regression*. We infer this from:
|
|
2103
|
+
1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
|
|
2104
|
+
2) The target dtype if we can parse y = df['...'] and have the DataFrame.
|
|
2105
|
+
Safe across datasets and queries.
|
|
2106
|
+
"""
|
|
2107
|
+
import re
|
|
2108
|
+
import pandas as pd
|
|
2109
|
+
|
|
2110
|
+
# 1) Heuristic by estimator names in the *code* (fast path)
|
|
2111
|
+
regression_by_model = bool(re.search(
|
|
2112
|
+
r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
2113
|
+
r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
2114
|
+
r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
2115
|
+
) or re.search(r"\bOLS\s*\(", code))
|
|
2116
|
+
|
|
2117
|
+
is_regression = regression_by_model
|
|
2118
|
+
|
|
2119
|
+
# 2) If not obvious from the model, try to infer from y dtype (if we can)
|
|
2120
|
+
if not is_regression:
|
|
2121
|
+
try:
|
|
2122
|
+
# Try to parse: y = df['target']
|
|
2123
|
+
m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
2124
|
+
if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
|
|
2125
|
+
y = y_or_df[m.group(1)]
|
|
2126
|
+
if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2127
|
+
is_regression = True
|
|
2128
|
+
else:
|
|
2129
|
+
# If a Series was passed
|
|
2130
|
+
y = y_or_df
|
|
2131
|
+
if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2132
|
+
is_regression = True
|
|
2133
|
+
except Exception:
|
|
2134
|
+
pass
|
|
2135
|
+
|
|
2136
|
+
if is_regression:
|
|
2137
|
+
# Strip classification-only lines
|
|
2138
|
+
for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
|
|
835
2139
|
code = re.sub(pat, "", code, flags=re.I)
|
|
836
2140
|
|
|
837
2141
|
return code
|
|
838
2142
|
|
|
839
|
-
|
|
840
|
-
|
|
2143
|
+
|
|
2144
|
+
def force_capture_display(code: str) -> str:
|
|
2145
|
+
"""
|
|
2146
|
+
Ensure our executor captures HTML output:
|
|
2147
|
+
- Remove any import that would override our 'display' hook.
|
|
2148
|
+
- Keep/allow importing HTML only.
|
|
2149
|
+
- Handle alias cases like 'display as d'.
|
|
2150
|
+
"""
|
|
2151
|
+
import re
|
|
2152
|
+
new = code
|
|
2153
|
+
|
|
2154
|
+
# 'from IPython.display import display, HTML' -> keep HTML only
|
|
2155
|
+
new = re.sub(
|
|
2156
|
+
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
|
|
2157
|
+
r"from IPython.display import HTML\1", new
|
|
2158
|
+
)
|
|
2159
|
+
|
|
2160
|
+
# 'from IPython.display import display as d' -> 'd = display'
|
|
2161
|
+
new = re.sub(
|
|
2162
|
+
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
|
|
2163
|
+
r"\1 = display", new
|
|
2164
|
+
)
|
|
2165
|
+
|
|
2166
|
+
# 'from IPython.display import display' -> remove (use our injected display)
|
|
2167
|
+
new = re.sub(
|
|
2168
|
+
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
|
|
2169
|
+
r"# display import removed (SMX capture active)", new
|
|
2170
|
+
)
|
|
2171
|
+
|
|
2172
|
+
# If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
|
|
2173
|
+
new = re.sub(
|
|
2174
|
+
r"(?m)\bIPython\.display\.display\s*\(",
|
|
2175
|
+
"display(", new
|
|
2176
|
+
)
|
|
2177
|
+
new = re.sub(
|
|
2178
|
+
r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
|
|
2179
|
+
r"(?=.*import\s+IPython\.display\s+as\s+\1)",
|
|
2180
|
+
"display(", new
|
|
2181
|
+
)
|
|
2182
|
+
return new
|
|
2183
|
+
|
|
2184
|
+
|
|
2185
|
+
def strip_matplotlib_show(code: str) -> str:
|
|
2186
|
+
"""Remove blocking plt.show() calls (we export base64 instead)."""
|
|
2187
|
+
import re
|
|
2188
|
+
return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
|
|
2189
|
+
|
|
2190
|
+
|
|
2191
|
+
def inject_display_shim(code: str) -> str:
|
|
2192
|
+
"""
|
|
2193
|
+
Provide display()/HTML() if missing, forwarding to our executor hook.
|
|
2194
|
+
Harmless if the names already exist.
|
|
2195
|
+
"""
|
|
2196
|
+
shim = (
|
|
2197
|
+
"try:\n"
|
|
2198
|
+
" display\n"
|
|
2199
|
+
"except NameError:\n"
|
|
2200
|
+
" def display(obj=None, **kwargs):\n"
|
|
2201
|
+
" __builtins__.get('_smx_display', print)(obj)\n"
|
|
2202
|
+
"try:\n"
|
|
2203
|
+
" HTML\n"
|
|
2204
|
+
"except NameError:\n"
|
|
2205
|
+
" class HTML:\n"
|
|
2206
|
+
" def __init__(self, data): self.data = str(data)\n"
|
|
2207
|
+
" def _repr_html_(self): return self.data\n"
|
|
2208
|
+
"\n"
|
|
2209
|
+
)
|
|
2210
|
+
return shim + code
|
|
2211
|
+
|
|
2212
|
+
|
|
2213
|
+
def strip_spurious_column_tokens(code: str) -> str:
|
|
2214
|
+
"""
|
|
2215
|
+
Remove common stop-words ('the','whether', ...) when they appear
|
|
2216
|
+
inside column lists, e.g.:
|
|
2217
|
+
predictors = ['BMI','the','HbA1c']
|
|
2218
|
+
df[['GGT','whether','BMI']]
|
|
2219
|
+
Leaves other strings intact.
|
|
2220
|
+
"""
|
|
2221
|
+
STOP = {
|
|
2222
|
+
"the","whether","a","an","and","or","of","to","in","on","for","by",
|
|
2223
|
+
"with","as","at","from","that","this","these","those","is","are","was","were",
|
|
2224
|
+
"coef", "Coef", "coefficient", "Coefficient"
|
|
2225
|
+
}
|
|
2226
|
+
|
|
2227
|
+
def _norm(s: str) -> str:
|
|
2228
|
+
return re.sub(r"[^a-z0-9]+", "", s.lower())
|
|
2229
|
+
|
|
2230
|
+
def _clean_list(content: str) -> str:
|
|
2231
|
+
# Rebuild a string list, keeping only non-stopword items
|
|
2232
|
+
items = re.findall(r"(['\"])(.*?)\1", content)
|
|
2233
|
+
if not items:
|
|
2234
|
+
return "[" + content + "]"
|
|
2235
|
+
keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
|
|
2236
|
+
return "[" + ", ".join(keep) + "]"
|
|
2237
|
+
|
|
2238
|
+
# Variable assignments: predictors/features/columns/cols = [...]
|
|
2239
|
+
code = re.sub(
|
|
2240
|
+
r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
|
|
2241
|
+
lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
|
|
2242
|
+
code
|
|
2243
|
+
)
|
|
2244
|
+
|
|
2245
|
+
# df[[ ... ]] selections
|
|
2246
|
+
code = re.sub(
|
|
2247
|
+
r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
|
|
2248
|
+
|
|
2249
|
+
return code
|
|
2250
|
+
|
|
2251
|
+
|
|
2252
|
+
def patch_prefix_seaborn_calls(code: str) -> str:
|
|
2253
|
+
"""
|
|
2254
|
+
Ensure bare seaborn calls are prefixed with `sns.`.
|
|
2255
|
+
E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
|
|
2256
|
+
"""
|
|
2257
|
+
if "sns." in code:
|
|
2258
|
+
# still fix any leftover bare calls alongside prefixed ones
|
|
2259
|
+
pass
|
|
2260
|
+
|
|
2261
|
+
# functions commonly used from seaborn
|
|
2262
|
+
funcs = [
|
|
2263
|
+
"barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
|
|
2264
|
+
"histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
|
|
2265
|
+
"scatterplot","lineplot","catplot","displot","lmplot"
|
|
2266
|
+
]
|
|
2267
|
+
# Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
|
|
2268
|
+
# (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
|
|
2269
|
+
pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
|
|
2270
|
+
|
|
2271
|
+
def _add_prefix(m):
|
|
2272
|
+
fn = m.group(1)
|
|
2273
|
+
return f"sns.{fn}("
|
|
2274
|
+
|
|
2275
|
+
return pattern.sub(_add_prefix, code)
|
|
2276
|
+
|
|
2277
|
+
|
|
2278
|
+
def patch_ensure_seaborn_import(code: str) -> str:
|
|
2279
|
+
"""
|
|
2280
|
+
If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
|
|
2281
|
+
Also set a quiet theme for consistent visuals.
|
|
2282
|
+
"""
|
|
2283
|
+
needs_sns = "sns." in code
|
|
2284
|
+
has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2285
|
+
if needs_sns and not has_import:
|
|
2286
|
+
# Insert after the first block of imports if possible, else at top
|
|
2287
|
+
import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2288
|
+
inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2289
|
+
if import_block:
|
|
2290
|
+
start = import_block.end()
|
|
2291
|
+
code = code[:start] + inject + code[start:]
|
|
2292
|
+
else:
|
|
2293
|
+
code = inject + code
|
|
2294
|
+
return code
|
|
2295
|
+
|
|
2296
|
+
|
|
2297
|
+
def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2298
|
+
"""
|
|
2299
|
+
Normalise pie-chart requests.
|
|
2300
|
+
|
|
2301
|
+
Supports three patterns:
|
|
2302
|
+
A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2303
|
+
B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2304
|
+
→ one pie per facet level (grid) + counts bar of facet sizes.
|
|
2305
|
+
C) Single pie when no split/facet is requested.
|
|
2306
|
+
|
|
2307
|
+
Notes:
|
|
2308
|
+
- Pie variables must be categorical (or numeric binned).
|
|
2309
|
+
- Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2310
|
+
"""
|
|
2311
|
+
|
|
2312
|
+
q = (user_question or "")
|
|
2313
|
+
q_low = q.lower()
|
|
2314
|
+
|
|
2315
|
+
# Prefer explicit: df['col'].value_counts()
|
|
2316
|
+
m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2317
|
+
col = m.group(1) if m else None
|
|
2318
|
+
|
|
2319
|
+
# ---------- helpers ----------
|
|
2320
|
+
def _is_cat(col):
|
|
2321
|
+
return (str(df[col].dtype).startswith("category")
|
|
2322
|
+
or df[col].dtype == "object"
|
|
2323
|
+
or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2324
|
+
|
|
2325
|
+
def _cats_from_question(question: str):
|
|
2326
|
+
found = []
|
|
2327
|
+
for c in df.columns:
|
|
2328
|
+
if c.lower() in question.lower() and _is_cat(c):
|
|
2329
|
+
found.append(c)
|
|
2330
|
+
# dedupe preserve order
|
|
2331
|
+
seen, out = set(), []
|
|
2332
|
+
for c in found:
|
|
2333
|
+
if c not in seen:
|
|
2334
|
+
out.append(c); seen.add(c)
|
|
2335
|
+
return out
|
|
2336
|
+
|
|
2337
|
+
def _fallback_cat():
|
|
2338
|
+
cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2339
|
+
if not cats: return None
|
|
2340
|
+
cats.sort(key=lambda t: t[1])
|
|
2341
|
+
return cats[0][0]
|
|
2342
|
+
|
|
2343
|
+
def _infer_comp_pref(question: str) -> str:
|
|
2344
|
+
ql = (question or "").lower()
|
|
2345
|
+
if "heatmap" in ql or "matrix" in ql:
|
|
2346
|
+
return "heatmap"
|
|
2347
|
+
if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
|
|
2348
|
+
return "stacked_bar_pct"
|
|
2349
|
+
if "stacked" in ql:
|
|
2350
|
+
return "stacked_bar"
|
|
2351
|
+
if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
|
|
2352
|
+
return "grouped_bar"
|
|
2353
|
+
return "counts_bar"
|
|
2354
|
+
|
|
2355
|
+
# parse threshold split like "HbA1c ≥ 6.5"
|
|
2356
|
+
def _parse_split(question: str):
|
|
2357
|
+
ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
|
|
2358
|
+
m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
|
|
2359
|
+
if not m: return None
|
|
2360
|
+
col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
|
|
2361
|
+
op = ops_map.get(op_raw);
|
|
2362
|
+
if not op: return None
|
|
2363
|
+
# case-insensitive column match
|
|
2364
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2365
|
+
col = candidates.get(col_raw.lower())
|
|
2366
|
+
if not col: return None
|
|
2367
|
+
try: val = float(val_raw)
|
|
2368
|
+
except Exception: return None
|
|
2369
|
+
return (col, op, val)
|
|
2370
|
+
|
|
2371
|
+
# facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
|
|
2372
|
+
def _extract_facet(question: str):
|
|
2373
|
+
# 1) explicit "by/ across / within / per <col>"
|
|
2374
|
+
for kw in [" by ", " across ", " within ", " within each ", " per "]:
|
|
2375
|
+
m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
|
|
2376
|
+
if m:
|
|
2377
|
+
col_raw = m.group(1).strip()
|
|
2378
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2379
|
+
if col_raw.lower() in candidates:
|
|
2380
|
+
return (candidates[col_raw.lower()], "auto")
|
|
2381
|
+
# 2) "bin <col>"
|
|
2382
|
+
m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
|
|
2383
|
+
if m2:
|
|
2384
|
+
col_raw = m2.group(1).strip()
|
|
2385
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2386
|
+
if col_raw.lower() in candidates:
|
|
2387
|
+
return (candidates[col_raw.lower()], "bin")
|
|
2388
|
+
# 3) BMI special: mentions of normal/overweight/obese imply BMI categories
|
|
2389
|
+
if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
|
|
2390
|
+
any(c.lower() == "bmi" for c in df.columns.str.lower()):
|
|
2391
|
+
bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
|
|
2392
|
+
return (bmi_col, "bmi")
|
|
2393
|
+
return None
|
|
2394
|
+
|
|
2395
|
+
def _bmi_bins(series: pd.Series):
|
|
2396
|
+
# WHO cutoffs
|
|
2397
|
+
bins = [-np.inf, 18.5, 25, 30, np.inf]
|
|
2398
|
+
labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
|
|
2399
|
+
return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
|
|
2400
|
+
|
|
2401
|
+
wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
|
|
2402
|
+
if not wants_pie:
|
|
2403
|
+
return code
|
|
2404
|
+
|
|
2405
|
+
split = _parse_split(q)
|
|
2406
|
+
facet = _extract_facet(q)
|
|
2407
|
+
cats = _cats_from_question(q)
|
|
2408
|
+
_comp_pref = _infer_comp_pref(q)
|
|
2409
|
+
|
|
2410
|
+
# Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
|
|
2411
|
+
for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
|
|
2412
|
+
if hard in df.columns and hard not in cats and hard.lower() in q_low:
|
|
2413
|
+
cats.append(hard)
|
|
2414
|
+
|
|
2415
|
+
# --------------- CASE A: threshold split (cohorts) ---------------
|
|
2416
|
+
if split:
|
|
2417
|
+
if not (cats or any(_is_cat(c) for c in df.columns)):
|
|
2418
|
+
return code
|
|
2419
|
+
if not cats:
|
|
2420
|
+
pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2421
|
+
pool.sort(key=lambda t: t[1])
|
|
2422
|
+
cats = [t[0] for t in pool[:3]] if pool else []
|
|
2423
|
+
if not cats:
|
|
2424
|
+
return code
|
|
2425
|
+
|
|
2426
|
+
split_col, op, val = split
|
|
2427
|
+
cond_str = f"(df['{split_col}'] {op} {val})"
|
|
2428
|
+
snippet = f"""
|
|
2429
|
+
import numpy as np
|
|
2430
|
+
import pandas as pd
|
|
2431
|
+
import matplotlib.pyplot as plt
|
|
2432
|
+
|
|
2433
|
+
_mask_a = ({cond_str}) & df['{split_col}'].notna()
|
|
2434
|
+
_mask_b = (~({cond_str})) & df['{split_col}'].notna()
|
|
2435
|
+
|
|
2436
|
+
_cohort_a_name = "{split_col} {op} {val}"
|
|
2437
|
+
_cohort_b_name = "NOT ({split_col} {op} {val})"
|
|
2438
|
+
|
|
2439
|
+
_cat_cols = {cats!r}
|
|
2440
|
+
n = len(_cat_cols)
|
|
2441
|
+
fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
|
|
2442
|
+
if n == 1:
|
|
2443
|
+
axes = np.array([axes])
|
|
2444
|
+
|
|
2445
|
+
for i, col in enumerate(_cat_cols):
|
|
2446
|
+
s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
|
|
2447
|
+
s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
|
|
2448
|
+
|
|
2449
|
+
ax_a = axes[i, 0]; ax_b = axes[i, 1]
|
|
2450
|
+
if len(s_a) > 0:
|
|
2451
|
+
ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
|
|
2452
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2453
|
+
ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
|
|
2454
|
+
|
|
2455
|
+
if len(s_b) > 0:
|
|
2456
|
+
ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
|
|
2457
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2458
|
+
ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
|
|
2459
|
+
|
|
2460
|
+
plt.tight_layout(); plt.show()
|
|
2461
|
+
|
|
2462
|
+
# grouped bar complement
|
|
2463
|
+
for col in _cat_cols:
|
|
2464
|
+
_tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
|
|
2465
|
+
.assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
|
|
2466
|
+
_tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
|
|
2467
|
+
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2468
|
+
|
|
2469
|
+
if _comp_pref == "grouped_bar":
|
|
2470
|
+
ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
|
|
2471
|
+
ax.set_title(f"{col} by cohort (grouped)")
|
|
2472
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2473
|
+
plt.tight_layout(); plt.show()
|
|
2474
|
+
|
|
2475
|
+
elif _comp_pref == "stacked_bar":
|
|
2476
|
+
ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2477
|
+
ax.set_title(f"{col} by cohort (stacked)")
|
|
2478
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2479
|
+
plt.tight_layout(); plt.show()
|
|
2480
|
+
|
|
2481
|
+
elif _comp_pref == "stacked_bar_pct":
|
|
2482
|
+
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2483
|
+
ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2484
|
+
ax.set_title(f"{col} by cohort (100% stacked)")
|
|
2485
|
+
ax.set_xlabel(col); ax.set_ylabel("Percent")
|
|
2486
|
+
plt.tight_layout(); plt.show()
|
|
2487
|
+
|
|
2488
|
+
elif _comp_pref == "heatmap":
|
|
2489
|
+
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2490
|
+
import numpy as np
|
|
2491
|
+
fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
|
|
2492
|
+
im = ax.imshow(_perc.values, aspect='auto')
|
|
2493
|
+
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2494
|
+
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2495
|
+
ax.set_title(f"{col} by cohort — % heatmap")
|
|
2496
|
+
for i in range(_perc.shape[0]):
|
|
2497
|
+
for j in range(_perc.shape[1]):
|
|
2498
|
+
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2499
|
+
fig.colorbar(im, ax=ax, label="%")
|
|
2500
|
+
plt.tight_layout(); plt.show()
|
|
2501
|
+
|
|
2502
|
+
else: # counts_bar (default)
|
|
2503
|
+
ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
|
|
2504
|
+
ax.set_title(f"{col}: total counts (both cohorts)")
|
|
2505
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2506
|
+
plt.tight_layout(); plt.show()
|
|
2507
|
+
""".lstrip()
|
|
2508
|
+
return snippet
|
|
2509
|
+
|
|
2510
|
+
# --------------- CASE B: facet-by (categories/bins) ---------------
|
|
2511
|
+
if facet:
|
|
2512
|
+
facet_col, how = facet
|
|
2513
|
+
# Build facet series
|
|
2514
|
+
if pd.api.types.is_numeric_dtype(df[facet_col]):
|
|
2515
|
+
if how == "bmi":
|
|
2516
|
+
facet_series = _bmi_bins(df[facet_col])
|
|
2517
|
+
else:
|
|
2518
|
+
# generic numeric bins: 3 equal-width bins by default
|
|
2519
|
+
facet_series = pd.cut(df[facet_col].astype(float), bins=3)
|
|
2520
|
+
else:
|
|
2521
|
+
facet_series = df[facet_col].astype(str)
|
|
2522
|
+
|
|
2523
|
+
# Choose pie dimension (categorical to count inside each facet)
|
|
2524
|
+
pie_dim = None
|
|
2525
|
+
for c in cats:
|
|
2526
|
+
if c in df.columns and _is_cat(c):
|
|
2527
|
+
pie_dim = c; break
|
|
2528
|
+
if pie_dim is None:
|
|
2529
|
+
pie_dim = _fallback_cat()
|
|
2530
|
+
if pie_dim is None:
|
|
2531
|
+
return code
|
|
2532
|
+
|
|
2533
|
+
snippet = f"""
|
|
2534
|
+
import math
|
|
2535
|
+
import pandas as pd
|
|
2536
|
+
import matplotlib.pyplot as plt
|
|
2537
|
+
|
|
2538
|
+
df = df.copy()
|
|
2539
|
+
_preferred = "{facet_col}" if "{facet_col}" in df.columns else None
|
|
2540
|
+
|
|
2541
|
+
def _select_facet_col(df, preferred=None):
|
|
2542
|
+
if preferred is not None:
|
|
2543
|
+
return preferred
|
|
2544
|
+
# Prefer low-cardinality categoricals (readable pies/grids)
|
|
2545
|
+
cat_cols = [
|
|
2546
|
+
c for c in df.columns
|
|
2547
|
+
if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
|
2548
|
+
and df[c].nunique() > 1 and df[c].nunique() <= 20
|
|
2549
|
+
]
|
|
2550
|
+
if cat_cols:
|
|
2551
|
+
cat_cols.sort(key=lambda c: df[c].nunique())
|
|
2552
|
+
return cat_cols[0]
|
|
2553
|
+
# Else fall back to first usable numeric
|
|
2554
|
+
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
|
|
2555
|
+
return num_cols[0] if num_cols else None
|
|
2556
|
+
|
|
2557
|
+
_facet_col = _select_facet_col(df, _preferred)
|
|
2558
|
+
|
|
2559
|
+
if _facet_col is None:
|
|
2560
|
+
# Nothing suitable → single facet keeps pipeline alive
|
|
2561
|
+
df["__facet__"] = "All"
|
|
2562
|
+
else:
|
|
2563
|
+
s = df[_facet_col]
|
|
2564
|
+
if pd.api.types.is_numeric_dtype(s):
|
|
2565
|
+
# Robust numeric binning: quantiles first, fallback to equal-width
|
|
2566
|
+
uniq = pd.Series(s).dropna().nunique()
|
|
2567
|
+
q = 3 if uniq < 10 else 4 if uniq < 30 else 5
|
|
2568
|
+
try:
|
|
2569
|
+
df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
|
|
2570
|
+
except Exception:
|
|
2571
|
+
df["__facet__"] = pd.cut(s.astype(float), bins=q)
|
|
2572
|
+
else:
|
|
2573
|
+
# Cap long tails; keep top categories
|
|
2574
|
+
vc = s.astype(str).value_counts()
|
|
2575
|
+
keep = vc.index[:{top_n}]
|
|
2576
|
+
df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
|
|
2577
|
+
|
|
2578
|
+
levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
|
|
2579
|
+
levels = [x for x in levels if x != "nan"]
|
|
2580
|
+
levels.sort()
|
|
2581
|
+
|
|
2582
|
+
m = len(levels)
|
|
2583
|
+
cols = 3 if m >= 3 else m or 1
|
|
2584
|
+
rows = int(math.ceil(m / cols))
|
|
2585
|
+
|
|
2586
|
+
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
|
|
2587
|
+
if not isinstance(axes, (list, np.ndarray)):
|
|
2588
|
+
axes = np.array([[axes]])
|
|
2589
|
+
axes = axes.reshape(rows, cols)
|
|
2590
|
+
|
|
2591
|
+
for i, lvl in enumerate(levels):
|
|
2592
|
+
r, c = divmod(i, cols)
|
|
2593
|
+
ax = axes[r, c]
|
|
2594
|
+
s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
|
|
2595
|
+
.astype(str).value_counts().nlargest({top_n}))
|
|
2596
|
+
if len(s) > 0:
|
|
2597
|
+
ax.pie(s.values, labels=[str(x) for x in s.index],
|
|
2598
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2599
|
+
ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
|
|
2600
|
+
|
|
2601
|
+
# hide any empty subplots
|
|
2602
|
+
for j in range(m, rows*cols):
|
|
2603
|
+
r, c = divmod(j, cols)
|
|
2604
|
+
axes[r, c].axis("off")
|
|
2605
|
+
|
|
2606
|
+
plt.tight_layout(); plt.show()
|
|
2607
|
+
|
|
2608
|
+
# --- companion visual (adaptive) ---
|
|
2609
|
+
_comp_pref = "{_comp_pref}"
|
|
2610
|
+
# build contingency table: pie_dim x facet
|
|
2611
|
+
_tab = (df[["__facet__", "{pie_dim}"]]
|
|
2612
|
+
.dropna()
|
|
2613
|
+
.astype({{"__facet__": str, "{pie_dim}": str}})
|
|
2614
|
+
.value_counts()
|
|
2615
|
+
.unstack(level="__facet__")
|
|
2616
|
+
.fillna(0))
|
|
2617
|
+
|
|
2618
|
+
# keep top categories by overall size
|
|
2619
|
+
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2620
|
+
|
|
2621
|
+
if _comp_pref == "grouped_bar":
|
|
2622
|
+
ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2623
|
+
ax.set_title("{pie_dim} by {facet_col} (grouped)")
|
|
2624
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2625
|
+
plt.tight_layout(); plt.show()
|
|
2626
|
+
|
|
2627
|
+
elif _comp_pref == "stacked_bar":
|
|
2628
|
+
ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2629
|
+
ax.set_title("{pie_dim} by {facet_col} (stacked)")
|
|
2630
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2631
|
+
plt.tight_layout(); plt.show()
|
|
2632
|
+
|
|
2633
|
+
elif _comp_pref == "stacked_bar_pct":
|
|
2634
|
+
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
|
|
2635
|
+
ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
|
|
2636
|
+
ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
|
|
2637
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
|
|
2638
|
+
plt.tight_layout(); plt.show()
|
|
2639
|
+
|
|
2640
|
+
elif _comp_pref == "heatmap":
|
|
2641
|
+
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
|
|
2642
|
+
import numpy as np
|
|
2643
|
+
fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
|
|
2644
|
+
im = ax.imshow(_perc.values, aspect='auto')
|
|
2645
|
+
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2646
|
+
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2647
|
+
ax.set_title("{pie_dim} by {facet_col} — % heatmap")
|
|
2648
|
+
for i in range(_perc.shape[0]):
|
|
2649
|
+
for j in range(_perc.shape[1]):
|
|
2650
|
+
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2651
|
+
fig.colorbar(im, ax=ax, label="%")
|
|
2652
|
+
plt.tight_layout(); plt.show()
|
|
2653
|
+
|
|
2654
|
+
else: # counts_bar (default denominators)
|
|
2655
|
+
_counts = df["__facet"].value_counts()
|
|
2656
|
+
ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
|
|
2657
|
+
ax.set_title("Counts by {facet_col}")
|
|
2658
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2659
|
+
plt.tight_layout(); plt.show()
|
|
2660
|
+
|
|
2661
|
+
""".lstrip()
|
|
2662
|
+
return snippet
|
|
2663
|
+
|
|
2664
|
+
# --------------- CASE C: single pie ---------------
|
|
2665
|
+
chosen = None
|
|
2666
|
+
for c in cats:
|
|
2667
|
+
if c in df.columns and _is_cat(c):
|
|
2668
|
+
chosen = c; break
|
|
2669
|
+
if chosen is None:
|
|
2670
|
+
chosen = _fallback_cat()
|
|
2671
|
+
|
|
2672
|
+
if chosen:
|
|
2673
|
+
snippet = f"""
|
|
2674
|
+
import matplotlib.pyplot as plt
|
|
2675
|
+
counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
|
|
2676
|
+
fig, ax = plt.subplots()
|
|
2677
|
+
if len(counts) > 0:
|
|
2678
|
+
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2679
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2680
|
+
ax.set_title('Distribution of {chosen} (top {top_n})')
|
|
2681
|
+
ax.axis('equal')
|
|
2682
|
+
plt.show()
|
|
2683
|
+
""".lstrip()
|
|
2684
|
+
return snippet
|
|
2685
|
+
|
|
2686
|
+
# numeric last resort
|
|
2687
|
+
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
2688
|
+
if num_cols:
|
|
2689
|
+
col = num_cols[0]
|
|
2690
|
+
snippet = f"""
|
|
2691
|
+
import pandas as pd
|
|
2692
|
+
import matplotlib.pyplot as plt
|
|
2693
|
+
bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
|
|
2694
|
+
counts = bins.value_counts().sort_index()
|
|
2695
|
+
fig, ax = plt.subplots()
|
|
2696
|
+
if len(counts) > 0:
|
|
2697
|
+
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2698
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2699
|
+
ax.set_title('Distribution of {col} (binned)')
|
|
2700
|
+
ax.axis('equal')
|
|
2701
|
+
plt.show()
|
|
2702
|
+
""".lstrip()
|
|
2703
|
+
return snippet
|
|
2704
|
+
|
|
2705
|
+
return code
|
|
2706
|
+
|
|
2707
|
+
|
|
2708
|
+
def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2709
|
+
"""
|
|
2710
|
+
Removes seaborn `palette=` when no `hue=` is present in the same call.
|
|
2711
|
+
Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
|
|
2712
|
+
"""
|
|
2713
|
+
if "sns." not in code:
|
|
2714
|
+
return code
|
|
2715
|
+
|
|
2716
|
+
# Targets common seaborn plotters
|
|
2717
|
+
funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
|
|
2718
|
+
pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
|
|
2719
|
+
|
|
2720
|
+
def _fix_call(m):
|
|
2721
|
+
head, inner = m.group(1), m.group(2)
|
|
2722
|
+
# If there's already hue=, keep as is
|
|
2723
|
+
if re.search(r"(?<!\w)hue\s*=", inner):
|
|
2724
|
+
return f"{head}{inner})"
|
|
2725
|
+
# Otherwise remove palette=... safely (and any adjacent comma spacing)
|
|
2726
|
+
inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
|
|
2727
|
+
inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
|
|
2728
|
+
inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
|
|
2729
|
+
return f"{head}{inner2})"
|
|
2730
|
+
|
|
2731
|
+
return pattern.sub(_fix_call, code)
|
|
2732
|
+
|
|
2733
|
+
|
|
2734
|
+
def patch_quiet_specific_warnings(code: str) -> str:
|
|
2735
|
+
"""
|
|
2736
|
+
Inserts targeted warning filters (not blanket ignores).
|
|
2737
|
+
- seaborn palette/hue deprecation
|
|
2738
|
+
- python-dotenv parse chatter
|
|
2739
|
+
"""
|
|
2740
|
+
prelude = (
|
|
2741
|
+
"import warnings\n"
|
|
2742
|
+
"warnings.filterwarnings(\n"
|
|
2743
|
+
" 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
|
|
2744
|
+
"warnings.filterwarnings(\n"
|
|
2745
|
+
" 'ignore', message=r'python-dotenv could not parse statement.*')\n"
|
|
2746
|
+
)
|
|
2747
|
+
# If warnings already imported once, just add filters; else insert full prelude.
|
|
2748
|
+
if "import warnings" in code:
|
|
2749
|
+
code = re.sub(
|
|
2750
|
+
r"(import warnings[^\n]*\n)",
|
|
2751
|
+
lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
|
|
2752
|
+
code,
|
|
2753
|
+
count=1
|
|
2754
|
+
)
|
|
2755
|
+
|
|
2756
|
+
else:
|
|
2757
|
+
# place after first import block if possible
|
|
2758
|
+
m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
|
|
2759
|
+
if m:
|
|
2760
|
+
idx = m.end()
|
|
2761
|
+
code = code[:idx] + prelude + code[idx:]
|
|
2762
|
+
else:
|
|
2763
|
+
code = prelude + code
|
|
2764
|
+
return code
|
|
2765
|
+
|
|
2766
|
+
|
|
2767
|
+
def _norm_col_name(s: str) -> str:
|
|
2768
|
+
"""normalise a column name: lowercase + strip non-alphanumerics."""
|
|
2769
|
+
return re.sub(r"[^a-z0-9]+", "", str(s).lower())
|
|
2770
|
+
|
|
2771
|
+
|
|
2772
|
+
def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
2773
|
+
"""return the actual df column that matches any candidate (after normalisation)."""
|
|
2774
|
+
norm_map = {_norm_col_name(c): c for c in df.columns}
|
|
2775
|
+
for cand in candidates:
|
|
2776
|
+
hit = norm_map.get(_norm_col_name(cand))
|
|
2777
|
+
if hit is not None:
|
|
2778
|
+
return hit
|
|
2779
|
+
return None
|
|
2780
|
+
|
|
2781
|
+
|
|
2782
|
+
def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
|
|
2783
|
+
"""
|
|
2784
|
+
If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
|
|
2785
|
+
Returns (df, found_bool).
|
|
2786
|
+
"""
|
|
2787
|
+
if target in df.columns:
|
|
2788
|
+
return df, True
|
|
2789
|
+
col = _first_present(df, [target, *aliases])
|
|
2790
|
+
if col is None:
|
|
2791
|
+
return df, False
|
|
2792
|
+
df[target] = df[col]
|
|
2793
|
+
return df, True
|
|
2794
|
+
|
|
2795
|
+
|
|
2796
|
+
def strip_python_dotenv(code: str) -> str:
|
|
2797
|
+
"""
|
|
2798
|
+
Remove any use of python-dotenv from generated code, including:
|
|
2799
|
+
- single and multi-line 'from dotenv import ...'
|
|
2800
|
+
- 'import dotenv' (with or without alias) and calls via any alias
|
|
2801
|
+
- load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
|
|
2802
|
+
- IPython magics (%load_ext dotenv, %dotenv, %env …)
|
|
2803
|
+
- shell installs like '!pip install python-dotenv'
|
|
2804
|
+
"""
|
|
2805
|
+
original = code
|
|
2806
|
+
|
|
2807
|
+
# 0) Kill IPython magics & shell installs referencing dotenv
|
|
2808
|
+
code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
|
|
2809
|
+
code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
|
|
2810
|
+
code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
|
|
2811
|
+
code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
|
|
2812
|
+
|
|
2813
|
+
# 1) Remove single-line 'from dotenv import ...'
|
|
2814
|
+
code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
|
|
2815
|
+
|
|
2816
|
+
# 2) Remove multi-line 'from dotenv import ( ... )' blocks
|
|
2817
|
+
code = re.sub(
|
|
2818
|
+
r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
|
|
2819
|
+
"",
|
|
2820
|
+
code,
|
|
2821
|
+
flags=re.MULTILINE,
|
|
2822
|
+
)
|
|
2823
|
+
|
|
2824
|
+
# 3) Remove 'import dotenv' (with optional alias). Capture alias names.
|
|
2825
|
+
aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
|
|
2826
|
+
code, flags=re.MULTILINE)
|
|
2827
|
+
code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
|
|
2828
|
+
"", code, flags=re.MULTILINE)
|
|
2829
|
+
|
|
2830
|
+
# 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
|
|
2831
|
+
# e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
|
|
2832
|
+
fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
|
|
2833
|
+
# bare calls
|
|
2834
|
+
code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2835
|
+
# dotted calls with any identifier prefix (alias or module)
|
|
2836
|
+
code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
|
|
2837
|
+
"", code, flags=re.MULTILINE)
|
|
2838
|
+
|
|
2839
|
+
# 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
|
|
2840
|
+
for al in aliases:
|
|
2841
|
+
code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2842
|
+
|
|
2843
|
+
# 6) Tidy excess blank lines
|
|
2844
|
+
code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
|
|
2845
|
+
return code
|
|
2846
|
+
|
|
2847
|
+
|
|
2848
|
+
def fix_predict_calls_records_arg(code: str) -> str:
|
|
2849
|
+
"""
|
|
2850
|
+
If generated code calls predict_* with a list-of-dicts via .to_dict('records')
|
|
2851
|
+
(or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
|
|
2852
|
+
Works line-by-line to avoid over-rewrites elsewhere.
|
|
2853
|
+
Examples fixed:
|
|
2854
|
+
predict_patient(X_test.iloc[:5].to_dict('records'))
|
|
2855
|
+
predict_risk(df.head(3).to_dict(orient="records"))
|
|
2856
|
+
→ predict_patient(X_test.iloc[:5])
|
|
2857
|
+
"""
|
|
2858
|
+
fixed_lines = []
|
|
2859
|
+
for line in code.splitlines():
|
|
2860
|
+
if "predict_" in line and "to_dict" in line and "records" in line:
|
|
2861
|
+
line = re.sub(
|
|
2862
|
+
r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
|
|
2863
|
+
"",
|
|
2864
|
+
line
|
|
2865
|
+
)
|
|
2866
|
+
fixed_lines.append(line)
|
|
2867
|
+
return "\n".join(fixed_lines)
|
|
2868
|
+
|
|
2869
|
+
|
|
2870
|
+
def fix_fstring_backslash_paths(code: str) -> str:
|
|
2871
|
+
"""
|
|
2872
|
+
Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
|
|
2873
|
+
→ f"...{os.path.join(out_dir, r'plots\\img.png')}"
|
|
2874
|
+
Only touches f-strings that contain a backslash path inside {...}.
|
|
2875
|
+
"""
|
|
2876
|
+
def _fix_line(line: str) -> str:
|
|
2877
|
+
# quick check: only f-strings need scanning
|
|
2878
|
+
if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
|
|
2879
|
+
return line
|
|
2880
|
+
# {var\rest-of-path} where var can be dotted (e.g., cfg.out)
|
|
2881
|
+
pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
|
|
2882
|
+
def repl(m):
|
|
2883
|
+
left = m.group(1)
|
|
2884
|
+
right = m.group(2).strip().replace('"', '\\"')
|
|
2885
|
+
return "{os.path.join(" + left + ', r"' + right + '")}'
|
|
2886
|
+
return pattern.sub(repl, line)
|
|
2887
|
+
|
|
2888
|
+
return "\n".join(_fix_line(ln) for ln in code.splitlines())
|
|
2889
|
+
|
|
2890
|
+
|
|
2891
|
+
def ensure_os_import(code: str) -> str:
|
|
2892
|
+
"""
|
|
2893
|
+
If os.path.join is used but 'import os' is missing, inject it at the top.
|
|
2894
|
+
"""
|
|
2895
|
+
needs = "os.path.join(" in code
|
|
2896
|
+
has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
|
|
2897
|
+
has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
|
|
2898
|
+
if needs and not (has_import_os or has_from_os):
|
|
2899
|
+
return "import os\n" + code
|
|
2900
|
+
return code
|
|
2901
|
+
|
|
2902
|
+
|
|
2903
|
+
def fix_seaborn_boxplot_nameerror(code: str) -> str:
|
|
2904
|
+
"""
|
|
2905
|
+
Fix bad calls like: sns.boxplot(boxplot)
|
|
2906
|
+
Heuristic:
|
|
2907
|
+
- If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
|
|
2908
|
+
- Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
|
|
2909
|
+
- Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
|
|
2910
|
+
- Else → sns.boxplot(ax=ax)
|
|
2911
|
+
Ensures a matplotlib Axes 'ax' exists.
|
|
2912
|
+
"""
|
|
2913
|
+
pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
|
|
2914
|
+
if not pattern.search(code):
|
|
2915
|
+
return code
|
|
2916
|
+
|
|
2917
|
+
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2918
|
+
has_var = re.search(r"\bvar\b", code) is not None
|
|
2919
|
+
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2920
|
+
|
|
2921
|
+
if has_plot_df and has_var and has_fh:
|
|
2922
|
+
replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2923
|
+
elif has_plot_df and has_var:
|
|
2924
|
+
replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
|
|
2925
|
+
elif has_plot_df:
|
|
2926
|
+
replacement = "sns.boxplot(data=plot_df, ax=ax)"
|
|
2927
|
+
else:
|
|
2928
|
+
replacement = "sns.boxplot(ax=ax)"
|
|
2929
|
+
|
|
2930
|
+
fixed = pattern.sub(replacement, code)
|
|
2931
|
+
|
|
2932
|
+
# Ensure 'fig, ax = plt.subplots(...)' exists
|
|
2933
|
+
if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
|
|
2934
|
+
# Insert right before the first seaborn call
|
|
2935
|
+
m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
|
|
2936
|
+
insert_at = m.start() if m else 0
|
|
2937
|
+
fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
|
|
2938
|
+
|
|
2939
|
+
return fixed
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
def fix_seaborn_barplot_nameerror(code: str) -> str:
|
|
2943
|
+
"""
|
|
2944
|
+
Fix bad calls like: sns.barplot(barplot)
|
|
2945
|
+
Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
|
|
2946
|
+
otherwise degrade safely to an empty call on an existing Axes.
|
|
2947
|
+
"""
|
|
2948
|
+
import re
|
|
2949
|
+
pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
|
|
2950
|
+
if not pattern.search(code):
|
|
2951
|
+
return code
|
|
2952
|
+
|
|
2953
|
+
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2954
|
+
has_var = re.search(r"\bvar\b", code) is not None
|
|
2955
|
+
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2956
|
+
|
|
2957
|
+
if has_plot_df and has_var and has_fh:
|
|
2958
|
+
replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2959
|
+
elif has_plot_df and has_var:
|
|
2960
|
+
replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2961
|
+
elif has_plot_df:
|
|
2962
|
+
replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2963
|
+
else:
|
|
2964
|
+
replacement = "sns.barplot(ax=ax)"
|
|
2965
|
+
|
|
2966
|
+
# ensure an Axes 'ax' exists (no-op if already present)
|
|
2967
|
+
if "ax =" not in code:
|
|
2968
|
+
code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2969
|
+
|
|
2970
|
+
return pattern.sub(replacement, code)
|
|
2971
|
+
|
|
2972
|
+
|
|
2973
|
+
def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2974
|
+
"""
|
|
2975
|
+
Parses the raw text to extract and format the 'refined question',
|
|
2976
|
+
'intents (tasks)', and 'chronology of tasks' sections.
|
|
2977
|
+
Args:
|
|
2978
|
+
raw_text: The complete input string containing the ML pipeline structure.
|
|
2979
|
+
Returns:
|
|
2980
|
+
A tuple containing:
|
|
2981
|
+
(formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2982
|
+
"""
|
|
2983
|
+
# --- 1. Regex Pattern to Extract Sections ---
|
|
2984
|
+
# The pattern uses capturing groups (?) to look for the section headers
|
|
2985
|
+
# (e.g., 'refined question:') and captures all the content until the next
|
|
2986
|
+
# section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2987
|
+
|
|
2988
|
+
pattern = re.compile(
|
|
2989
|
+
r"refined question:(?P<question>.*?)"
|
|
2990
|
+
r"intents \(tasks\):(?P<intents>.*?)"
|
|
2991
|
+
r"Chronology of tasks:(?P<chronology>.*)",
|
|
2992
|
+
re.IGNORECASE | re.DOTALL
|
|
2993
|
+
)
|
|
2994
|
+
|
|
2995
|
+
match = pattern.search(raw_text)
|
|
2996
|
+
|
|
2997
|
+
if not match:
|
|
2998
|
+
raise ValueError("Input text structure does not match the expected pattern.")
|
|
2999
|
+
|
|
3000
|
+
# --- 2. Extract Content ---
|
|
3001
|
+
question_content = match.group('question').strip()
|
|
3002
|
+
intents_content = match.group('intents').strip()
|
|
3003
|
+
chronology_content = match.group('chronology').strip()
|
|
3004
|
+
|
|
3005
|
+
# --- 3. Formatting Functions ---
|
|
3006
|
+
|
|
3007
|
+
def format_question(content):
|
|
3008
|
+
"""Formats the Refined Question section."""
|
|
3009
|
+
# Clean up leading/trailing whitespace and ensure clean paragraphs
|
|
3010
|
+
content = content.strip().replace('\n', ' ').replace(' ', ' ')
|
|
3011
|
+
|
|
3012
|
+
# Simple formatting using Markdown headers and bolding
|
|
3013
|
+
formatted = (
|
|
3014
|
+
# "## 1. Project Goal and Objectives\n\n"
|
|
3015
|
+
"<b> Refined Question:</b>\n"
|
|
3016
|
+
f"{content}\n"
|
|
3017
|
+
)
|
|
3018
|
+
return formatted
|
|
3019
|
+
|
|
3020
|
+
def format_intents(content):
|
|
3021
|
+
"""Formats the Intents (Tasks) section as a structured list."""
|
|
3022
|
+
# Use regex to find and format each numbered task
|
|
3023
|
+
# It finds 'N. **Text** - ...' and breaks it down.
|
|
3024
|
+
|
|
3025
|
+
tasks = []
|
|
3026
|
+
# Pattern: N. **Text** - Content (including newlines, non-greedy)
|
|
3027
|
+
# We need to explicitly handle the list items starting with '-' within the content
|
|
3028
|
+
task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
|
|
3029
|
+
|
|
3030
|
+
# Split the content by lines and join tasks back into clean strings
|
|
3031
|
+
raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
|
|
3032
|
+
|
|
3033
|
+
for task in raw_tasks:
|
|
3034
|
+
# Replace the initial task number and **Heading** with a Heading 3
|
|
3035
|
+
task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
|
|
3036
|
+
|
|
3037
|
+
# Replace list markers (' - ') with Markdown bullets ('* ') for clarity
|
|
3038
|
+
task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
|
|
3039
|
+
tasks.append(task)
|
|
3040
|
+
|
|
3041
|
+
formatted_tasks = "\n\n".join(tasks)
|
|
3042
|
+
|
|
3043
|
+
return (
|
|
3044
|
+
"\n---\n"
|
|
3045
|
+
"## 2. Methodology and Tasks\n\n"
|
|
3046
|
+
f"{formatted_tasks}\n"
|
|
3047
|
+
)
|
|
3048
|
+
|
|
3049
|
+
def format_chronology(content):
|
|
3050
|
+
"""Formats the Chronology section."""
|
|
3051
|
+
# Uses the given LaTeX format
|
|
3052
|
+
content = content.strip().replace(' ', ' \rightarrow ')
|
|
3053
|
+
formatted = (
|
|
3054
|
+
"\n---\n"
|
|
3055
|
+
"## 3. Chronology of Tasks\n"
|
|
3056
|
+
f"$$\\text{{{content}}}$$"
|
|
3057
|
+
)
|
|
3058
|
+
return formatted
|
|
3059
|
+
|
|
3060
|
+
# --- 4. Format and Return ---
|
|
3061
|
+
formatted_question = format_question(question_content)
|
|
3062
|
+
formatted_intents = format_intents(intents_content)
|
|
3063
|
+
formatted_chronology = format_chronology(chronology_content)
|
|
3064
|
+
|
|
3065
|
+
return formatted_question, formatted_intents, formatted_chronology
|
|
3066
|
+
|
|
3067
|
+
|
|
3068
|
+
def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
3069
|
+
"""Combines all formatted parts into a final report string."""
|
|
3070
|
+
return (
|
|
3071
|
+
"# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
3072
|
+
f"{formatted_question}\n"
|
|
3073
|
+
f"{formatted_intents}\n"
|
|
3074
|
+
f"{formatted_chronology}\n"
|
|
3075
|
+
)
|
|
3076
|
+
|
|
3077
|
+
|
|
3078
|
+
def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
3079
|
+
"""
|
|
3080
|
+
Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
3081
|
+
from_predictions(...) which works for multi-label loops without requiring
|
|
3082
|
+
the estimator to expose _estimator_type.
|
|
3083
|
+
"""
|
|
3084
|
+
return re.sub(
|
|
3085
|
+
r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
3086
|
+
r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
3087
|
+
code
|
|
3088
|
+
)
|
|
3089
|
+
|
|
3090
|
+
|
|
3091
|
+
def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
3092
|
+
"""
|
|
3093
|
+
Ensure every Matplotlib/Seaborn Axes has a title.
|
|
3094
|
+
Uses refined_question -> askai_question -> fallback.
|
|
3095
|
+
Only sets a title if it's currently empty.
|
|
3096
|
+
"""
|
|
3097
|
+
import matplotlib.pyplot as plt
|
|
3098
|
+
|
|
3099
|
+
def _all_figures():
|
|
3100
|
+
try:
|
|
3101
|
+
from matplotlib._pylab_helpers import Gcf
|
|
3102
|
+
return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
3103
|
+
except Exception:
|
|
3104
|
+
# Best effort fallback
|
|
3105
|
+
nums = plt.get_fignums()
|
|
3106
|
+
return [plt.figure(n) for n in nums] if nums else []
|
|
3107
|
+
|
|
3108
|
+
# Choose a concise title
|
|
3109
|
+
title = None
|
|
3110
|
+
if isinstance(ctx, dict):
|
|
3111
|
+
title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
3112
|
+
title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
3113
|
+
|
|
3114
|
+
for fig in _all_figures():
|
|
3115
|
+
for ax in getattr(fig, "axes", []):
|
|
3116
|
+
try:
|
|
3117
|
+
if not (ax.get_title() or "").strip():
|
|
3118
|
+
ax.set_title(title)
|
|
3119
|
+
except Exception:
|
|
3120
|
+
pass
|
|
3121
|
+
try:
|
|
3122
|
+
fig.tight_layout()
|
|
3123
|
+
except Exception:
|
|
3124
|
+
pass
|
|
3125
|
+
|
|
3126
|
+
|
|
3127
|
+
def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
3128
|
+
"""
|
|
3129
|
+
Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
3130
|
+
SB_barplot(barplot) -> SB_barplot()
|
|
3131
|
+
SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
3132
|
+
sns.barplot(barplot) -> SB_barplot()
|
|
3133
|
+
sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
3134
|
+
Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
3135
|
+
"""
|
|
3136
|
+
names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
3137
|
+
for n in names:
|
|
3138
|
+
# SB_* with sentinel as the first arg (with or without trailing args)
|
|
3139
|
+
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3140
|
+
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3141
|
+
# sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
3142
|
+
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3143
|
+
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3144
|
+
return code
|
|
3145
|
+
|
|
3146
|
+
|
|
3147
|
+
def patch_rmse_calls(code: str) -> str:
|
|
3148
|
+
"""
|
|
3149
|
+
Make RMSE robust across sklearn versions.
|
|
3150
|
+
- Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
3151
|
+
- Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
3152
|
+
"""
|
|
3153
|
+
import re
|
|
3154
|
+
# (a) Specific RMSE pattern
|
|
3155
|
+
code = re.sub(
|
|
3156
|
+
r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3157
|
+
r"_SMX_rmse(\1)",
|
|
3158
|
+
code,
|
|
3159
|
+
flags=re.DOTALL
|
|
3160
|
+
)
|
|
3161
|
+
# (b) Guard any other MSE calls
|
|
3162
|
+
code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3163
|
+
return code
|