syntaxmatrix 2.3.5__py3-none-any.whl → 2.5.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syntaxmatrix/agentic/__init__.py +0 -0
- syntaxmatrix/agentic/agent_tools.py +24 -0
- syntaxmatrix/agentic/agents.py +810 -0
- syntaxmatrix/agentic/code_tools_registry.py +37 -0
- syntaxmatrix/agentic/model_templates.py +1790 -0
- syntaxmatrix/commentary.py +134 -112
- syntaxmatrix/core.py +385 -245
- syntaxmatrix/dataset_preprocessing.py +218 -0
- syntaxmatrix/display.py +89 -37
- syntaxmatrix/gpt_models_latest.py +5 -4
- syntaxmatrix/profiles.py +19 -4
- syntaxmatrix/routes.py +947 -141
- syntaxmatrix/settings/model_map.py +38 -30
- syntaxmatrix/static/icons/hero_bg.jpg +0 -0
- syntaxmatrix/templates/dashboard.html +248 -54
- syntaxmatrix/utils.py +2254 -84
- {syntaxmatrix-2.3.5.dist-info → syntaxmatrix-2.5.5.5.dist-info}/METADATA +16 -17
- {syntaxmatrix-2.3.5.dist-info → syntaxmatrix-2.5.5.5.dist-info}/RECORD +21 -15
- syntaxmatrix/model_templates.py +0 -29
- {syntaxmatrix-2.3.5.dist-info → syntaxmatrix-2.5.5.5.dist-info}/WHEEL +0 -0
- {syntaxmatrix-2.3.5.dist-info → syntaxmatrix-2.5.5.5.dist-info}/licenses/LICENSE.txt +0 -0
- {syntaxmatrix-2.3.5.dist-info → syntaxmatrix-2.5.5.5.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py
CHANGED
|
@@ -1,15 +1,1189 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
import re, os, textwrap
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import re, textwrap
|
|
4
3
|
import pandas as pd
|
|
5
|
-
import
|
|
4
|
+
import numpy as np
|
|
6
5
|
import warnings
|
|
7
|
-
from .model_templates import classification
|
|
8
|
-
import syntaxmatrix as smx
|
|
9
6
|
|
|
7
|
+
from difflib import get_close_matches
|
|
8
|
+
from typing import Iterable, Tuple, Dict
|
|
9
|
+
import inspect
|
|
10
|
+
from sklearn.preprocessing import OneHotEncoder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from syntaxmatrix.agentic.model_templates import (
|
|
14
|
+
classification, regression, multilabel_classification,
|
|
15
|
+
eda_overview, eda_correlation,
|
|
16
|
+
anomaly_detection, ts_anomaly_detection,
|
|
17
|
+
dimensionality_reduction, feature_selection,
|
|
18
|
+
time_series_forecasting, time_series_classification,
|
|
19
|
+
unknown_group_proxy_pack, viz_line,
|
|
20
|
+
clustering, recommendation, topic_modelling,
|
|
21
|
+
viz_pie, viz_count_bar, viz_box, viz_scatter,
|
|
22
|
+
viz_stacked_bar, viz_distribution, viz_area, viz_kde,
|
|
23
|
+
)
|
|
24
|
+
import ast
|
|
10
25
|
|
|
11
26
|
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
|
12
27
|
|
|
28
|
+
INJECTABLE_INTENTS = [
|
|
29
|
+
"classification",
|
|
30
|
+
"multilabel_classification",
|
|
31
|
+
"regression",
|
|
32
|
+
"anomaly_detection",
|
|
33
|
+
"time_series_forecasting",
|
|
34
|
+
"time_series_classification",
|
|
35
|
+
"ts_anomaly_detection",
|
|
36
|
+
"dimensionality_reduction",
|
|
37
|
+
"feature_selection",
|
|
38
|
+
"clustering",
|
|
39
|
+
"eda",
|
|
40
|
+
"correlation_analysis",
|
|
41
|
+
"visualisation",
|
|
42
|
+
"recommendation",
|
|
43
|
+
"topic_modelling",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
def classify_ml_job(prompt: str) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Very-light intent classifier.
|
|
49
|
+
Returns one of:
|
|
50
|
+
'stat_test' | 'time_series' | 'clustering'
|
|
51
|
+
'classification' | 'regression' | 'eda'
|
|
52
|
+
"""
|
|
53
|
+
p = prompt.lower()
|
|
54
|
+
|
|
55
|
+
greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
|
|
56
|
+
if any(p.startswith(g) or p == g for g in greetings):
|
|
57
|
+
return "greeting"
|
|
58
|
+
|
|
59
|
+
# Feature selection / importance intent
|
|
60
|
+
if any(k in p for k in (
|
|
61
|
+
"feature selection", "select k best", "selectkbest", "rfe",
|
|
62
|
+
"mutual information", "feature importance", "permutation importance",
|
|
63
|
+
"feature engineering suggestions"
|
|
64
|
+
)):
|
|
65
|
+
return "feature_selection"
|
|
66
|
+
|
|
67
|
+
# Dimensionality reduction intent
|
|
68
|
+
if any(k in p for k in (
|
|
69
|
+
"pca", "principal component", "dimensionality reduction",
|
|
70
|
+
"reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
|
|
71
|
+
)):
|
|
72
|
+
return "dimensionality_reduction"
|
|
73
|
+
|
|
74
|
+
# Anomaly / outlier intent
|
|
75
|
+
if any(k in p for k in (
|
|
76
|
+
"anomaly", "anomalies", "outlier", "outliers", "novelty",
|
|
77
|
+
"fraud", "deviation", "rare event", "rare events", "odd pattern",
|
|
78
|
+
"suspicious"
|
|
79
|
+
)):
|
|
80
|
+
return "anomaly_detection"
|
|
81
|
+
|
|
82
|
+
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
83
|
+
return "stat_test"
|
|
84
|
+
if "forecast" in p or "prophet" in p:
|
|
85
|
+
return "time_series"
|
|
86
|
+
if "cluster" in p or "kmeans" in p:
|
|
87
|
+
return "clustering"
|
|
88
|
+
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
89
|
+
return "classification"
|
|
90
|
+
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
91
|
+
return "regression"
|
|
92
|
+
|
|
93
|
+
return "eda"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def harden_ai_code(code: str) -> str:
|
|
97
|
+
"""
|
|
98
|
+
Make any AI-generated cell resilient:
|
|
99
|
+
- Safe seaborn wrappers + sentinel vars (boxplot/barplot/etc.)
|
|
100
|
+
- Remove 'numeric_only=' args
|
|
101
|
+
- Replace pd.concat(...) with _safe_concat(...)
|
|
102
|
+
- Relax 'required_cols' hard fails
|
|
103
|
+
- Make static numeric_vars dynamic
|
|
104
|
+
- Wrap the whole block in try/except so no exception bubbles up
|
|
105
|
+
"""
|
|
106
|
+
# Remove any LLM-added try/except blocks (hardener adds its own)
|
|
107
|
+
import re
|
|
108
|
+
|
|
109
|
+
def strip_placeholders(code: str) -> str:
|
|
110
|
+
code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
|
|
111
|
+
"show('⚠ Block skipped due to an error.')",
|
|
112
|
+
code)
|
|
113
|
+
code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
|
|
114
|
+
return code
|
|
115
|
+
|
|
116
|
+
def _indent(code: str, spaces: int = 4) -> str:
|
|
117
|
+
pad = " " * spaces
|
|
118
|
+
return "\n".join(pad + line for line in code.splitlines())
|
|
119
|
+
|
|
120
|
+
def _SMX_OHE(**k):
|
|
121
|
+
# normalise arg name across sklearn versions
|
|
122
|
+
if "sparse" in k and "sparse_output" not in k:
|
|
123
|
+
k["sparse_output"] = k.pop("sparse")
|
|
124
|
+
# default behaviour we want
|
|
125
|
+
k.setdefault("handle_unknown", "ignore")
|
|
126
|
+
k.setdefault("sparse_output", False)
|
|
127
|
+
try:
|
|
128
|
+
# if running on old sklearn without sparse_output, translate back
|
|
129
|
+
if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
|
|
130
|
+
if "sparse_output" in k:
|
|
131
|
+
k["sparse"] = k.pop("sparse_output")
|
|
132
|
+
return OneHotEncoder(**k)
|
|
133
|
+
except TypeError:
|
|
134
|
+
# final fallback: try legacy name
|
|
135
|
+
if "sparse_output" in k:
|
|
136
|
+
k["sparse"] = k.pop("sparse_output")
|
|
137
|
+
return OneHotEncoder(**k)
|
|
138
|
+
|
|
139
|
+
def _strip_stray_backrefs(code: str) -> str:
|
|
140
|
+
code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
|
|
141
|
+
code = re.sub(r'(?m)[;]\s*\\\d+\s*', '; ', code)
|
|
142
|
+
return code
|
|
143
|
+
|
|
144
|
+
def _wrap_metric_calls(code: str) -> str:
|
|
145
|
+
names = [
|
|
146
|
+
"r2_score","accuracy_score","precision_score","recall_score","f1_score",
|
|
147
|
+
"roc_auc_score","classification_report","confusion_matrix",
|
|
148
|
+
"mean_absolute_error","mean_absolute_percentage_error",
|
|
149
|
+
"explained_variance_score","log_loss","average_precision_score",
|
|
150
|
+
"precision_recall_fscore_support"
|
|
151
|
+
]
|
|
152
|
+
pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
|
|
153
|
+
def repl(m):
|
|
154
|
+
prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
|
|
155
|
+
name = m.group(2)
|
|
156
|
+
return f"_SMX_call({prefix}{name}, "
|
|
157
|
+
return pat.sub(repl, code)
|
|
158
|
+
|
|
159
|
+
def _smx_patch_mean_squared_error_squared_kw():
|
|
160
|
+
"""
|
|
161
|
+
sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
|
|
162
|
+
Patch the module attr so 'from sklearn.metrics import mean_squared_error'
|
|
163
|
+
receives a wrapper that drops 'squared' if the underlying call rejects it.
|
|
164
|
+
"""
|
|
165
|
+
try:
|
|
166
|
+
import sklearn.metrics as _sm
|
|
167
|
+
_orig = getattr(_sm, "mean_squared_error", None)
|
|
168
|
+
if not callable(_orig):
|
|
169
|
+
return
|
|
170
|
+
def _mse_compat(y_true, y_pred, *a, **k):
|
|
171
|
+
if "squared" in k:
|
|
172
|
+
try:
|
|
173
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
174
|
+
except TypeError:
|
|
175
|
+
k.pop("squared", None)
|
|
176
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
177
|
+
return _orig(y_true, y_pred, *a, **k)
|
|
178
|
+
_sm.mean_squared_error = _mse_compat
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
def _smx_patch_kmeans_n_init_auto():
|
|
183
|
+
"""
|
|
184
|
+
sklearn>=1.4 accepts n_init='auto'; older versions want an int.
|
|
185
|
+
Patch sklearn.cluster.KMeans so 'auto' is converted to 10 if TypeError occurs.
|
|
186
|
+
"""
|
|
187
|
+
try:
|
|
188
|
+
import sklearn.cluster as _sc
|
|
189
|
+
_Orig = getattr(_sc, "KMeans", None)
|
|
190
|
+
if _Orig is None:
|
|
191
|
+
return
|
|
192
|
+
class KMeansCompat(_Orig):
|
|
193
|
+
def __init__(self, *a, **k):
|
|
194
|
+
if isinstance(k.get("n_init", None), str):
|
|
195
|
+
try:
|
|
196
|
+
super().__init__(*a, **k)
|
|
197
|
+
return
|
|
198
|
+
except TypeError:
|
|
199
|
+
k["n_init"] = 10
|
|
200
|
+
super().__init__(*a, **k)
|
|
201
|
+
_sc.KMeans = KMeansCompat
|
|
202
|
+
except Exception:
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
def _smx_patch_ohe_name_api():
|
|
206
|
+
"""
|
|
207
|
+
Guard get_feature_names_out on older OneHotEncoder.
|
|
208
|
+
Your templates already use _SMX_OHE; this adds a soft fallback for feature names.
|
|
209
|
+
"""
|
|
210
|
+
try:
|
|
211
|
+
from sklearn.preprocessing import OneHotEncoder as _OHE
|
|
212
|
+
_orig_get = getattr(_OHE, "get_feature_names_out", None)
|
|
213
|
+
if _orig_get is None:
|
|
214
|
+
# Monkey-patch instance method via mixin
|
|
215
|
+
def _fallback_get_feature_names_out(self, input_features=None):
|
|
216
|
+
cats = getattr(self, "categories_", None) or []
|
|
217
|
+
input_features = input_features or [f"x{i}" for i in range(len(cats))]
|
|
218
|
+
names = []
|
|
219
|
+
for base, cat_list in zip(input_features, cats):
|
|
220
|
+
for j, _ in enumerate(cat_list):
|
|
221
|
+
names.append(f"{base}__{j}")
|
|
222
|
+
return names
|
|
223
|
+
_OHE.get_feature_names_out = _fallback_get_feature_names_out
|
|
224
|
+
except Exception:
|
|
225
|
+
pass
|
|
226
|
+
|
|
227
|
+
# Register and run patches once per execution
|
|
228
|
+
for _patch in (
|
|
229
|
+
_smx_patch_mean_squared_error_squared_kw,
|
|
230
|
+
_smx_patch_kmeans_n_init_auto,
|
|
231
|
+
_smx_patch_ohe_name_api,
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
_patch()
|
|
235
|
+
except Exception:
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
PREFACE = (
|
|
239
|
+
"# === SMX Auto-Hardening Preface (do not edit) ===\n"
|
|
240
|
+
"import warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt\n"
|
|
241
|
+
"warnings.filterwarnings('ignore')\n"
|
|
242
|
+
"try:\n"
|
|
243
|
+
" import seaborn as sns\n"
|
|
244
|
+
"except Exception:\n"
|
|
245
|
+
" class _Dummy:\n"
|
|
246
|
+
" def __getattr__(self, name):\n"
|
|
247
|
+
" def _f(*a, **k):\n"
|
|
248
|
+
" from syntaxmatrix.display import show\n"
|
|
249
|
+
" show('⚠ seaborn not available; plot skipped.')\n"
|
|
250
|
+
" return _f\n"
|
|
251
|
+
" sns = _Dummy()\n"
|
|
252
|
+
"\n"
|
|
253
|
+
"from syntaxmatrix.display import show as _SMX_base_show\n"
|
|
254
|
+
"def _SMX_caption_from_ctx():\n"
|
|
255
|
+
" g = globals()\n"
|
|
256
|
+
" t = g.get('refined_question') or g.get('askai_question') or 'Table'\n"
|
|
257
|
+
" return str(t).strip().splitlines()[0][:120]\n"
|
|
258
|
+
"\n"
|
|
259
|
+
"def show(obj, title=None):\n"
|
|
260
|
+
" try:\n"
|
|
261
|
+
" import pandas as pd\n"
|
|
262
|
+
" if isinstance(obj, pd.DataFrame):\n"
|
|
263
|
+
" cap = (title or _SMX_caption_from_ctx())\n"
|
|
264
|
+
" try:\n"
|
|
265
|
+
" return _SMX_base_show(obj.style.set_caption(cap))\n"
|
|
266
|
+
" except Exception:\n"
|
|
267
|
+
" pass\n"
|
|
268
|
+
" except Exception:\n"
|
|
269
|
+
" pass\n"
|
|
270
|
+
" return _SMX_base_show(obj)\n"
|
|
271
|
+
"\n"
|
|
272
|
+
"def _SMX_axes_have_titles(fig=None):\n"
|
|
273
|
+
" import matplotlib.pyplot as _plt\n"
|
|
274
|
+
" fig = fig or _plt.gcf()\n"
|
|
275
|
+
" try:\n"
|
|
276
|
+
" for _ax in fig.get_axes():\n"
|
|
277
|
+
" if (_ax.get_title() or '').strip():\n"
|
|
278
|
+
" return True\n"
|
|
279
|
+
" except Exception:\n"
|
|
280
|
+
" pass\n"
|
|
281
|
+
" return False\n"
|
|
282
|
+
"\n"
|
|
283
|
+
"def _SMX_export_png():\n"
|
|
284
|
+
" import io, base64\n"
|
|
285
|
+
" fig = plt.gcf()\n"
|
|
286
|
+
" try:\n"
|
|
287
|
+
" if not _SMX_axes_have_titles(fig):\n"
|
|
288
|
+
" fig.suptitle(_SMX_caption_from_ctx(), fontsize=10)\n"
|
|
289
|
+
" except Exception:\n"
|
|
290
|
+
" pass\n"
|
|
291
|
+
" buf = io.BytesIO()\n"
|
|
292
|
+
" plt.savefig(buf, format='png', bbox_inches='tight')\n"
|
|
293
|
+
" buf.seek(0)\n"
|
|
294
|
+
" from IPython.display import display, HTML\n"
|
|
295
|
+
" _img = base64.b64encode(buf.read()).decode('ascii')\n"
|
|
296
|
+
" display(HTML(f\"<img src='data:image/png;base64,{_img}' style='max-width:100%;height:auto;border:1px solid #ccc;border-radius:4px;'/>\"))\n"
|
|
297
|
+
" plt.close()\n"
|
|
298
|
+
"\n"
|
|
299
|
+
"def _pick_df():\n"
|
|
300
|
+
" return globals().get('df', None)\n"
|
|
301
|
+
"\n"
|
|
302
|
+
"def _pick_ax_slot():\n"
|
|
303
|
+
" ax = None\n"
|
|
304
|
+
" try:\n"
|
|
305
|
+
" _axes = globals().get('axes', None)\n"
|
|
306
|
+
" import numpy as _np\n"
|
|
307
|
+
" if _axes is not None:\n"
|
|
308
|
+
" arr = _np.ravel(_axes)\n"
|
|
309
|
+
" for _a in arr:\n"
|
|
310
|
+
" try:\n"
|
|
311
|
+
" if hasattr(_a,'has_data') and not _a.has_data():\n"
|
|
312
|
+
" ax = _a; break\n"
|
|
313
|
+
" except Exception:\n"
|
|
314
|
+
" continue\n"
|
|
315
|
+
" except Exception:\n"
|
|
316
|
+
" ax = None\n"
|
|
317
|
+
" return ax\n"
|
|
318
|
+
"\n"
|
|
319
|
+
"def _first_numeric(_d):\n"
|
|
320
|
+
" import numpy as np, pandas as pd\n"
|
|
321
|
+
" try:\n"
|
|
322
|
+
" preferred = [\"median_house_value\", \"price\", \"value\", \"target\", \"label\", \"y\"]\n"
|
|
323
|
+
" for c in preferred:\n"
|
|
324
|
+
" if c in _d.columns and pd.api.types.is_numeric_dtype(_d[c]):\n"
|
|
325
|
+
" return c\n"
|
|
326
|
+
" cols = _d.select_dtypes(include=[np.number]).columns.tolist()\n"
|
|
327
|
+
" return cols[0] if cols else None\n"
|
|
328
|
+
" except Exception:\n"
|
|
329
|
+
" return None\n"
|
|
330
|
+
"\n"
|
|
331
|
+
"def _first_categorical(_d):\n"
|
|
332
|
+
" import pandas as pd, numpy as np\n"
|
|
333
|
+
" try:\n"
|
|
334
|
+
" num = set(_d.select_dtypes(include=[np.number]).columns.tolist())\n"
|
|
335
|
+
" cand = [c for c in _d.columns if c not in num and _d[c].nunique(dropna=True) <= 50]\n"
|
|
336
|
+
" return cand[0] if cand else None\n"
|
|
337
|
+
" except Exception:\n"
|
|
338
|
+
" return None\n"
|
|
339
|
+
"\n"
|
|
340
|
+
"boxplot = barplot = histplot = distplot = lineplot = countplot = heatmap = pairplot = None\n"
|
|
341
|
+
"\n"
|
|
342
|
+
"def _safe_plot(func, *args, **kwargs):\n"
|
|
343
|
+
" try:\n"
|
|
344
|
+
" ax = func(*args, **kwargs)\n"
|
|
345
|
+
" if ax is None:\n"
|
|
346
|
+
" ax = plt.gca()\n"
|
|
347
|
+
" try:\n"
|
|
348
|
+
" if hasattr(ax, 'has_data') and not ax.has_data():\n"
|
|
349
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
350
|
+
" _show('⚠ Empty plot: no data drawn.')\n"
|
|
351
|
+
" except Exception:\n"
|
|
352
|
+
" pass\n"
|
|
353
|
+
" try: plt.tight_layout()\n"
|
|
354
|
+
" except Exception: pass\n"
|
|
355
|
+
" return ax\n"
|
|
356
|
+
" except Exception as e:\n"
|
|
357
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
358
|
+
" _show(f'⚠ Plot skipped: {type(e).__name__}: {e}')\n"
|
|
359
|
+
" return None\n"
|
|
360
|
+
"\n"
|
|
361
|
+
"def SB_histplot(*a, **k):\n"
|
|
362
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
363
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
364
|
+
" if (not a or _sentinel) and not k:\n"
|
|
365
|
+
" d = _pick_df()\n"
|
|
366
|
+
" if d is not None:\n"
|
|
367
|
+
" x = _first_numeric(d)\n"
|
|
368
|
+
" if x is not None:\n"
|
|
369
|
+
" def _draw():\n"
|
|
370
|
+
" plt.hist(d[x].dropna())\n"
|
|
371
|
+
" ax = plt.gca()\n"
|
|
372
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
373
|
+
" ax.set_title(f'Distribution of {x}')\n"
|
|
374
|
+
" return ax\n"
|
|
375
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
376
|
+
" if _missing:\n"
|
|
377
|
+
" return _safe_plot(lambda **kw: plt.hist([]))\n"
|
|
378
|
+
" if _sentinel:\n"
|
|
379
|
+
" a = a[1:]\n"
|
|
380
|
+
" return _safe_plot(getattr(sns,'histplot', plt.hist), *a, **k)\n"
|
|
381
|
+
"\n"
|
|
382
|
+
"def SB_barplot(*a, **k):\n"
|
|
383
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
384
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
385
|
+
" _ax = k.get('ax') or _pick_ax_slot()\n"
|
|
386
|
+
" if _ax is not None:\n"
|
|
387
|
+
" try: plt.sca(_ax)\n"
|
|
388
|
+
" except Exception: pass\n"
|
|
389
|
+
" k.setdefault('ax', _ax)\n"
|
|
390
|
+
" if (not a or _sentinel) and not k:\n"
|
|
391
|
+
" d = _pick_df()\n"
|
|
392
|
+
" if d is not None:\n"
|
|
393
|
+
" x = _first_categorical(d)\n"
|
|
394
|
+
" y = _first_numeric(d)\n"
|
|
395
|
+
" if x and y:\n"
|
|
396
|
+
" import pandas as _pd\n"
|
|
397
|
+
" g = d.groupby(x)[y].mean().reset_index()\n"
|
|
398
|
+
" def _draw():\n"
|
|
399
|
+
" if _missing:\n"
|
|
400
|
+
" plt.bar(g[x], g[y])\n"
|
|
401
|
+
" else:\n"
|
|
402
|
+
" sns.barplot(data=g, x=x, y=y, ax=k.get('ax'))\n"
|
|
403
|
+
" ax = plt.gca()\n"
|
|
404
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
405
|
+
" ax.set_title(f'Mean {y} by {x}')\n"
|
|
406
|
+
" return ax\n"
|
|
407
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
408
|
+
" if _missing:\n"
|
|
409
|
+
" return _safe_plot(lambda **kw: plt.bar([], []))\n"
|
|
410
|
+
" if _sentinel:\n"
|
|
411
|
+
" a = a[1:]\n"
|
|
412
|
+
" return _safe_plot(sns.barplot, *a, **k)\n"
|
|
413
|
+
"\n"
|
|
414
|
+
"def SB_boxplot(*a, **k):\n"
|
|
415
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
416
|
+
" _sentinel = (len(a) >= 1 and a[0] is None)\n"
|
|
417
|
+
" _ax = k.get('ax') or _pick_ax_slot()\n"
|
|
418
|
+
" if _ax is not None:\n"
|
|
419
|
+
" try: plt.sca(_ax)\n"
|
|
420
|
+
" except Exception: pass\n"
|
|
421
|
+
" k.setdefault('ax', _ax)\n"
|
|
422
|
+
" if (not a or _sentinel) and not k:\n"
|
|
423
|
+
" d = _pick_df()\n"
|
|
424
|
+
" if d is not None:\n"
|
|
425
|
+
" x = _first_categorical(d)\n"
|
|
426
|
+
" y = _first_numeric(d)\n"
|
|
427
|
+
" if x and y:\n"
|
|
428
|
+
" def _draw():\n"
|
|
429
|
+
" if _missing:\n"
|
|
430
|
+
" plt.boxplot(d[y].dropna())\n"
|
|
431
|
+
" else:\n"
|
|
432
|
+
" sns.boxplot(data=d, x=x, y=y, ax=k.get('ax'))\n"
|
|
433
|
+
" ax = plt.gca()\n"
|
|
434
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
435
|
+
" ax.set_title(f'Distribution of {y} by {x}')\n"
|
|
436
|
+
" return ax\n"
|
|
437
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
438
|
+
" if _missing:\n"
|
|
439
|
+
" return _safe_plot(lambda **kw: plt.boxplot([]))\n"
|
|
440
|
+
" if _sentinel:\n"
|
|
441
|
+
" a = a[1:]\n"
|
|
442
|
+
" return _safe_plot(sns.boxplot, *a, **k)\n"
|
|
443
|
+
"\n"
|
|
444
|
+
"def SB_scatterplot(*a, **k):\n"
|
|
445
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
446
|
+
" fn = getattr(sns,'scatterplot', None)\n"
|
|
447
|
+
" # If seaborn is unavailable OR the caller passed (data=..., x='col', y='col'),\n"
|
|
448
|
+
" # use a robust matplotlib path that looks up data and coerces to numeric.\n"
|
|
449
|
+
" if _missing or fn is None:\n"
|
|
450
|
+
" data = k.get('data'); x = k.get('x'); y = k.get('y')\n"
|
|
451
|
+
" if data is not None and isinstance(x, str) and isinstance(y, str) and x in data.columns and y in data.columns:\n"
|
|
452
|
+
" xs = pd.to_numeric(data[x], errors='coerce')\n"
|
|
453
|
+
" ys = pd.to_numeric(data[y], errors='coerce')\n"
|
|
454
|
+
" m = xs.notna() & ys.notna()\n"
|
|
455
|
+
" def _draw():\n"
|
|
456
|
+
" plt.scatter(xs[m], ys[m])\n"
|
|
457
|
+
" ax = plt.gca()\n"
|
|
458
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
459
|
+
" ax.set_title(f'{y} vs {x}')\n"
|
|
460
|
+
" return ax\n"
|
|
461
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
462
|
+
" # else: fall back to auto-pick two numeric columns\n"
|
|
463
|
+
" d = _pick_df()\n"
|
|
464
|
+
" if d is not None:\n"
|
|
465
|
+
" num = d.select_dtypes(include=[np.number]).columns.tolist()\n"
|
|
466
|
+
" if len(num) >= 2:\n"
|
|
467
|
+
" def _draw2():\n"
|
|
468
|
+
" plt.scatter(d[num[0]], d[num[1]])\n"
|
|
469
|
+
" ax = plt.gca()\n"
|
|
470
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
471
|
+
" ax.set_title(f'{num[1]} vs {num[0]}')\n"
|
|
472
|
+
" return ax\n"
|
|
473
|
+
" return _safe_plot(lambda **kw: _draw2())\n"
|
|
474
|
+
" return _safe_plot(lambda **kw: plt.scatter([], []))\n"
|
|
475
|
+
" # seaborn path\n"
|
|
476
|
+
" return _safe_plot(fn, *a, **k)\n"
|
|
477
|
+
"\n"
|
|
478
|
+
"def SB_heatmap(*a, **k):\n"
|
|
479
|
+
" _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
|
|
480
|
+
" data = None\n"
|
|
481
|
+
" if a:\n"
|
|
482
|
+
" data = a[0]\n"
|
|
483
|
+
" elif 'data' in k:\n"
|
|
484
|
+
" data = k['data']\n"
|
|
485
|
+
" if data is None:\n"
|
|
486
|
+
" d = _pick_df()\n"
|
|
487
|
+
" try:\n"
|
|
488
|
+
" if d is not None:\n"
|
|
489
|
+
" import numpy as _np\n"
|
|
490
|
+
" data = d.select_dtypes(include=[_np.number]).corr()\n"
|
|
491
|
+
" except Exception:\n"
|
|
492
|
+
" data = None\n"
|
|
493
|
+
" if data is None:\n"
|
|
494
|
+
" from syntaxmatrix.display import show as _show\n"
|
|
495
|
+
" _show('⚠ Heatmap skipped: no data.')\n"
|
|
496
|
+
" return None\n"
|
|
497
|
+
" if not _missing and hasattr(sns, 'heatmap'):\n"
|
|
498
|
+
" _k = {kk: vv for kk, vv in k.items() if kk != 'data'}\n"
|
|
499
|
+
" def _draw():\n"
|
|
500
|
+
" ax = sns.heatmap(data, **_k)\n"
|
|
501
|
+
" try:\n"
|
|
502
|
+
" ax = ax or plt.gca()\n"
|
|
503
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
504
|
+
" ax.set_title('Correlation Heatmap')\n"
|
|
505
|
+
" except Exception:\n"
|
|
506
|
+
" pass\n"
|
|
507
|
+
" return ax\n"
|
|
508
|
+
" return _safe_plot(lambda **kw: _draw())\n"
|
|
509
|
+
" def _mat_heat():\n"
|
|
510
|
+
" im = plt.imshow(data, aspect='auto')\n"
|
|
511
|
+
" try: plt.colorbar()\n"
|
|
512
|
+
" except Exception: pass\n"
|
|
513
|
+
" try:\n"
|
|
514
|
+
" cols = list(getattr(data, 'columns', []))\n"
|
|
515
|
+
" rows = list(getattr(data, 'index', []))\n"
|
|
516
|
+
" if cols: plt.xticks(range(len(cols)), cols, rotation=90)\n"
|
|
517
|
+
" if rows: plt.yticks(range(len(rows)), rows)\n"
|
|
518
|
+
" except Exception:\n"
|
|
519
|
+
" pass\n"
|
|
520
|
+
" ax = plt.gca()\n"
|
|
521
|
+
" try:\n"
|
|
522
|
+
" if not (ax.get_title() or '').strip():\n"
|
|
523
|
+
" ax.set_title('Correlation Heatmap')\n"
|
|
524
|
+
" except Exception:\n"
|
|
525
|
+
" pass\n"
|
|
526
|
+
" return ax\n"
|
|
527
|
+
" return _safe_plot(lambda **kw: _mat_heat())\n"
|
|
528
|
+
"\n"
|
|
529
|
+
"def _safe_concat(objs, **kwargs):\n"
|
|
530
|
+
" import pandas as _pd\n"
|
|
531
|
+
" if objs is None: return _pd.DataFrame()\n"
|
|
532
|
+
" if isinstance(objs,(list,tuple)) and len(objs)==0: return _pd.DataFrame()\n"
|
|
533
|
+
" try: return _pd.concat(objs, **kwargs)\n"
|
|
534
|
+
" except Exception as e:\n"
|
|
535
|
+
" show(f'⚠ concat skipped: {e}')\n"
|
|
536
|
+
" return _pd.DataFrame()\n"
|
|
537
|
+
"\n"
|
|
538
|
+
"from sklearn.preprocessing import OneHotEncoder\n"
|
|
539
|
+
"import inspect\n"
|
|
540
|
+
"def _SMX_OHE(**k):\n"
|
|
541
|
+
" # normalise arg name across sklearn versions\n"
|
|
542
|
+
" if 'sparse' in k and 'sparse_output' not in k:\n"
|
|
543
|
+
" k['sparse_output'] = k.pop('sparse')\n"
|
|
544
|
+
" k.setdefault('handle_unknown','ignore')\n"
|
|
545
|
+
" k.setdefault('sparse_output', False)\n"
|
|
546
|
+
" try:\n"
|
|
547
|
+
" sig = inspect.signature(OneHotEncoder)\n"
|
|
548
|
+
" if 'sparse_output' not in sig.parameters and 'sparse_output' in k:\n"
|
|
549
|
+
" k['sparse'] = k.pop('sparse_output')\n"
|
|
550
|
+
" except Exception:\n"
|
|
551
|
+
" if 'sparse_output' in k:\n"
|
|
552
|
+
" k['sparse'] = k.pop('sparse_output')\n"
|
|
553
|
+
" return OneHotEncoder(**k)\n"
|
|
554
|
+
"\n"
|
|
555
|
+
"import numpy as _np\n"
|
|
556
|
+
"def _SMX_mm(a, b):\n"
|
|
557
|
+
" try:\n"
|
|
558
|
+
" return a @ b # normal path\n"
|
|
559
|
+
" except Exception:\n"
|
|
560
|
+
" try:\n"
|
|
561
|
+
" A = _np.asarray(a); B = _np.asarray(b)\n"
|
|
562
|
+
" # If same 2D shape (e.g. (n,k) & (n,k)), treat as row-wise dot\n"
|
|
563
|
+
" if A.ndim==2 and B.ndim==2 and A.shape==B.shape:\n"
|
|
564
|
+
" return (A * B).sum(axis=1)\n"
|
|
565
|
+
" # Otherwise try element-wise product (broadcast if possible)\n"
|
|
566
|
+
" return A * B\n"
|
|
567
|
+
" except Exception as e:\n"
|
|
568
|
+
" from syntaxmatrix.display import show\n"
|
|
569
|
+
" show(f'⚠ Matmul relaxed: {type(e).__name__}: {e}'); return _np.nan\n"
|
|
570
|
+
"\n"
|
|
571
|
+
"def _SMX_call(fn, *a, **k):\n"
|
|
572
|
+
" try:\n"
|
|
573
|
+
" return fn(*a, **k)\n"
|
|
574
|
+
" except TypeError as e:\n"
|
|
575
|
+
" msg = str(e)\n"
|
|
576
|
+
" if \"unexpected keyword argument 'squared'\" in msg:\n"
|
|
577
|
+
" k.pop('squared', None)\n"
|
|
578
|
+
" return fn(*a, **k)\n"
|
|
579
|
+
" raise\n"
|
|
580
|
+
"\n"
|
|
581
|
+
"def _SMX_rmse(y_true, y_pred):\n"
|
|
582
|
+
" try:\n"
|
|
583
|
+
" from sklearn.metrics import mean_squared_error as _mse\n"
|
|
584
|
+
" try:\n"
|
|
585
|
+
" return _mse(y_true, y_pred, squared=False)\n"
|
|
586
|
+
" except TypeError:\n"
|
|
587
|
+
" return (_mse(y_true, y_pred)) ** 0.5\n"
|
|
588
|
+
" except Exception:\n"
|
|
589
|
+
" import numpy as _np\n"
|
|
590
|
+
" yt = _np.asarray(y_true, dtype=float)\n"
|
|
591
|
+
" yp = _np.asarray(y_pred, dtype=float)\n"
|
|
592
|
+
" diff = yt - yp\n"
|
|
593
|
+
" return float((_np.mean(diff * diff)) ** 0.5)\n"
|
|
594
|
+
"\n"
|
|
595
|
+
"import pandas as _pd\n"
|
|
596
|
+
"import numpy as _np\n"
|
|
597
|
+
"def _SMX_autocoerce_dates(_df):\n"
|
|
598
|
+
" if _df is None or not hasattr(_df, 'columns'): return\n"
|
|
599
|
+
" for c in list(_df.columns):\n"
|
|
600
|
+
" s = _df[c]\n"
|
|
601
|
+
" n = str(c).lower()\n"
|
|
602
|
+
" if _pd.api.types.is_datetime64_any_dtype(s):\n"
|
|
603
|
+
" continue\n"
|
|
604
|
+
" if _pd.api.types.is_object_dtype(s) or ('date' in n or 'time' in n or 'timestamp' in n or n.endswith('_dt')):\n"
|
|
605
|
+
" try:\n"
|
|
606
|
+
" conv = _pd.to_datetime(s, errors='coerce', utc=True).dt.tz_localize(None)\n"
|
|
607
|
+
" # accept only if at least 10% (min 3) parse as dates\n"
|
|
608
|
+
" if getattr(conv, 'notna', lambda: _pd.Series([]))().sum() >= max(3, int(0.1*len(_df))):\n"
|
|
609
|
+
" _df[c] = conv\n"
|
|
610
|
+
" except Exception:\n"
|
|
611
|
+
" pass\n"
|
|
612
|
+
"\n"
|
|
613
|
+
"def _SMX_autocoerce_numeric(_df, cols):\n"
|
|
614
|
+
" if _df is None: return\n"
|
|
615
|
+
" for c in cols:\n"
|
|
616
|
+
" if c in getattr(_df, 'columns', []):\n"
|
|
617
|
+
" try:\n"
|
|
618
|
+
" _df[c] = _pd.to_numeric(_df[c], errors='coerce')\n"
|
|
619
|
+
" except Exception:\n"
|
|
620
|
+
" pass\n"
|
|
621
|
+
"\n"
|
|
622
|
+
"def show(obj, title=None):\n"
|
|
623
|
+
" try:\n"
|
|
624
|
+
" import pandas as pd, numbers\n"
|
|
625
|
+
" cap = (title or _SMX_caption_from_ctx())\n"
|
|
626
|
+
" # 1) DataFrame → Styler with caption\n"
|
|
627
|
+
" if isinstance(obj, pd.DataFrame):\n"
|
|
628
|
+
" try: return _SMX_base_show(obj.style.set_caption(cap))\n"
|
|
629
|
+
" except Exception: pass\n"
|
|
630
|
+
" # 2) dict of scalars → DataFrame with caption\n"
|
|
631
|
+
" if isinstance(obj, dict) and all(isinstance(v, numbers.Number) for v in obj.values()):\n"
|
|
632
|
+
" df_ = pd.DataFrame({'metric': list(obj.keys()), 'value': list(obj.values())})\n"
|
|
633
|
+
" try: return _SMX_base_show(df_.style.set_caption(cap))\n"
|
|
634
|
+
" except Exception: return _SMX_base_show(df_)\n"
|
|
635
|
+
" except Exception:\n"
|
|
636
|
+
" pass\n"
|
|
637
|
+
" return _SMX_base_show(obj)\n"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
PREFACE_IMPORT = "from syntaxmatrix.smx_preface import *\n"
|
|
641
|
+
# if PREFACE not in code:
|
|
642
|
+
# code = PREFACE_IMPORT + code
|
|
643
|
+
|
|
644
|
+
fixed = code
|
|
645
|
+
|
|
646
|
+
fixed = re.sub(
|
|
647
|
+
r"(?s)^\s*try:\s*(.*?)\s*except\s+Exception\s+as\s+\w+:\s*\n\s*show\([^\n]*\)\s*$",
|
|
648
|
+
r"\1",
|
|
649
|
+
fixed.strip()
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# 1) Strip numeric_only=... (version-agnostic)
|
|
653
|
+
fixed = re.sub(r",\s*numeric_only\s*=\s*(True|False|None)", "", fixed, flags=re.I)
|
|
654
|
+
fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\s*,\s*", "", fixed, flags=re.I)
|
|
655
|
+
fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\b", "", fixed, flags=re.I)
|
|
656
|
+
|
|
657
|
+
# 2) Use safe seaborn wrappers
|
|
658
|
+
fixed = re.sub(r"\bsns\.boxplot\s*\(", "SB_boxplot(", fixed)
|
|
659
|
+
fixed = re.sub(r"\bsns\.barplot\s*\(", "SB_barplot(", fixed)
|
|
660
|
+
fixed = re.sub(r"\bsns\.histplot\s*\(", "SB_histplot(", fixed)
|
|
661
|
+
fixed = re.sub(r"\bsns\.scatterplot\s*\(", "SB_scatterplot(", fixed)
|
|
662
|
+
|
|
663
|
+
# 3) Guard concat calls
|
|
664
|
+
fixed = re.sub(r"\bpd\.concat\s*\(", "_safe_concat(", fixed)
|
|
665
|
+
fixed = re.sub(r"\bOneHotEncoder\s*\(", "_SMX_OHE(", fixed)
|
|
666
|
+
# Route np.dot to tolerant matmul
|
|
667
|
+
fixed = re.sub(r"\bnp\.dot\s*\(", "_SMX_mm(", fixed)
|
|
668
|
+
fixed = re.sub(r"(df\s*\[[^\]]+\])\s*\.dt", r"SMX_dt(\1).dt", fixed)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
# 4) Relax any 'required_cols' hard failure blocks
|
|
672
|
+
fixed = re.sub(
|
|
673
|
+
r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
674
|
+
"required_cols = [c for c in df.columns]\n# (relaxed by SMX hardener)",
|
|
675
|
+
fixed,
|
|
676
|
+
flags=re.S,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# 5) Make static numeric_vars lists dynamic
|
|
680
|
+
fixed = re.sub(
|
|
681
|
+
r"\bnumeric_vars\s*=\s*\[.*?\]",
|
|
682
|
+
"numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
683
|
+
fixed,
|
|
684
|
+
flags=re.S,
|
|
685
|
+
)
|
|
686
|
+
# normalise all .dt usages on df[...] / df.attr / df.loc[...] to SMX_dt(...)
|
|
687
|
+
fixed = re.sub(
|
|
688
|
+
r"((?:df\s*(?:\.\s*(?:loc|iloc)\s*)?\[[^\]]+\]|df\s*\.\s*[A-Za-z_]\w*))\s*\.dt\b",
|
|
689
|
+
lambda m: f"SMX_dt({m.group(1)}).dt",
|
|
690
|
+
fixed
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
try:
|
|
694
|
+
class _SMXMatmulRewriter(ast.NodeTransformer):
|
|
695
|
+
def visit_BinOp(self, node):
|
|
696
|
+
self.generic_visit(node)
|
|
697
|
+
if isinstance(node.op, ast.MatMult):
|
|
698
|
+
return ast.Call(func=ast.Name(id="_SMX_mm", ctx=ast.Load()),
|
|
699
|
+
args=[node.left, node.right], keywords=[])
|
|
700
|
+
return node
|
|
701
|
+
_tree = ast.parse(fixed)
|
|
702
|
+
_tree = _SMXMatmulRewriter().visit(_tree)
|
|
703
|
+
fixed = ast.unparse(_tree)
|
|
704
|
+
except Exception:
|
|
705
|
+
# If AST rewrite fails, keep original; _SMX_mm will still handle np.dot(...)
|
|
706
|
+
pass
|
|
707
|
+
|
|
708
|
+
# 6) Final safety wrapper
|
|
709
|
+
fixed = fixed.replace("\t", " ")
|
|
710
|
+
fixed = textwrap.dedent(fixed).strip("\n")
|
|
711
|
+
|
|
712
|
+
fixed = _strip_stray_backrefs(fixed)
|
|
713
|
+
fixed = _wrap_metric_calls(fixed)
|
|
714
|
+
|
|
715
|
+
# If the transformed code is still not syntactically valid, fall back to a
|
|
716
|
+
# very defensive generic snippet that depends only on `df`. This guarantees
|
|
717
|
+
try:
|
|
718
|
+
ast.parse(fixed)
|
|
719
|
+
except (SyntaxError, IndentationError):
|
|
720
|
+
fixed = (
|
|
721
|
+
"import pandas as pd\n"
|
|
722
|
+
"df = df.copy()\n"
|
|
723
|
+
"_info = {\n"
|
|
724
|
+
" 'rows': len(df),\n"
|
|
725
|
+
" 'cols': len(df.columns),\n"
|
|
726
|
+
" 'numeric_cols': len(df.select_dtypes(include=['number','bool']).columns),\n"
|
|
727
|
+
" 'categorical_cols': len(df.select_dtypes(exclude=['number','bool']).columns),\n"
|
|
728
|
+
"}\n"
|
|
729
|
+
"show(df.head(), title='Sample of data')\n"
|
|
730
|
+
"show(_info, title='Dataset summary')\n"
|
|
731
|
+
"try:\n"
|
|
732
|
+
" _num = df.select_dtypes(include=['number','bool']).columns.tolist()\n"
|
|
733
|
+
" if _num:\n"
|
|
734
|
+
" SB_histplot()\n"
|
|
735
|
+
" _SMX_export_png()\n"
|
|
736
|
+
"except Exception as e:\n"
|
|
737
|
+
" show(f\"⚠ Fallback visualisation failed: {type(e).__name__}: {e}\")\n"
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Fix placeholder Ellipsis handlers from LLM
|
|
741
|
+
fixed = re.sub(
|
|
742
|
+
r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
|
|
743
|
+
"except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
|
|
744
|
+
fixed,
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
wrapped = PREFACE + "try:\n" + _indent(fixed) + "\nexcept Exception as e:\n show(...)\n"
|
|
748
|
+
wrapped = wrapped.lstrip()
|
|
749
|
+
return wrapped
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def indent_code(code: str, spaces: int = 4) -> str:
|
|
753
|
+
pad = " " * spaces
|
|
754
|
+
return "\n".join(pad + line for line in code.splitlines())
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
def wrap_llm_code_safe(code: str) -> str:
|
|
758
|
+
# Swallow any runtime error from the LLM block instead of crashing the run
|
|
759
|
+
return (
|
|
760
|
+
"# __SAFE_WRAPPED__\n"
|
|
761
|
+
"try:\n" + indent_code(code) + "\n"
|
|
762
|
+
"except Exception as e:\n"
|
|
763
|
+
" from syntaxmatrix.display import show\n"
|
|
764
|
+
" show(f\"⚠️ Skipped LLM block due to: {type(e).__name__}: {e}\")\n"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def fix_boxplot_placeholder(code: str) -> str:
|
|
769
|
+
# Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
|
|
770
|
+
return re.sub(
|
|
771
|
+
r"sns\.boxplot\(\s*boxplot\s*\)",
|
|
772
|
+
"sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
|
|
773
|
+
code
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def relax_required_columns(code: str) -> str:
|
|
778
|
+
# Remove hard failure on required_cols; keep a soft filter instead
|
|
779
|
+
return re.sub(
|
|
780
|
+
r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
781
|
+
"required_cols = [c for c in df.columns]\n",
|
|
782
|
+
code,
|
|
783
|
+
flags=re.S
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def make_numeric_vars_dynamic(code: str) -> str:
|
|
788
|
+
# Replace any static numeric_vars list with a dynamic selection
|
|
789
|
+
return re.sub(
|
|
790
|
+
r"numeric_vars\s*=\s*\[.*?\]",
|
|
791
|
+
"numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
792
|
+
code,
|
|
793
|
+
flags=re.S
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def auto_inject_template(code: str, intents, df) -> str:
|
|
798
|
+
"""If the LLM forgot the core logic, prepend a skeleton block."""
|
|
799
|
+
|
|
800
|
+
has_fit = ".fit(" in code
|
|
801
|
+
has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
|
|
802
|
+
|
|
803
|
+
UNKNOWN_TOKENS = {
|
|
804
|
+
"unknown","not reported","not_reported","not known","n/a","na",
|
|
805
|
+
"none","nan","missing","unreported","unspecified","null","-",""
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
# --- Safe template caller: passes only supported kwargs, falls back cleanly ---
|
|
809
|
+
def _call_template(func, df, **hints):
|
|
810
|
+
import inspect
|
|
811
|
+
try:
|
|
812
|
+
params = inspect.signature(func).parameters
|
|
813
|
+
kw = {k: v for k, v in hints.items() if k in params}
|
|
814
|
+
try:
|
|
815
|
+
return func(df, **kw)
|
|
816
|
+
except TypeError:
|
|
817
|
+
# In case the template changed its signature at runtime
|
|
818
|
+
return func(df)
|
|
819
|
+
except Exception:
|
|
820
|
+
# Absolute safety net
|
|
821
|
+
try:
|
|
822
|
+
return func(df)
|
|
823
|
+
except Exception:
|
|
824
|
+
# As a last resort, return empty code so we don't 500
|
|
825
|
+
return ""
|
|
826
|
+
|
|
827
|
+
def _guess_classification_target(df: pd.DataFrame) -> str | None:
|
|
828
|
+
cols = list(df.columns)
|
|
829
|
+
|
|
830
|
+
# Helper: does this column look like a sensible label?
|
|
831
|
+
def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
|
|
832
|
+
try:
|
|
833
|
+
nunq = s.dropna().nunique()
|
|
834
|
+
except Exception:
|
|
835
|
+
return False
|
|
836
|
+
# need at least 2 classes, but not hundreds
|
|
837
|
+
if nunq < 2 or nunq > 20:
|
|
838
|
+
return False
|
|
839
|
+
bad_name_keys = ("id", "identifier", "index", "uuid", "key")
|
|
840
|
+
name = str(col_name).lower()
|
|
841
|
+
if any(k in name for k in bad_name_keys):
|
|
842
|
+
return False
|
|
843
|
+
return True
|
|
844
|
+
|
|
845
|
+
# 1) columns whose names look like labels
|
|
846
|
+
label_keys = ("target", "label", "outcome", "class", "y", "status")
|
|
847
|
+
name_candidates: list[str] = []
|
|
848
|
+
for key in label_keys:
|
|
849
|
+
for c in cols:
|
|
850
|
+
if key in str(c).lower():
|
|
851
|
+
name_candidates.append(c)
|
|
852
|
+
if name_candidates:
|
|
853
|
+
break # keep the earliest matching key-group
|
|
854
|
+
|
|
855
|
+
# prioritise name-based candidates that also look like proper label columns
|
|
856
|
+
for c in name_candidates:
|
|
857
|
+
if _is_reasonable_class_col(df[c], c):
|
|
858
|
+
return c
|
|
859
|
+
if name_candidates:
|
|
860
|
+
# fall back to the first name-based candidate if none passed the shape test
|
|
861
|
+
return name_candidates[0]
|
|
862
|
+
|
|
863
|
+
# 2) any column with a small number of distinct values (likely a class label)
|
|
864
|
+
for c in cols:
|
|
865
|
+
s = df[c]
|
|
866
|
+
if _is_reasonable_class_col(s, c):
|
|
867
|
+
return c
|
|
868
|
+
|
|
869
|
+
# Nothing suitable found
|
|
870
|
+
return None
|
|
871
|
+
|
|
872
|
+
def _guess_regression_target(df: pd.DataFrame) -> str | None:
|
|
873
|
+
num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
|
|
874
|
+
if not num_cols:
|
|
875
|
+
return None
|
|
876
|
+
# Avoid obvious ID-like columns
|
|
877
|
+
bad_keys = ("id", "identifier", "index")
|
|
878
|
+
candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
|
|
879
|
+
return (candidates or num_cols)[-1]
|
|
880
|
+
|
|
881
|
+
def _guess_time_col(df: pd.DataFrame) -> str | None:
|
|
882
|
+
# Prefer actual datetime dtype
|
|
883
|
+
dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
|
|
884
|
+
if dt_cols:
|
|
885
|
+
return dt_cols[0]
|
|
886
|
+
|
|
887
|
+
# Fallback: name-based hints
|
|
888
|
+
name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
|
|
889
|
+
for c in df.columns:
|
|
890
|
+
name = str(c).lower()
|
|
891
|
+
if any(k in name for k in name_keys):
|
|
892
|
+
return c
|
|
893
|
+
return None
|
|
894
|
+
|
|
895
|
+
def _guess_entity_col(df: pd.DataFrame) -> str | None:
|
|
896
|
+
# Typical sequence IDs: id, patient, subject, device, series, entity
|
|
897
|
+
keys = ["id", "patient", "subject", "device", "series", "entity"]
|
|
898
|
+
candidates = []
|
|
899
|
+
for c in df.columns:
|
|
900
|
+
name = str(c).lower()
|
|
901
|
+
if any(k in name for k in keys):
|
|
902
|
+
candidates.append(c)
|
|
903
|
+
return candidates[0] if candidates else None
|
|
904
|
+
|
|
905
|
+
def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
|
|
906
|
+
# Try label-like names first
|
|
907
|
+
keys = ["target", "label", "class", "outcome", "y"]
|
|
908
|
+
for key in keys:
|
|
909
|
+
for c in df.columns:
|
|
910
|
+
if key in str(c).lower():
|
|
911
|
+
return c
|
|
912
|
+
|
|
913
|
+
# Fallback: any column with few distinct values (e.g. <= 10)
|
|
914
|
+
for c in df.columns:
|
|
915
|
+
s = df[c]
|
|
916
|
+
# avoid obvious IDs
|
|
917
|
+
if any(k in str(c).lower() for k in ["id", "index"]):
|
|
918
|
+
continue
|
|
919
|
+
try:
|
|
920
|
+
nunq = s.dropna().nunique()
|
|
921
|
+
except Exception:
|
|
922
|
+
continue
|
|
923
|
+
if 1 < nunq <= 10:
|
|
924
|
+
return c
|
|
925
|
+
|
|
926
|
+
return None
|
|
927
|
+
|
|
928
|
+
def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
|
|
929
|
+
cols = list(df.columns)
|
|
930
|
+
lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
|
|
931
|
+
# also include boolean/binary columns with suitable names
|
|
932
|
+
for c in cols:
|
|
933
|
+
s = df[c]
|
|
934
|
+
try:
|
|
935
|
+
nunq = s.dropna().nunique()
|
|
936
|
+
except Exception:
|
|
937
|
+
continue
|
|
938
|
+
if nunq in (2,) and c not in lbl_like:
|
|
939
|
+
# avoid obvious IDs
|
|
940
|
+
if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
|
|
941
|
+
lbl_like.append(c)
|
|
942
|
+
# keep at most, say, 12 to avoid accidental flood
|
|
943
|
+
return lbl_like[:12]
|
|
944
|
+
|
|
945
|
+
def _find_unknownish_column(df: pd.DataFrame) -> str | None:
|
|
946
|
+
# Search categorical-like columns for any 'unknown-like' values or high missingness
|
|
947
|
+
candidates = []
|
|
948
|
+
for c in df.columns:
|
|
949
|
+
s = df[c]
|
|
950
|
+
# focus on object/category/boolean-ish or low-card columns
|
|
951
|
+
if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
|
|
952
|
+
continue
|
|
953
|
+
try:
|
|
954
|
+
vals = s.astype(str).str.strip().str.lower()
|
|
955
|
+
except Exception:
|
|
956
|
+
continue
|
|
957
|
+
# score: presence of unknown tokens + missing rate
|
|
958
|
+
token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
|
|
959
|
+
miss_rate = s.isna().mean()
|
|
960
|
+
name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
|
|
961
|
+
score = 3*token_hit + 2*name_bonus + miss_rate
|
|
962
|
+
if token_hit or miss_rate > 0.05 or name_bonus:
|
|
963
|
+
candidates.append((score, c))
|
|
964
|
+
if not candidates:
|
|
965
|
+
return None
|
|
966
|
+
candidates.sort(reverse=True)
|
|
967
|
+
return candidates[0][1]
|
|
968
|
+
|
|
969
|
+
def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
|
|
970
|
+
cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
|
|
971
|
+
# prefer non-constant columns
|
|
972
|
+
scored = []
|
|
973
|
+
for c in cols:
|
|
974
|
+
try:
|
|
975
|
+
v = df[c].dropna()
|
|
976
|
+
var = float(v.var()) if len(v) else 0.0
|
|
977
|
+
scored.append((var, c))
|
|
978
|
+
except Exception:
|
|
979
|
+
continue
|
|
980
|
+
scored.sort(reverse=True)
|
|
981
|
+
return [c for _, c in scored[:max_n]]
|
|
982
|
+
|
|
983
|
+
def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
|
|
984
|
+
exclude = exclude or set()
|
|
985
|
+
picks = []
|
|
986
|
+
for c in df.columns:
|
|
987
|
+
if c in exclude:
|
|
988
|
+
continue
|
|
989
|
+
s = df[c]
|
|
990
|
+
if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
|
|
991
|
+
nunq = s.dropna().nunique()
|
|
992
|
+
if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
|
|
993
|
+
picks.append((nunq, c))
|
|
994
|
+
picks.sort(reverse=True)
|
|
995
|
+
return [c for _, c in picks[:max_n]]
|
|
996
|
+
|
|
997
|
+
def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
|
|
998
|
+
exclude = exclude or set()
|
|
999
|
+
# name hints first
|
|
1000
|
+
name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
|
|
1001
|
+
for c in df.columns:
|
|
1002
|
+
if c in exclude:
|
|
1003
|
+
continue
|
|
1004
|
+
name = str(c).lower()
|
|
1005
|
+
if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
|
|
1006
|
+
return c
|
|
1007
|
+
# fallback: any binary numeric
|
|
1008
|
+
for c in df.select_dtypes(include=[np.number, "bool"]).columns:
|
|
1009
|
+
if c in exclude:
|
|
1010
|
+
continue
|
|
1011
|
+
try:
|
|
1012
|
+
if df[c].dropna().nunique() == 2:
|
|
1013
|
+
return c
|
|
1014
|
+
except Exception:
|
|
1015
|
+
continue
|
|
1016
|
+
return None
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def _pick_viz_template(signal: str):
|
|
1020
|
+
s = signal.lower()
|
|
1021
|
+
|
|
1022
|
+
# explicit chart requests
|
|
1023
|
+
if any(k in s for k in ("pie", "donut")):
|
|
1024
|
+
return viz_pie
|
|
1025
|
+
|
|
1026
|
+
if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
|
|
1027
|
+
return viz_stacked_bar
|
|
1028
|
+
|
|
1029
|
+
if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
|
|
1030
|
+
return viz_distribution
|
|
1031
|
+
|
|
1032
|
+
if any(k in s for k in ("kde", "density")):
|
|
1033
|
+
return viz_kde
|
|
1034
|
+
|
|
1035
|
+
# these three you asked about
|
|
1036
|
+
if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
|
|
1037
|
+
return viz_box
|
|
1038
|
+
|
|
1039
|
+
if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
|
|
1040
|
+
return viz_scatter
|
|
1041
|
+
|
|
1042
|
+
if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
|
|
1043
|
+
return viz_count_bar
|
|
1044
|
+
|
|
1045
|
+
if any(k in s for k in ("area", "trend", "over time", "time series")):
|
|
1046
|
+
return viz_area
|
|
1047
|
+
|
|
1048
|
+
# fallback
|
|
1049
|
+
return viz_line
|
|
1050
|
+
|
|
1051
|
+
for intent in intents:
|
|
1052
|
+
|
|
1053
|
+
if intent not in INJECTABLE_INTENTS:
|
|
1054
|
+
return code
|
|
1055
|
+
|
|
1056
|
+
# Correlation analysis
|
|
1057
|
+
if intent == "correlation_analysis" and not has_fit:
|
|
1058
|
+
return eda_correlation(df) + "\n\n" + code
|
|
1059
|
+
|
|
1060
|
+
# Generic visualisation (keyword-based)
|
|
1061
|
+
if intent == "visualisation" and not has_fit and not has_plot:
|
|
1062
|
+
rq = str(globals().get("refined_question", ""))
|
|
1063
|
+
# aq = str(globals().get("askai_question", ""))
|
|
1064
|
+
signal = rq + "\n" + str(intents) + "\n" + code
|
|
1065
|
+
tpl = _pick_viz_template(signal)
|
|
1066
|
+
return tpl(df) + "\n\n" + code
|
|
1067
|
+
|
|
1068
|
+
if intent == "clustering" and not has_fit:
|
|
1069
|
+
return clustering(df) + "\n\n" + code
|
|
1070
|
+
|
|
1071
|
+
if intent == "recommendation" and not has_fit:
|
|
1072
|
+
return recommendation(df) + "\\n\\n" + code
|
|
1073
|
+
|
|
1074
|
+
if intent == "topic_modelling" and not has_fit:
|
|
1075
|
+
return topic_modelling(df) + "\\n\\n" + code
|
|
1076
|
+
|
|
1077
|
+
if intent == "eda" and not has_fit:
|
|
1078
|
+
return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
|
|
1079
|
+
|
|
1080
|
+
# --- Classification ------------------------------------------------
|
|
1081
|
+
if intent == "classification" and not has_fit:
|
|
1082
|
+
target = _guess_classification_target(df)
|
|
1083
|
+
if target:
|
|
1084
|
+
return classification(df) + "\n\n" + code
|
|
1085
|
+
# return _call_template(classification, df, target) + "\n\n" + code
|
|
1086
|
+
|
|
1087
|
+
# --- Regression ----------------------------------------------------
|
|
1088
|
+
if intent == "regression" and not has_fit:
|
|
1089
|
+
target = _guess_regression_target(df)
|
|
1090
|
+
if target:
|
|
1091
|
+
return regression(df) + "\n\n" + code
|
|
1092
|
+
# return _call_template(regression, df, target) + "\n\n" + code
|
|
1093
|
+
|
|
1094
|
+
# --- Anomaly detection --------------------------------------------
|
|
1095
|
+
if intent == "anomaly_detection":
|
|
1096
|
+
uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
|
|
1097
|
+
if not uses_anomaly:
|
|
1098
|
+
return anomaly_detection(df) + "\n\n" + code
|
|
1099
|
+
|
|
1100
|
+
# --- Time-series anomaly detection --------------------------------
|
|
1101
|
+
if intent == "ts_anomaly_detection":
|
|
1102
|
+
uses_ts = "STL(" in code or "seasonal_decompose(" in code
|
|
1103
|
+
if not uses_ts:
|
|
1104
|
+
return ts_anomaly_detection(df) + "\n\n" + code
|
|
1105
|
+
|
|
1106
|
+
# --- Time-series classification --------------------------------
|
|
1107
|
+
if intent == "time_series_classification" and not has_fit:
|
|
1108
|
+
time_col = _guess_time_col(df)
|
|
1109
|
+
entity_col = _guess_entity_col(df)
|
|
1110
|
+
target_col = _guess_ts_class_target(df)
|
|
1111
|
+
|
|
1112
|
+
# If we can't confidently identify these, do NOT inject anything
|
|
1113
|
+
if time_col and entity_col and target_col:
|
|
1114
|
+
return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
|
|
1115
|
+
|
|
1116
|
+
# --- Dimensionality reduction --------------------------------------
|
|
1117
|
+
if intent == "dimensionality_reduction":
|
|
1118
|
+
uses_dr = any(k in code for k in ("PCA(", "TSNE("))
|
|
1119
|
+
if not uses_dr:
|
|
1120
|
+
return dimensionality_reduction(df) + "\n\n" + code
|
|
1121
|
+
|
|
1122
|
+
# --- Feature selection ---------------------------------------------
|
|
1123
|
+
if intent == "feature_selection":
|
|
1124
|
+
uses_fs = any(k in code for k in (
|
|
1125
|
+
"mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
|
|
1126
|
+
))
|
|
1127
|
+
if not uses_fs:
|
|
1128
|
+
return feature_selection(df) + "\n\n" + code
|
|
1129
|
+
|
|
1130
|
+
# --- EDA / correlation / visualisation -----------------------------
|
|
1131
|
+
if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
|
|
1132
|
+
if intent == "correlation_analysis":
|
|
1133
|
+
return eda_correlation(df) + "\n\n" + code
|
|
1134
|
+
else:
|
|
1135
|
+
return eda_overview(df) + "\n\n" + code
|
|
1136
|
+
|
|
1137
|
+
# --- Time-series forecasting ---------------------------------------
|
|
1138
|
+
if intent == "time_series_forecasting" and not has_fit:
|
|
1139
|
+
uses_ts_forecast = any(k in code for k in (
|
|
1140
|
+
"ARIMA", "ExponentialSmoothing", "forecast", "predict("
|
|
1141
|
+
))
|
|
1142
|
+
if not uses_ts_forecast:
|
|
1143
|
+
return time_series_forecasting(df) + "\n\n" + code
|
|
1144
|
+
|
|
1145
|
+
# --- Multi-label classification -----------------------------------
|
|
1146
|
+
if intent in ("multilabel_classification",) and not has_fit:
|
|
1147
|
+
label_cols = _guess_multilabel_cols(df)
|
|
1148
|
+
if len(label_cols) >= 2:
|
|
1149
|
+
return multilabel_classification(df, label_cols) + "\n\n" + code
|
|
1150
|
+
|
|
1151
|
+
group_col = _find_unknownish_column(df)
|
|
1152
|
+
if group_col:
|
|
1153
|
+
num_cols = _guess_numeric_cols(df)
|
|
1154
|
+
cat_cols = _guess_categorical_cols(df, exclude={group_col})
|
|
1155
|
+
outcome_col = None # generic; let template skip if not present
|
|
1156
|
+
tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
|
|
1157
|
+
|
|
1158
|
+
# Return template + guarded (repaired) LLM code, so it never crashes
|
|
1159
|
+
repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
|
|
1160
|
+
return tpl + "\n\n" + wrap_llm_code_safe(repaired)
|
|
1161
|
+
|
|
1162
|
+
return code
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1166
|
+
"""
|
|
1167
|
+
If a previous pass injected numeric_only=True into a NumPy-style sum,
|
|
1168
|
+
e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
|
|
1169
|
+
"""
|
|
1170
|
+
# .values.sum(numeric_only=True, ...)
|
|
1171
|
+
code = re.sub(
|
|
1172
|
+
r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1173
|
+
".to_numpy().sum()",
|
|
1174
|
+
code,
|
|
1175
|
+
flags=re.IGNORECASE,
|
|
1176
|
+
)
|
|
1177
|
+
# .to_numpy().sum(numeric_only=True, ...)
|
|
1178
|
+
code = re.sub(
|
|
1179
|
+
r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1180
|
+
".to_numpy().sum()",
|
|
1181
|
+
code,
|
|
1182
|
+
flags=re.IGNORECASE,
|
|
1183
|
+
)
|
|
1184
|
+
return code
|
|
1185
|
+
|
|
1186
|
+
|
|
13
1187
|
def strip_describe_slice(code: str) -> str:
|
|
14
1188
|
"""
|
|
15
1189
|
Remove any pattern like df.groupby(...).describe()[[ ... ]] because
|
|
@@ -23,10 +1197,12 @@ def strip_describe_slice(code: str) -> str:
|
|
|
23
1197
|
)
|
|
24
1198
|
return pat.sub(r"\1)", code)
|
|
25
1199
|
|
|
1200
|
+
|
|
26
1201
|
def remove_plt_show(code: str) -> str:
|
|
27
1202
|
"""Removes all plt.show() calls from the generated code string."""
|
|
28
1203
|
return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
|
|
29
1204
|
|
|
1205
|
+
|
|
30
1206
|
def patch_plot_with_table(code: str) -> str:
|
|
31
1207
|
"""
|
|
32
1208
|
▸ strips every `plt.show()` (avoids warnings)
|
|
@@ -113,7 +1289,7 @@ def patch_plot_with_table(code: str) -> str:
|
|
|
113
1289
|
")\n"
|
|
114
1290
|
)
|
|
115
1291
|
|
|
116
|
-
tbl_block += "
|
|
1292
|
+
tbl_block += "show(summary_table, title='Summary Statistics')"
|
|
117
1293
|
|
|
118
1294
|
# 5. inject image-export block, then table block, after the plot
|
|
119
1295
|
patched = (
|
|
@@ -253,10 +1429,10 @@ def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
|
253
1429
|
"Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
|
|
254
1430
|
)
|
|
255
1431
|
|
|
256
|
-
|
|
257
1432
|
# 9. Fallback: return the raw question
|
|
258
1433
|
return q
|
|
259
1434
|
|
|
1435
|
+
|
|
260
1436
|
def patch_plot_code(code, df, user_question=None):
|
|
261
1437
|
|
|
262
1438
|
# ── Early guard: abort nicely if the generated code references columns that
|
|
@@ -277,10 +1453,13 @@ def patch_plot_code(code, df, user_question=None):
|
|
|
277
1453
|
|
|
278
1454
|
if missing_cols:
|
|
279
1455
|
cols_list = ", ".join(missing_cols)
|
|
280
|
-
|
|
281
|
-
f"
|
|
282
|
-
|
|
1456
|
+
warning = (
|
|
1457
|
+
f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
|
|
1458
|
+
"These must either exist in df or be created earlier in the code; "
|
|
1459
|
+
"otherwise you may see a KeyError.')\n"
|
|
283
1460
|
)
|
|
1461
|
+
# Prepend the warning but keep the original code so it can still run
|
|
1462
|
+
code = warning + code
|
|
284
1463
|
|
|
285
1464
|
# 1. For line plots (auto-aggregate)
|
|
286
1465
|
m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
@@ -389,6 +1568,16 @@ def patch_plot_code(code, df, user_question=None):
|
|
|
389
1568
|
# Fallback: Return original code
|
|
390
1569
|
return code
|
|
391
1570
|
|
|
1571
|
+
|
|
1572
|
+
def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1573
|
+
import re
|
|
1574
|
+
makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
|
|
1575
|
+
has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
|
|
1576
|
+
if makes_plot and not has_title:
|
|
1577
|
+
code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
|
|
1578
|
+
return code
|
|
1579
|
+
|
|
1580
|
+
|
|
392
1581
|
def ensure_output(code: str) -> str:
|
|
393
1582
|
"""
|
|
394
1583
|
Guarantees that AI-generated code actually surfaces results in the UI
|
|
@@ -405,7 +1594,6 @@ def ensure_output(code: str) -> str:
|
|
|
405
1594
|
# not a comment / print / assignment / pyplot call
|
|
406
1595
|
if (last and not last.startswith(("print(", "plt.", "#")) and "=" not in last):
|
|
407
1596
|
lines[-1] = f"_out = {last}"
|
|
408
|
-
lines.append("from syntaxmatrix.display import show")
|
|
409
1597
|
lines.append("show(_out)")
|
|
410
1598
|
|
|
411
1599
|
# ── 3· auto-surface common stats tuples (stat, p) ───────────────────
|
|
@@ -413,14 +1601,12 @@ def ensure_output(code: str) -> str:
|
|
|
413
1601
|
if re.search(r"\bchi2\s*,\s*p\s*,", code) and "show((" in code:
|
|
414
1602
|
pass # AI already shows the tuple
|
|
415
1603
|
elif re.search(r"\bchi2\s*,\s*p\s*,", code):
|
|
416
|
-
lines.append("from syntaxmatrix.display import show")
|
|
417
1604
|
lines.append("show((chi2, p))")
|
|
418
1605
|
|
|
419
1606
|
# ── 4· classification report (string) ───────────────────────────────
|
|
420
1607
|
cr_match = re.search(r"^\s*(\w+)\s*=\s*classification_report\(", code, re.M)
|
|
421
1608
|
if cr_match and f"show({cr_match.group(1)})" not in "\n".join(lines):
|
|
422
1609
|
var = cr_match.group(1)
|
|
423
|
-
lines.append("from syntaxmatrix.display import show")
|
|
424
1610
|
lines.append(f"show({var})")
|
|
425
1611
|
|
|
426
1612
|
# 5-bis · pivot tables (DataFrame)
|
|
@@ -457,18 +1643,17 @@ def ensure_output(code: str) -> str:
|
|
|
457
1643
|
assign_scalar = re.match(r"\s*(\w+)\s*=\s*.+\.shape\[\s*0\s*\]\s*$", lines[-1])
|
|
458
1644
|
if assign_scalar:
|
|
459
1645
|
var = assign_scalar.group(1)
|
|
460
|
-
lines.append("from syntaxmatrix.display import show")
|
|
461
1646
|
lines.append(f"show({var})")
|
|
462
1647
|
|
|
463
1648
|
# ── 8. utils.ensure_output()
|
|
464
1649
|
assign_df = re.match(r"\s*(\w+)\s*=\s*df\[", lines[-1])
|
|
465
1650
|
if assign_df:
|
|
466
1651
|
var = assign_df.group(1)
|
|
467
|
-
lines.append("from syntaxmatrix.display import show")
|
|
468
1652
|
lines.append(f"show({var})")
|
|
469
1653
|
|
|
470
1654
|
return "\n".join(lines)
|
|
471
1655
|
|
|
1656
|
+
|
|
472
1657
|
def get_plotting_imports(code):
|
|
473
1658
|
imports = []
|
|
474
1659
|
if "plt." in code and "import matplotlib.pyplot as plt" not in code:
|
|
@@ -488,6 +1673,7 @@ def get_plotting_imports(code):
|
|
|
488
1673
|
code = "\n".join(imports) + "\n\n" + code
|
|
489
1674
|
return code
|
|
490
1675
|
|
|
1676
|
+
|
|
491
1677
|
def patch_pairplot(code, df):
|
|
492
1678
|
if "sns.pairplot" in code:
|
|
493
1679
|
# Always assign and print pairgrid
|
|
@@ -498,29 +1684,82 @@ def patch_pairplot(code, df):
|
|
|
498
1684
|
code += "\nprint(pairgrid)"
|
|
499
1685
|
return code
|
|
500
1686
|
|
|
1687
|
+
|
|
501
1688
|
def ensure_image_output(code: str) -> str:
|
|
502
1689
|
"""
|
|
503
|
-
|
|
504
|
-
|
|
1690
|
+
Replace each plt.show() with an indented _SMX_export_png() call.
|
|
1691
|
+
This keeps block indentation valid and still renders images in the dashboard.
|
|
505
1692
|
"""
|
|
506
1693
|
if "plt.show()" not in code:
|
|
507
1694
|
return code
|
|
508
1695
|
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
"
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
1696
|
+
import re
|
|
1697
|
+
out_lines = []
|
|
1698
|
+
for ln in code.splitlines():
|
|
1699
|
+
if "plt.show()" not in ln:
|
|
1700
|
+
out_lines.append(ln)
|
|
1701
|
+
continue
|
|
1702
|
+
|
|
1703
|
+
# works for:
|
|
1704
|
+
# plt.show()
|
|
1705
|
+
# plt.tight_layout(); plt.show()
|
|
1706
|
+
# ... ; plt.show(); ... (multiple on one line)
|
|
1707
|
+
indent = re.match(r"^(\s*)", ln).group(1)
|
|
1708
|
+
parts = ln.split("plt.show()")
|
|
1709
|
+
|
|
1710
|
+
# keep whatever is before the first plt.show()
|
|
1711
|
+
if parts[0].strip():
|
|
1712
|
+
out_lines.append(parts[0].rstrip())
|
|
1713
|
+
|
|
1714
|
+
# for every plt.show() we removed, insert exporter at same indent
|
|
1715
|
+
for _ in range(len(parts) - 1):
|
|
1716
|
+
out_lines.append(indent + "_SMX_export_png()")
|
|
1717
|
+
|
|
1718
|
+
# keep whatever comes after the last plt.show()
|
|
1719
|
+
if parts[-1].strip():
|
|
1720
|
+
out_lines.append(indent + parts[-1].lstrip())
|
|
1721
|
+
|
|
1722
|
+
return "\n".join(out_lines)
|
|
1723
|
+
|
|
1724
|
+
|
|
1725
|
+
def clean_llm_code(code: str) -> str:
|
|
1726
|
+
"""
|
|
1727
|
+
Make LLM output safe to exec:
|
|
1728
|
+
- If fenced blocks exist, keep the largest one (usually the real code).
|
|
1729
|
+
- Otherwise strip any stray ``` / ```python lines.
|
|
1730
|
+
- Remove common markdown/preamble junk.
|
|
1731
|
+
"""
|
|
1732
|
+
code = str(code or "")
|
|
1733
|
+
|
|
1734
|
+
# Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1735
|
+
blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1736
|
+
|
|
1737
|
+
if blocks:
|
|
1738
|
+
# pick the largest block; small trailing blocks are usually garbage
|
|
1739
|
+
largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1740
|
+
if len(largest.strip().splitlines()) >= 10:
|
|
1741
|
+
code = largest
|
|
1742
|
+
else:
|
|
1743
|
+
# if no meaningful block, just remove fence markers
|
|
1744
|
+
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1745
|
+
else:
|
|
1746
|
+
# no complete blocks — still remove any stray fence lines
|
|
1747
|
+
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1748
|
+
|
|
1749
|
+
# Strip common markdown/preamble lines
|
|
1750
|
+
drop_prefixes = (
|
|
1751
|
+
"here is", "here's", "below is", "sure,", "certainly",
|
|
1752
|
+
"explanation", "note:", "```"
|
|
520
1753
|
)
|
|
1754
|
+
cleaned_lines = []
|
|
1755
|
+
for ln in code.splitlines():
|
|
1756
|
+
s = ln.strip().lower()
|
|
1757
|
+
if any(s.startswith(p) for p in drop_prefixes):
|
|
1758
|
+
continue
|
|
1759
|
+
cleaned_lines.append(ln)
|
|
1760
|
+
|
|
1761
|
+
return "\n".join(cleaned_lines).strip()
|
|
521
1762
|
|
|
522
|
-
# exporter BEFORE the original plt.show()
|
|
523
|
-
return code.replace("plt.show()", exporter + "plt.show()")
|
|
524
1763
|
|
|
525
1764
|
def fix_groupby_describe_slice(code: str) -> str:
|
|
526
1765
|
"""
|
|
@@ -543,6 +1782,7 @@ def fix_groupby_describe_slice(code: str) -> str:
|
|
|
543
1782
|
)
|
|
544
1783
|
return pat.sub(repl, code)
|
|
545
1784
|
|
|
1785
|
+
|
|
546
1786
|
def fix_importance_groupby(code: str) -> str:
|
|
547
1787
|
pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
|
|
548
1788
|
if "importance_df" in code:
|
|
@@ -589,10 +1829,12 @@ def inject_auto_preprocessing(code: str) -> str:
|
|
|
589
1829
|
# simply prepend; model code that follows can wrap estimator in a Pipeline
|
|
590
1830
|
return prep_snippet + code
|
|
591
1831
|
|
|
1832
|
+
|
|
592
1833
|
def fix_to_datetime_errors(code: str) -> str:
|
|
593
1834
|
"""
|
|
594
1835
|
Force every pd.to_datetime(…) call to ignore bad dates so that
|
|
595
|
-
|
|
1836
|
+
|
|
1837
|
+
'year 16500 is out of range' and similar issues don’t crash runs.
|
|
596
1838
|
"""
|
|
597
1839
|
import re
|
|
598
1840
|
# look for any pd.to_datetime( … )
|
|
@@ -605,25 +1847,67 @@ def fix_to_datetime_errors(code: str) -> str:
|
|
|
605
1847
|
return f"pd.to_datetime({inside}, errors='coerce')"
|
|
606
1848
|
return pat.sub(repl, code)
|
|
607
1849
|
|
|
1850
|
+
|
|
608
1851
|
def fix_numeric_sum(code: str) -> str:
|
|
609
1852
|
"""
|
|
610
|
-
|
|
611
|
-
|
|
1853
|
+
Make .sum(...) code safe across pandas versions by removing any
|
|
1854
|
+
numeric_only=... argument (True/False/None) from function calls.
|
|
1855
|
+
|
|
1856
|
+
This avoids errors on pandas versions where numeric_only is not
|
|
1857
|
+
supported for Series/grouped sums, and we rely instead on explicit
|
|
1858
|
+
numeric column selection (e.g. select_dtypes) in the generated code.
|
|
612
1859
|
"""
|
|
613
|
-
|
|
1860
|
+
# Case 1: ..., numeric_only=True/False/None
|
|
1861
|
+
code = re.sub(
|
|
1862
|
+
r",\s*numeric_only\s*=\s*(True|False|None)",
|
|
1863
|
+
"",
|
|
1864
|
+
code,
|
|
1865
|
+
flags=re.IGNORECASE,
|
|
1866
|
+
)
|
|
614
1867
|
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
1868
|
+
# Case 2: numeric_only=True/False/None, ... (as first argument)
|
|
1869
|
+
code = re.sub(
|
|
1870
|
+
r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
|
|
1871
|
+
"",
|
|
1872
|
+
code,
|
|
1873
|
+
flags=re.IGNORECASE,
|
|
1874
|
+
)
|
|
619
1875
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
1876
|
+
# Case 3: numeric_only=True/False/None (only argument)
|
|
1877
|
+
code = re.sub(
|
|
1878
|
+
r"numeric_only\s*=\s*(True|False|None)",
|
|
1879
|
+
"",
|
|
1880
|
+
code,
|
|
1881
|
+
flags=re.IGNORECASE,
|
|
1882
|
+
)
|
|
1883
|
+
|
|
1884
|
+
return code
|
|
1885
|
+
|
|
1886
|
+
|
|
1887
|
+
def fix_concat_empty_list(code: str) -> str:
|
|
1888
|
+
"""
|
|
1889
|
+
Make pd.concat calls resilient to empty lists of objects.
|
|
1890
|
+
|
|
1891
|
+
Transforms patterns like:
|
|
1892
|
+
pd.concat(frames, ignore_index=True)
|
|
1893
|
+
pd.concat(frames)
|
|
1894
|
+
|
|
1895
|
+
into:
|
|
1896
|
+
pd.concat(frames or [pd.DataFrame()], ignore_index=True)
|
|
1897
|
+
pd.concat(frames or [pd.DataFrame()])
|
|
1898
|
+
|
|
1899
|
+
Only triggers when the first argument is a simple variable name.
|
|
1900
|
+
"""
|
|
1901
|
+
pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
|
|
1902
|
+
|
|
1903
|
+
def _repl(m):
|
|
1904
|
+
name = m.group(1)
|
|
1905
|
+
sep = m.group(2) # ',' or ')'
|
|
1906
|
+
return f"pd.concat({name} or [pd.DataFrame()]{sep}"
|
|
624
1907
|
|
|
625
1908
|
return pattern.sub(_repl, code)
|
|
626
1909
|
|
|
1910
|
+
|
|
627
1911
|
def fix_numeric_aggs(code: str) -> str:
|
|
628
1912
|
_AGG_FUNCS = ("sum", "mean")
|
|
629
1913
|
pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
|
|
@@ -637,6 +1921,7 @@ def fix_numeric_aggs(code: str) -> str:
|
|
|
637
1921
|
return f".{func}({args}numeric_only=True)"
|
|
638
1922
|
return pat.sub(_repl, code)
|
|
639
1923
|
|
|
1924
|
+
|
|
640
1925
|
def ensure_accuracy_block(code: str) -> str:
|
|
641
1926
|
"""
|
|
642
1927
|
Inject a sensible evaluation block right after the last `<est>.fit(...)`
|
|
@@ -696,40 +1981,6 @@ def ensure_accuracy_block(code: str) -> str:
|
|
|
696
1981
|
insert_at = code.find("\n", m[-1].end()) + 1
|
|
697
1982
|
return code[:insert_at] + eval_block + code[insert_at:]
|
|
698
1983
|
|
|
699
|
-
def classify(prompt: str) -> str:
|
|
700
|
-
"""
|
|
701
|
-
Very-light intent classifier.
|
|
702
|
-
Returns one of:
|
|
703
|
-
'stat_test' | 'time_series' | 'clustering'
|
|
704
|
-
'classification' | 'regression' | 'eda'
|
|
705
|
-
"""
|
|
706
|
-
p = prompt.lower().strip()
|
|
707
|
-
greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
|
|
708
|
-
if any(p.startswith(g) or p == g for g in greetings):
|
|
709
|
-
return "greeting"
|
|
710
|
-
|
|
711
|
-
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
712
|
-
return "stat_test"
|
|
713
|
-
if "forecast" in p or "prophet" in p:
|
|
714
|
-
return "time_series"
|
|
715
|
-
if "cluster" in p or "kmeans" in p:
|
|
716
|
-
return "clustering"
|
|
717
|
-
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
718
|
-
return "classification"
|
|
719
|
-
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
720
|
-
return "regression"
|
|
721
|
-
return "eda"
|
|
722
|
-
|
|
723
|
-
def auto_inject_template(code: str, intent: str, df) -> str:
|
|
724
|
-
"""If the LLM forgot the core logic, prepend a skeleton block."""
|
|
725
|
-
has_fit = ".fit(" in code
|
|
726
|
-
|
|
727
|
-
if intent == "classification" and not has_fit:
|
|
728
|
-
# guess a y column that contains 'diabetes' as in your dataset
|
|
729
|
-
target = next((c for c in df.columns if "diabetes" in c.lower()), None)
|
|
730
|
-
if target:
|
|
731
|
-
return classification(df, target) + "\n\n" + code
|
|
732
|
-
return code
|
|
733
1984
|
|
|
734
1985
|
def fix_scatter_and_summary(code: str) -> str:
|
|
735
1986
|
"""
|
|
@@ -757,6 +2008,7 @@ def fix_scatter_and_summary(code: str) -> str:
|
|
|
757
2008
|
|
|
758
2009
|
return code
|
|
759
2010
|
|
|
2011
|
+
|
|
760
2012
|
def auto_format_with_black(code: str) -> str:
|
|
761
2013
|
"""
|
|
762
2014
|
Format the generated code with Black. Falls back silently if Black
|
|
@@ -771,6 +2023,7 @@ def auto_format_with_black(code: str) -> str:
|
|
|
771
2023
|
except Exception:
|
|
772
2024
|
return code
|
|
773
2025
|
|
|
2026
|
+
|
|
774
2027
|
def ensure_preproc_in_pipeline(code: str) -> str:
|
|
775
2028
|
"""
|
|
776
2029
|
If code defines `preproc = ColumnTransformer(...)` but then builds
|
|
@@ -783,13 +2036,14 @@ def ensure_preproc_in_pipeline(code: str) -> str:
|
|
|
783
2036
|
code
|
|
784
2037
|
)
|
|
785
2038
|
|
|
2039
|
+
|
|
786
2040
|
def fix_plain_prints(code: str) -> str:
|
|
787
2041
|
"""
|
|
788
2042
|
Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
|
|
789
2043
|
to go through SyntaxMatrix's smart display (so it renders in the dashboard).
|
|
790
2044
|
Keeps string prints alone.
|
|
791
2045
|
"""
|
|
792
|
-
|
|
2046
|
+
|
|
793
2047
|
# Skip obvious string-literal prints
|
|
794
2048
|
new = re.sub(
|
|
795
2049
|
r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
|
|
@@ -798,6 +2052,7 @@ def fix_plain_prints(code: str) -> str:
|
|
|
798
2052
|
)
|
|
799
2053
|
return new
|
|
800
2054
|
|
|
2055
|
+
|
|
801
2056
|
def fix_print_html(code: str) -> str:
|
|
802
2057
|
"""
|
|
803
2058
|
Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
|
|
@@ -830,6 +2085,7 @@ def fix_print_html(code: str) -> str:
|
|
|
830
2085
|
|
|
831
2086
|
return new
|
|
832
2087
|
|
|
2088
|
+
|
|
833
2089
|
def ensure_ipy_display(code: str) -> str:
|
|
834
2090
|
"""
|
|
835
2091
|
Guarantee that the cell has proper IPython display imports so that
|
|
@@ -838,9 +2094,8 @@ def ensure_ipy_display(code: str) -> str:
|
|
|
838
2094
|
if "display(" in code and "from IPython.display import display, HTML" not in code:
|
|
839
2095
|
return "from IPython.display import display, HTML\n" + code
|
|
840
2096
|
return code
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
# --------------------------------------------------------------------------
|
|
2097
|
+
|
|
2098
|
+
|
|
844
2099
|
def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
845
2100
|
"""
|
|
846
2101
|
Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
|
|
@@ -885,6 +2140,7 @@ def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
|
885
2140
|
|
|
886
2141
|
return code
|
|
887
2142
|
|
|
2143
|
+
|
|
888
2144
|
def force_capture_display(code: str) -> str:
|
|
889
2145
|
"""
|
|
890
2146
|
Ensure our executor captures HTML output:
|
|
@@ -925,11 +2181,13 @@ def force_capture_display(code: str) -> str:
|
|
|
925
2181
|
)
|
|
926
2182
|
return new
|
|
927
2183
|
|
|
2184
|
+
|
|
928
2185
|
def strip_matplotlib_show(code: str) -> str:
|
|
929
2186
|
"""Remove blocking plt.show() calls (we export base64 instead)."""
|
|
930
2187
|
import re
|
|
931
2188
|
return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
|
|
932
2189
|
|
|
2190
|
+
|
|
933
2191
|
def inject_display_shim(code: str) -> str:
|
|
934
2192
|
"""
|
|
935
2193
|
Provide display()/HTML() if missing, forwarding to our executor hook.
|
|
@@ -951,6 +2209,7 @@ def inject_display_shim(code: str) -> str:
|
|
|
951
2209
|
)
|
|
952
2210
|
return shim + code
|
|
953
2211
|
|
|
2212
|
+
|
|
954
2213
|
def strip_spurious_column_tokens(code: str) -> str:
|
|
955
2214
|
"""
|
|
956
2215
|
Remove common stop-words ('the','whether', ...) when they appear
|
|
@@ -961,7 +2220,8 @@ def strip_spurious_column_tokens(code: str) -> str:
|
|
|
961
2220
|
"""
|
|
962
2221
|
STOP = {
|
|
963
2222
|
"the","whether","a","an","and","or","of","to","in","on","for","by",
|
|
964
|
-
"with","as","at","from","that","this","these","those","is","are","was","were"
|
|
2223
|
+
"with","as","at","from","that","this","these","those","is","are","was","were",
|
|
2224
|
+
"coef", "Coef", "coefficient", "Coefficient"
|
|
965
2225
|
}
|
|
966
2226
|
|
|
967
2227
|
def _norm(s: str) -> str:
|
|
@@ -984,10 +2244,920 @@ def strip_spurious_column_tokens(code: str) -> str:
|
|
|
984
2244
|
|
|
985
2245
|
# df[[ ... ]] selections
|
|
986
2246
|
code = re.sub(
|
|
987
|
-
r"df\s*\[\s*\[([^\]]+)\]\s*\]",
|
|
988
|
-
|
|
2247
|
+
r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
|
|
2248
|
+
|
|
2249
|
+
return code
|
|
2250
|
+
|
|
2251
|
+
|
|
2252
|
+
def patch_prefix_seaborn_calls(code: str) -> str:
|
|
2253
|
+
"""
|
|
2254
|
+
Ensure bare seaborn calls are prefixed with `sns.`.
|
|
2255
|
+
E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
|
|
2256
|
+
"""
|
|
2257
|
+
if "sns." in code:
|
|
2258
|
+
# still fix any leftover bare calls alongside prefixed ones
|
|
2259
|
+
pass
|
|
2260
|
+
|
|
2261
|
+
# functions commonly used from seaborn
|
|
2262
|
+
funcs = [
|
|
2263
|
+
"barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
|
|
2264
|
+
"histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
|
|
2265
|
+
"scatterplot","lineplot","catplot","displot","lmplot"
|
|
2266
|
+
]
|
|
2267
|
+
# Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
|
|
2268
|
+
# (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
|
|
2269
|
+
pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
|
|
2270
|
+
|
|
2271
|
+
def _add_prefix(m):
|
|
2272
|
+
fn = m.group(1)
|
|
2273
|
+
return f"sns.{fn}("
|
|
2274
|
+
|
|
2275
|
+
return pattern.sub(_add_prefix, code)
|
|
2276
|
+
|
|
2277
|
+
|
|
2278
|
+
def patch_ensure_seaborn_import(code: str) -> str:
|
|
2279
|
+
"""
|
|
2280
|
+
If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
|
|
2281
|
+
Also set a quiet theme for consistent visuals.
|
|
2282
|
+
"""
|
|
2283
|
+
needs_sns = "sns." in code
|
|
2284
|
+
has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2285
|
+
if needs_sns and not has_import:
|
|
2286
|
+
# Insert after the first block of imports if possible, else at top
|
|
2287
|
+
import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2288
|
+
inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2289
|
+
if import_block:
|
|
2290
|
+
start = import_block.end()
|
|
2291
|
+
code = code[:start] + inject + code[start:]
|
|
2292
|
+
else:
|
|
2293
|
+
code = inject + code
|
|
2294
|
+
return code
|
|
2295
|
+
|
|
2296
|
+
|
|
2297
|
+
def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2298
|
+
"""
|
|
2299
|
+
Normalise pie-chart requests.
|
|
2300
|
+
|
|
2301
|
+
Supports three patterns:
|
|
2302
|
+
A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2303
|
+
B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2304
|
+
→ one pie per facet level (grid) + counts bar of facet sizes.
|
|
2305
|
+
C) Single pie when no split/facet is requested.
|
|
2306
|
+
|
|
2307
|
+
Notes:
|
|
2308
|
+
- Pie variables must be categorical (or numeric binned).
|
|
2309
|
+
- Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2310
|
+
"""
|
|
2311
|
+
|
|
2312
|
+
q = (user_question or "")
|
|
2313
|
+
q_low = q.lower()
|
|
2314
|
+
|
|
2315
|
+
# Prefer explicit: df['col'].value_counts()
|
|
2316
|
+
m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2317
|
+
col = m.group(1) if m else None
|
|
2318
|
+
|
|
2319
|
+
# ---------- helpers ----------
|
|
2320
|
+
def _is_cat(col):
|
|
2321
|
+
return (str(df[col].dtype).startswith("category")
|
|
2322
|
+
or df[col].dtype == "object"
|
|
2323
|
+
or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2324
|
+
|
|
2325
|
+
def _cats_from_question(question: str):
|
|
2326
|
+
found = []
|
|
2327
|
+
for c in df.columns:
|
|
2328
|
+
if c.lower() in question.lower() and _is_cat(c):
|
|
2329
|
+
found.append(c)
|
|
2330
|
+
# dedupe preserve order
|
|
2331
|
+
seen, out = set(), []
|
|
2332
|
+
for c in found:
|
|
2333
|
+
if c not in seen:
|
|
2334
|
+
out.append(c); seen.add(c)
|
|
2335
|
+
return out
|
|
2336
|
+
|
|
2337
|
+
def _fallback_cat():
|
|
2338
|
+
cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2339
|
+
if not cats: return None
|
|
2340
|
+
cats.sort(key=lambda t: t[1])
|
|
2341
|
+
return cats[0][0]
|
|
2342
|
+
|
|
2343
|
+
def _infer_comp_pref(question: str) -> str:
|
|
2344
|
+
ql = (question or "").lower()
|
|
2345
|
+
if "heatmap" in ql or "matrix" in ql:
|
|
2346
|
+
return "heatmap"
|
|
2347
|
+
if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
|
|
2348
|
+
return "stacked_bar_pct"
|
|
2349
|
+
if "stacked" in ql:
|
|
2350
|
+
return "stacked_bar"
|
|
2351
|
+
if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
|
|
2352
|
+
return "grouped_bar"
|
|
2353
|
+
return "counts_bar"
|
|
2354
|
+
|
|
2355
|
+
# parse threshold split like "HbA1c ≥ 6.5"
|
|
2356
|
+
def _parse_split(question: str):
|
|
2357
|
+
ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
|
|
2358
|
+
m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
|
|
2359
|
+
if not m: return None
|
|
2360
|
+
col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
|
|
2361
|
+
op = ops_map.get(op_raw);
|
|
2362
|
+
if not op: return None
|
|
2363
|
+
# case-insensitive column match
|
|
2364
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2365
|
+
col = candidates.get(col_raw.lower())
|
|
2366
|
+
if not col: return None
|
|
2367
|
+
try: val = float(val_raw)
|
|
2368
|
+
except Exception: return None
|
|
2369
|
+
return (col, op, val)
|
|
2370
|
+
|
|
2371
|
+
# facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
|
|
2372
|
+
def _extract_facet(question: str):
|
|
2373
|
+
# 1) explicit "by/ across / within / per <col>"
|
|
2374
|
+
for kw in [" by ", " across ", " within ", " within each ", " per "]:
|
|
2375
|
+
m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
|
|
2376
|
+
if m:
|
|
2377
|
+
col_raw = m.group(1).strip()
|
|
2378
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2379
|
+
if col_raw.lower() in candidates:
|
|
2380
|
+
return (candidates[col_raw.lower()], "auto")
|
|
2381
|
+
# 2) "bin <col>"
|
|
2382
|
+
m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
|
|
2383
|
+
if m2:
|
|
2384
|
+
col_raw = m2.group(1).strip()
|
|
2385
|
+
candidates = {c.lower(): c for c in df.columns}
|
|
2386
|
+
if col_raw.lower() in candidates:
|
|
2387
|
+
return (candidates[col_raw.lower()], "bin")
|
|
2388
|
+
# 3) BMI special: mentions of normal/overweight/obese imply BMI categories
|
|
2389
|
+
if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
|
|
2390
|
+
any(c.lower() == "bmi" for c in df.columns.str.lower()):
|
|
2391
|
+
bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
|
|
2392
|
+
return (bmi_col, "bmi")
|
|
2393
|
+
return None
|
|
2394
|
+
|
|
2395
|
+
def _bmi_bins(series: pd.Series):
|
|
2396
|
+
# WHO cutoffs
|
|
2397
|
+
bins = [-np.inf, 18.5, 25, 30, np.inf]
|
|
2398
|
+
labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
|
|
2399
|
+
return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
|
|
2400
|
+
|
|
2401
|
+
wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
|
|
2402
|
+
if not wants_pie:
|
|
2403
|
+
return code
|
|
2404
|
+
|
|
2405
|
+
split = _parse_split(q)
|
|
2406
|
+
facet = _extract_facet(q)
|
|
2407
|
+
cats = _cats_from_question(q)
|
|
2408
|
+
_comp_pref = _infer_comp_pref(q)
|
|
2409
|
+
|
|
2410
|
+
# Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
|
|
2411
|
+
for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
|
|
2412
|
+
if hard in df.columns and hard not in cats and hard.lower() in q_low:
|
|
2413
|
+
cats.append(hard)
|
|
2414
|
+
|
|
2415
|
+
# --------------- CASE A: threshold split (cohorts) ---------------
|
|
2416
|
+
if split:
|
|
2417
|
+
if not (cats or any(_is_cat(c) for c in df.columns)):
|
|
2418
|
+
return code
|
|
2419
|
+
if not cats:
|
|
2420
|
+
pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2421
|
+
pool.sort(key=lambda t: t[1])
|
|
2422
|
+
cats = [t[0] for t in pool[:3]] if pool else []
|
|
2423
|
+
if not cats:
|
|
2424
|
+
return code
|
|
2425
|
+
|
|
2426
|
+
split_col, op, val = split
|
|
2427
|
+
cond_str = f"(df['{split_col}'] {op} {val})"
|
|
2428
|
+
snippet = f"""
|
|
2429
|
+
import numpy as np
|
|
2430
|
+
import pandas as pd
|
|
2431
|
+
import matplotlib.pyplot as plt
|
|
2432
|
+
|
|
2433
|
+
_mask_a = ({cond_str}) & df['{split_col}'].notna()
|
|
2434
|
+
_mask_b = (~({cond_str})) & df['{split_col}'].notna()
|
|
2435
|
+
|
|
2436
|
+
_cohort_a_name = "{split_col} {op} {val}"
|
|
2437
|
+
_cohort_b_name = "NOT ({split_col} {op} {val})"
|
|
2438
|
+
|
|
2439
|
+
_cat_cols = {cats!r}
|
|
2440
|
+
n = len(_cat_cols)
|
|
2441
|
+
fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
|
|
2442
|
+
if n == 1:
|
|
2443
|
+
axes = np.array([axes])
|
|
2444
|
+
|
|
2445
|
+
for i, col in enumerate(_cat_cols):
|
|
2446
|
+
s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
|
|
2447
|
+
s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
|
|
2448
|
+
|
|
2449
|
+
ax_a = axes[i, 0]; ax_b = axes[i, 1]
|
|
2450
|
+
if len(s_a) > 0:
|
|
2451
|
+
ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
|
|
2452
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2453
|
+
ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
|
|
2454
|
+
|
|
2455
|
+
if len(s_b) > 0:
|
|
2456
|
+
ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
|
|
2457
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2458
|
+
ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
|
|
2459
|
+
|
|
2460
|
+
plt.tight_layout(); plt.show()
|
|
2461
|
+
|
|
2462
|
+
# grouped bar complement
|
|
2463
|
+
for col in _cat_cols:
|
|
2464
|
+
_tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
|
|
2465
|
+
.assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
|
|
2466
|
+
_tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
|
|
2467
|
+
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2468
|
+
|
|
2469
|
+
if _comp_pref == "grouped_bar":
|
|
2470
|
+
ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
|
|
2471
|
+
ax.set_title(f"{col} by cohort (grouped)")
|
|
2472
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2473
|
+
plt.tight_layout(); plt.show()
|
|
2474
|
+
|
|
2475
|
+
elif _comp_pref == "stacked_bar":
|
|
2476
|
+
ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2477
|
+
ax.set_title(f"{col} by cohort (stacked)")
|
|
2478
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2479
|
+
plt.tight_layout(); plt.show()
|
|
2480
|
+
|
|
2481
|
+
elif _comp_pref == "stacked_bar_pct":
|
|
2482
|
+
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2483
|
+
ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2484
|
+
ax.set_title(f"{col} by cohort (100% stacked)")
|
|
2485
|
+
ax.set_xlabel(col); ax.set_ylabel("Percent")
|
|
2486
|
+
plt.tight_layout(); plt.show()
|
|
2487
|
+
|
|
2488
|
+
elif _comp_pref == "heatmap":
|
|
2489
|
+
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2490
|
+
import numpy as np
|
|
2491
|
+
fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
|
|
2492
|
+
im = ax.imshow(_perc.values, aspect='auto')
|
|
2493
|
+
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2494
|
+
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2495
|
+
ax.set_title(f"{col} by cohort — % heatmap")
|
|
2496
|
+
for i in range(_perc.shape[0]):
|
|
2497
|
+
for j in range(_perc.shape[1]):
|
|
2498
|
+
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2499
|
+
fig.colorbar(im, ax=ax, label="%")
|
|
2500
|
+
plt.tight_layout(); plt.show()
|
|
2501
|
+
|
|
2502
|
+
else: # counts_bar (default)
|
|
2503
|
+
ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
|
|
2504
|
+
ax.set_title(f"{col}: total counts (both cohorts)")
|
|
2505
|
+
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2506
|
+
plt.tight_layout(); plt.show()
|
|
2507
|
+
""".lstrip()
|
|
2508
|
+
return snippet
|
|
2509
|
+
|
|
2510
|
+
# --------------- CASE B: facet-by (categories/bins) ---------------
|
|
2511
|
+
if facet:
|
|
2512
|
+
facet_col, how = facet
|
|
2513
|
+
# Build facet series
|
|
2514
|
+
if pd.api.types.is_numeric_dtype(df[facet_col]):
|
|
2515
|
+
if how == "bmi":
|
|
2516
|
+
facet_series = _bmi_bins(df[facet_col])
|
|
2517
|
+
else:
|
|
2518
|
+
# generic numeric bins: 3 equal-width bins by default
|
|
2519
|
+
facet_series = pd.cut(df[facet_col].astype(float), bins=3)
|
|
2520
|
+
else:
|
|
2521
|
+
facet_series = df[facet_col].astype(str)
|
|
2522
|
+
|
|
2523
|
+
# Choose pie dimension (categorical to count inside each facet)
|
|
2524
|
+
pie_dim = None
|
|
2525
|
+
for c in cats:
|
|
2526
|
+
if c in df.columns and _is_cat(c):
|
|
2527
|
+
pie_dim = c; break
|
|
2528
|
+
if pie_dim is None:
|
|
2529
|
+
pie_dim = _fallback_cat()
|
|
2530
|
+
if pie_dim is None:
|
|
2531
|
+
return code
|
|
2532
|
+
|
|
2533
|
+
snippet = f"""
|
|
2534
|
+
import math
|
|
2535
|
+
import pandas as pd
|
|
2536
|
+
import matplotlib.pyplot as plt
|
|
2537
|
+
|
|
2538
|
+
df = df.copy()
|
|
2539
|
+
_preferred = "{facet_col}" if "{facet_col}" in df.columns else None
|
|
2540
|
+
|
|
2541
|
+
def _select_facet_col(df, preferred=None):
|
|
2542
|
+
if preferred is not None:
|
|
2543
|
+
return preferred
|
|
2544
|
+
# Prefer low-cardinality categoricals (readable pies/grids)
|
|
2545
|
+
cat_cols = [
|
|
2546
|
+
c for c in df.columns
|
|
2547
|
+
if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
|
2548
|
+
and df[c].nunique() > 1 and df[c].nunique() <= 20
|
|
2549
|
+
]
|
|
2550
|
+
if cat_cols:
|
|
2551
|
+
cat_cols.sort(key=lambda c: df[c].nunique())
|
|
2552
|
+
return cat_cols[0]
|
|
2553
|
+
# Else fall back to first usable numeric
|
|
2554
|
+
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
|
|
2555
|
+
return num_cols[0] if num_cols else None
|
|
2556
|
+
|
|
2557
|
+
_facet_col = _select_facet_col(df, _preferred)
|
|
2558
|
+
|
|
2559
|
+
if _facet_col is None:
|
|
2560
|
+
# Nothing suitable → single facet keeps pipeline alive
|
|
2561
|
+
df["__facet__"] = "All"
|
|
2562
|
+
else:
|
|
2563
|
+
s = df[_facet_col]
|
|
2564
|
+
if pd.api.types.is_numeric_dtype(s):
|
|
2565
|
+
# Robust numeric binning: quantiles first, fallback to equal-width
|
|
2566
|
+
uniq = pd.Series(s).dropna().nunique()
|
|
2567
|
+
q = 3 if uniq < 10 else 4 if uniq < 30 else 5
|
|
2568
|
+
try:
|
|
2569
|
+
df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
|
|
2570
|
+
except Exception:
|
|
2571
|
+
df["__facet__"] = pd.cut(s.astype(float), bins=q)
|
|
2572
|
+
else:
|
|
2573
|
+
# Cap long tails; keep top categories
|
|
2574
|
+
vc = s.astype(str).value_counts()
|
|
2575
|
+
keep = vc.index[:{top_n}]
|
|
2576
|
+
df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
|
|
2577
|
+
|
|
2578
|
+
levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
|
|
2579
|
+
levels = [x for x in levels if x != "nan"]
|
|
2580
|
+
levels.sort()
|
|
2581
|
+
|
|
2582
|
+
m = len(levels)
|
|
2583
|
+
cols = 3 if m >= 3 else m or 1
|
|
2584
|
+
rows = int(math.ceil(m / cols))
|
|
2585
|
+
|
|
2586
|
+
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
|
|
2587
|
+
if not isinstance(axes, (list, np.ndarray)):
|
|
2588
|
+
axes = np.array([[axes]])
|
|
2589
|
+
axes = axes.reshape(rows, cols)
|
|
2590
|
+
|
|
2591
|
+
for i, lvl in enumerate(levels):
|
|
2592
|
+
r, c = divmod(i, cols)
|
|
2593
|
+
ax = axes[r, c]
|
|
2594
|
+
s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
|
|
2595
|
+
.astype(str).value_counts().nlargest({top_n}))
|
|
2596
|
+
if len(s) > 0:
|
|
2597
|
+
ax.pie(s.values, labels=[str(x) for x in s.index],
|
|
2598
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2599
|
+
ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
|
|
2600
|
+
|
|
2601
|
+
# hide any empty subplots
|
|
2602
|
+
for j in range(m, rows*cols):
|
|
2603
|
+
r, c = divmod(j, cols)
|
|
2604
|
+
axes[r, c].axis("off")
|
|
2605
|
+
|
|
2606
|
+
plt.tight_layout(); plt.show()
|
|
2607
|
+
|
|
2608
|
+
# --- companion visual (adaptive) ---
|
|
2609
|
+
_comp_pref = "{_comp_pref}"
|
|
2610
|
+
# build contingency table: pie_dim x facet
|
|
2611
|
+
_tab = (df[["__facet__", "{pie_dim}"]]
|
|
2612
|
+
.dropna()
|
|
2613
|
+
.astype({{"__facet__": str, "{pie_dim}": str}})
|
|
2614
|
+
.value_counts()
|
|
2615
|
+
.unstack(level="__facet__")
|
|
2616
|
+
.fillna(0))
|
|
2617
|
+
|
|
2618
|
+
# keep top categories by overall size
|
|
2619
|
+
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2620
|
+
|
|
2621
|
+
if _comp_pref == "grouped_bar":
|
|
2622
|
+
ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2623
|
+
ax.set_title("{pie_dim} by {facet_col} (grouped)")
|
|
2624
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2625
|
+
plt.tight_layout(); plt.show()
|
|
2626
|
+
|
|
2627
|
+
elif _comp_pref == "stacked_bar":
|
|
2628
|
+
ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2629
|
+
ax.set_title("{pie_dim} by {facet_col} (stacked)")
|
|
2630
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2631
|
+
plt.tight_layout(); plt.show()
|
|
2632
|
+
|
|
2633
|
+
elif _comp_pref == "stacked_bar_pct":
|
|
2634
|
+
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
|
|
2635
|
+
ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
|
|
2636
|
+
ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
|
|
2637
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
|
|
2638
|
+
plt.tight_layout(); plt.show()
|
|
2639
|
+
|
|
2640
|
+
elif _comp_pref == "heatmap":
|
|
2641
|
+
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
|
|
2642
|
+
import numpy as np
|
|
2643
|
+
fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
|
|
2644
|
+
im = ax.imshow(_perc.values, aspect='auto')
|
|
2645
|
+
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2646
|
+
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2647
|
+
ax.set_title("{pie_dim} by {facet_col} — % heatmap")
|
|
2648
|
+
for i in range(_perc.shape[0]):
|
|
2649
|
+
for j in range(_perc.shape[1]):
|
|
2650
|
+
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2651
|
+
fig.colorbar(im, ax=ax, label="%")
|
|
2652
|
+
plt.tight_layout(); plt.show()
|
|
2653
|
+
|
|
2654
|
+
else: # counts_bar (default denominators)
|
|
2655
|
+
_counts = df["__facet"].value_counts()
|
|
2656
|
+
ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
|
|
2657
|
+
ax.set_title("Counts by {facet_col}")
|
|
2658
|
+
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2659
|
+
plt.tight_layout(); plt.show()
|
|
2660
|
+
|
|
2661
|
+
""".lstrip()
|
|
2662
|
+
return snippet
|
|
2663
|
+
|
|
2664
|
+
# --------------- CASE C: single pie ---------------
|
|
2665
|
+
chosen = None
|
|
2666
|
+
for c in cats:
|
|
2667
|
+
if c in df.columns and _is_cat(c):
|
|
2668
|
+
chosen = c; break
|
|
2669
|
+
if chosen is None:
|
|
2670
|
+
chosen = _fallback_cat()
|
|
2671
|
+
|
|
2672
|
+
if chosen:
|
|
2673
|
+
snippet = f"""
|
|
2674
|
+
import matplotlib.pyplot as plt
|
|
2675
|
+
counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
|
|
2676
|
+
fig, ax = plt.subplots()
|
|
2677
|
+
if len(counts) > 0:
|
|
2678
|
+
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2679
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2680
|
+
ax.set_title('Distribution of {chosen} (top {top_n})')
|
|
2681
|
+
ax.axis('equal')
|
|
2682
|
+
plt.show()
|
|
2683
|
+
""".lstrip()
|
|
2684
|
+
return snippet
|
|
2685
|
+
|
|
2686
|
+
# numeric last resort
|
|
2687
|
+
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
2688
|
+
if num_cols:
|
|
2689
|
+
col = num_cols[0]
|
|
2690
|
+
snippet = f"""
|
|
2691
|
+
import pandas as pd
|
|
2692
|
+
import matplotlib.pyplot as plt
|
|
2693
|
+
bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
|
|
2694
|
+
counts = bins.value_counts().sort_index()
|
|
2695
|
+
fig, ax = plt.subplots()
|
|
2696
|
+
if len(counts) > 0:
|
|
2697
|
+
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2698
|
+
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2699
|
+
ax.set_title('Distribution of {col} (binned)')
|
|
2700
|
+
ax.axis('equal')
|
|
2701
|
+
plt.show()
|
|
2702
|
+
""".lstrip()
|
|
2703
|
+
return snippet
|
|
2704
|
+
|
|
2705
|
+
return code
|
|
2706
|
+
|
|
2707
|
+
|
|
2708
|
+
def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2709
|
+
"""
|
|
2710
|
+
Removes seaborn `palette=` when no `hue=` is present in the same call.
|
|
2711
|
+
Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
|
|
2712
|
+
"""
|
|
2713
|
+
if "sns." not in code:
|
|
2714
|
+
return code
|
|
2715
|
+
|
|
2716
|
+
# Targets common seaborn plotters
|
|
2717
|
+
funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
|
|
2718
|
+
pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
|
|
2719
|
+
|
|
2720
|
+
def _fix_call(m):
|
|
2721
|
+
head, inner = m.group(1), m.group(2)
|
|
2722
|
+
# If there's already hue=, keep as is
|
|
2723
|
+
if re.search(r"(?<!\w)hue\s*=", inner):
|
|
2724
|
+
return f"{head}{inner})"
|
|
2725
|
+
# Otherwise remove palette=... safely (and any adjacent comma spacing)
|
|
2726
|
+
inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
|
|
2727
|
+
inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
|
|
2728
|
+
inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
|
|
2729
|
+
return f"{head}{inner2})"
|
|
2730
|
+
|
|
2731
|
+
return pattern.sub(_fix_call, code)
|
|
2732
|
+
|
|
2733
|
+
|
|
2734
|
+
def patch_quiet_specific_warnings(code: str) -> str:
|
|
2735
|
+
"""
|
|
2736
|
+
Inserts targeted warning filters (not blanket ignores).
|
|
2737
|
+
- seaborn palette/hue deprecation
|
|
2738
|
+
- python-dotenv parse chatter
|
|
2739
|
+
"""
|
|
2740
|
+
prelude = (
|
|
2741
|
+
"import warnings\n"
|
|
2742
|
+
"warnings.filterwarnings(\n"
|
|
2743
|
+
" 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
|
|
2744
|
+
"warnings.filterwarnings(\n"
|
|
2745
|
+
" 'ignore', message=r'python-dotenv could not parse statement.*')\n"
|
|
2746
|
+
)
|
|
2747
|
+
# If warnings already imported once, just add filters; else insert full prelude.
|
|
2748
|
+
if "import warnings" in code:
|
|
2749
|
+
code = re.sub(
|
|
2750
|
+
r"(import warnings[^\n]*\n)",
|
|
2751
|
+
lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
|
|
2752
|
+
code,
|
|
2753
|
+
count=1
|
|
2754
|
+
)
|
|
2755
|
+
|
|
2756
|
+
else:
|
|
2757
|
+
# place after first import block if possible
|
|
2758
|
+
m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
|
|
2759
|
+
if m:
|
|
2760
|
+
idx = m.end()
|
|
2761
|
+
code = code[:idx] + prelude + code[idx:]
|
|
2762
|
+
else:
|
|
2763
|
+
code = prelude + code
|
|
2764
|
+
return code
|
|
2765
|
+
|
|
2766
|
+
|
|
2767
|
+
def _norm_col_name(s: str) -> str:
|
|
2768
|
+
"""normalise a column name: lowercase + strip non-alphanumerics."""
|
|
2769
|
+
return re.sub(r"[^a-z0-9]+", "", str(s).lower())
|
|
2770
|
+
|
|
2771
|
+
|
|
2772
|
+
def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
2773
|
+
"""return the actual df column that matches any candidate (after normalisation)."""
|
|
2774
|
+
norm_map = {_norm_col_name(c): c for c in df.columns}
|
|
2775
|
+
for cand in candidates:
|
|
2776
|
+
hit = norm_map.get(_norm_col_name(cand))
|
|
2777
|
+
if hit is not None:
|
|
2778
|
+
return hit
|
|
2779
|
+
return None
|
|
2780
|
+
|
|
2781
|
+
|
|
2782
|
+
def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
|
|
2783
|
+
"""
|
|
2784
|
+
If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
|
|
2785
|
+
Returns (df, found_bool).
|
|
2786
|
+
"""
|
|
2787
|
+
if target in df.columns:
|
|
2788
|
+
return df, True
|
|
2789
|
+
col = _first_present(df, [target, *aliases])
|
|
2790
|
+
if col is None:
|
|
2791
|
+
return df, False
|
|
2792
|
+
df[target] = df[col]
|
|
2793
|
+
return df, True
|
|
2794
|
+
|
|
2795
|
+
|
|
2796
|
+
def strip_python_dotenv(code: str) -> str:
|
|
2797
|
+
"""
|
|
2798
|
+
Remove any use of python-dotenv from generated code, including:
|
|
2799
|
+
- single and multi-line 'from dotenv import ...'
|
|
2800
|
+
- 'import dotenv' (with or without alias) and calls via any alias
|
|
2801
|
+
- load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
|
|
2802
|
+
- IPython magics (%load_ext dotenv, %dotenv, %env …)
|
|
2803
|
+
- shell installs like '!pip install python-dotenv'
|
|
2804
|
+
"""
|
|
2805
|
+
original = code
|
|
2806
|
+
|
|
2807
|
+
# 0) Kill IPython magics & shell installs referencing dotenv
|
|
2808
|
+
code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
|
|
2809
|
+
code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
|
|
2810
|
+
code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
|
|
2811
|
+
code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
|
|
2812
|
+
|
|
2813
|
+
# 1) Remove single-line 'from dotenv import ...'
|
|
2814
|
+
code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
|
|
2815
|
+
|
|
2816
|
+
# 2) Remove multi-line 'from dotenv import ( ... )' blocks
|
|
2817
|
+
code = re.sub(
|
|
2818
|
+
r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
|
|
2819
|
+
"",
|
|
2820
|
+
code,
|
|
2821
|
+
flags=re.MULTILINE,
|
|
2822
|
+
)
|
|
2823
|
+
|
|
2824
|
+
# 3) Remove 'import dotenv' (with optional alias). Capture alias names.
|
|
2825
|
+
aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
|
|
2826
|
+
code, flags=re.MULTILINE)
|
|
2827
|
+
code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
|
|
2828
|
+
"", code, flags=re.MULTILINE)
|
|
2829
|
+
|
|
2830
|
+
# 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
|
|
2831
|
+
# e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
|
|
2832
|
+
fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
|
|
2833
|
+
# bare calls
|
|
2834
|
+
code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2835
|
+
# dotted calls with any identifier prefix (alias or module)
|
|
2836
|
+
code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
|
|
2837
|
+
"", code, flags=re.MULTILINE)
|
|
2838
|
+
|
|
2839
|
+
# 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
|
|
2840
|
+
for al in aliases:
|
|
2841
|
+
code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2842
|
+
|
|
2843
|
+
# 6) Tidy excess blank lines
|
|
2844
|
+
code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
|
|
2845
|
+
return code
|
|
2846
|
+
|
|
2847
|
+
|
|
2848
|
+
def fix_predict_calls_records_arg(code: str) -> str:
|
|
2849
|
+
"""
|
|
2850
|
+
If generated code calls predict_* with a list-of-dicts via .to_dict('records')
|
|
2851
|
+
(or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
|
|
2852
|
+
Works line-by-line to avoid over-rewrites elsewhere.
|
|
2853
|
+
Examples fixed:
|
|
2854
|
+
predict_patient(X_test.iloc[:5].to_dict('records'))
|
|
2855
|
+
predict_risk(df.head(3).to_dict(orient="records"))
|
|
2856
|
+
→ predict_patient(X_test.iloc[:5])
|
|
2857
|
+
"""
|
|
2858
|
+
fixed_lines = []
|
|
2859
|
+
for line in code.splitlines():
|
|
2860
|
+
if "predict_" in line and "to_dict" in line and "records" in line:
|
|
2861
|
+
line = re.sub(
|
|
2862
|
+
r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
|
|
2863
|
+
"",
|
|
2864
|
+
line
|
|
2865
|
+
)
|
|
2866
|
+
fixed_lines.append(line)
|
|
2867
|
+
return "\n".join(fixed_lines)
|
|
2868
|
+
|
|
2869
|
+
|
|
2870
|
+
def fix_fstring_backslash_paths(code: str) -> str:
|
|
2871
|
+
"""
|
|
2872
|
+
Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
|
|
2873
|
+
→ f"...{os.path.join(out_dir, r'plots\\img.png')}"
|
|
2874
|
+
Only touches f-strings that contain a backslash path inside {...}.
|
|
2875
|
+
"""
|
|
2876
|
+
def _fix_line(line: str) -> str:
|
|
2877
|
+
# quick check: only f-strings need scanning
|
|
2878
|
+
if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
|
|
2879
|
+
return line
|
|
2880
|
+
# {var\rest-of-path} where var can be dotted (e.g., cfg.out)
|
|
2881
|
+
pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
|
|
2882
|
+
def repl(m):
|
|
2883
|
+
left = m.group(1)
|
|
2884
|
+
right = m.group(2).strip().replace('"', '\\"')
|
|
2885
|
+
return "{os.path.join(" + left + ', r"' + right + '")}'
|
|
2886
|
+
return pattern.sub(repl, line)
|
|
2887
|
+
|
|
2888
|
+
return "\n".join(_fix_line(ln) for ln in code.splitlines())
|
|
2889
|
+
|
|
2890
|
+
|
|
2891
|
+
def ensure_os_import(code: str) -> str:
|
|
2892
|
+
"""
|
|
2893
|
+
If os.path.join is used but 'import os' is missing, inject it at the top.
|
|
2894
|
+
"""
|
|
2895
|
+
needs = "os.path.join(" in code
|
|
2896
|
+
has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
|
|
2897
|
+
has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
|
|
2898
|
+
if needs and not (has_import_os or has_from_os):
|
|
2899
|
+
return "import os\n" + code
|
|
2900
|
+
return code
|
|
2901
|
+
|
|
2902
|
+
|
|
2903
|
+
def fix_seaborn_boxplot_nameerror(code: str) -> str:
|
|
2904
|
+
"""
|
|
2905
|
+
Fix bad calls like: sns.boxplot(boxplot)
|
|
2906
|
+
Heuristic:
|
|
2907
|
+
- If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
|
|
2908
|
+
- Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
|
|
2909
|
+
- Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
|
|
2910
|
+
- Else → sns.boxplot(ax=ax)
|
|
2911
|
+
Ensures a matplotlib Axes 'ax' exists.
|
|
2912
|
+
"""
|
|
2913
|
+
pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
|
|
2914
|
+
if not pattern.search(code):
|
|
2915
|
+
return code
|
|
2916
|
+
|
|
2917
|
+
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2918
|
+
has_var = re.search(r"\bvar\b", code) is not None
|
|
2919
|
+
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2920
|
+
|
|
2921
|
+
if has_plot_df and has_var and has_fh:
|
|
2922
|
+
replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2923
|
+
elif has_plot_df and has_var:
|
|
2924
|
+
replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
|
|
2925
|
+
elif has_plot_df:
|
|
2926
|
+
replacement = "sns.boxplot(data=plot_df, ax=ax)"
|
|
2927
|
+
else:
|
|
2928
|
+
replacement = "sns.boxplot(ax=ax)"
|
|
2929
|
+
|
|
2930
|
+
fixed = pattern.sub(replacement, code)
|
|
2931
|
+
|
|
2932
|
+
# Ensure 'fig, ax = plt.subplots(...)' exists
|
|
2933
|
+
if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
|
|
2934
|
+
# Insert right before the first seaborn call
|
|
2935
|
+
m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
|
|
2936
|
+
insert_at = m.start() if m else 0
|
|
2937
|
+
fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
|
|
2938
|
+
|
|
2939
|
+
return fixed
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
def fix_seaborn_barplot_nameerror(code: str) -> str:
|
|
2943
|
+
"""
|
|
2944
|
+
Fix bad calls like: sns.barplot(barplot)
|
|
2945
|
+
Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
|
|
2946
|
+
otherwise degrade safely to an empty call on an existing Axes.
|
|
2947
|
+
"""
|
|
2948
|
+
import re
|
|
2949
|
+
pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
|
|
2950
|
+
if not pattern.search(code):
|
|
2951
|
+
return code
|
|
2952
|
+
|
|
2953
|
+
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2954
|
+
has_var = re.search(r"\bvar\b", code) is not None
|
|
2955
|
+
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2956
|
+
|
|
2957
|
+
if has_plot_df and has_var and has_fh:
|
|
2958
|
+
replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2959
|
+
elif has_plot_df and has_var:
|
|
2960
|
+
replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2961
|
+
elif has_plot_df:
|
|
2962
|
+
replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2963
|
+
else:
|
|
2964
|
+
replacement = "sns.barplot(ax=ax)"
|
|
2965
|
+
|
|
2966
|
+
# ensure an Axes 'ax' exists (no-op if already present)
|
|
2967
|
+
if "ax =" not in code:
|
|
2968
|
+
code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2969
|
+
|
|
2970
|
+
return pattern.sub(replacement, code)
|
|
2971
|
+
|
|
2972
|
+
|
|
2973
|
+
def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2974
|
+
"""
|
|
2975
|
+
Parses the raw text to extract and format the 'refined question',
|
|
2976
|
+
'intents (tasks)', and 'chronology of tasks' sections.
|
|
2977
|
+
Args:
|
|
2978
|
+
raw_text: The complete input string containing the ML pipeline structure.
|
|
2979
|
+
Returns:
|
|
2980
|
+
A tuple containing:
|
|
2981
|
+
(formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2982
|
+
"""
|
|
2983
|
+
# --- 1. Regex Pattern to Extract Sections ---
|
|
2984
|
+
# The pattern uses capturing groups (?) to look for the section headers
|
|
2985
|
+
# (e.g., 'refined question:') and captures all the content until the next
|
|
2986
|
+
# section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2987
|
+
|
|
2988
|
+
pattern = re.compile(
|
|
2989
|
+
r"refined question:(?P<question>.*?)"
|
|
2990
|
+
r"intents \(tasks\):(?P<intents>.*?)"
|
|
2991
|
+
r"Chronology of tasks:(?P<chronology>.*)",
|
|
2992
|
+
re.IGNORECASE | re.DOTALL
|
|
2993
|
+
)
|
|
2994
|
+
|
|
2995
|
+
match = pattern.search(raw_text)
|
|
2996
|
+
|
|
2997
|
+
if not match:
|
|
2998
|
+
raise ValueError("Input text structure does not match the expected pattern.")
|
|
2999
|
+
|
|
3000
|
+
# --- 2. Extract Content ---
|
|
3001
|
+
question_content = match.group('question').strip()
|
|
3002
|
+
intents_content = match.group('intents').strip()
|
|
3003
|
+
chronology_content = match.group('chronology').strip()
|
|
3004
|
+
|
|
3005
|
+
# --- 3. Formatting Functions ---
|
|
3006
|
+
|
|
3007
|
+
def format_question(content):
|
|
3008
|
+
"""Formats the Refined Question section."""
|
|
3009
|
+
# Clean up leading/trailing whitespace and ensure clean paragraphs
|
|
3010
|
+
content = content.strip().replace('\n', ' ').replace(' ', ' ')
|
|
3011
|
+
|
|
3012
|
+
# Simple formatting using Markdown headers and bolding
|
|
3013
|
+
formatted = (
|
|
3014
|
+
# "## 1. Project Goal and Objectives\n\n"
|
|
3015
|
+
"<b> Refined Question:</b>\n"
|
|
3016
|
+
f"{content}\n"
|
|
3017
|
+
)
|
|
3018
|
+
return formatted
|
|
3019
|
+
|
|
3020
|
+
def format_intents(content):
|
|
3021
|
+
"""Formats the Intents (Tasks) section as a structured list."""
|
|
3022
|
+
# Use regex to find and format each numbered task
|
|
3023
|
+
# It finds 'N. **Text** - ...' and breaks it down.
|
|
3024
|
+
|
|
3025
|
+
tasks = []
|
|
3026
|
+
# Pattern: N. **Text** - Content (including newlines, non-greedy)
|
|
3027
|
+
# We need to explicitly handle the list items starting with '-' within the content
|
|
3028
|
+
task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
|
|
3029
|
+
|
|
3030
|
+
# Split the content by lines and join tasks back into clean strings
|
|
3031
|
+
raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
|
|
3032
|
+
|
|
3033
|
+
for task in raw_tasks:
|
|
3034
|
+
# Replace the initial task number and **Heading** with a Heading 3
|
|
3035
|
+
task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
|
|
3036
|
+
|
|
3037
|
+
# Replace list markers (' - ') with Markdown bullets ('* ') for clarity
|
|
3038
|
+
task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
|
|
3039
|
+
tasks.append(task)
|
|
3040
|
+
|
|
3041
|
+
formatted_tasks = "\n\n".join(tasks)
|
|
3042
|
+
|
|
3043
|
+
return (
|
|
3044
|
+
"\n---\n"
|
|
3045
|
+
"## 2. Methodology and Tasks\n\n"
|
|
3046
|
+
f"{formatted_tasks}\n"
|
|
3047
|
+
)
|
|
3048
|
+
|
|
3049
|
+
def format_chronology(content):
|
|
3050
|
+
"""Formats the Chronology section."""
|
|
3051
|
+
# Uses the given LaTeX format
|
|
3052
|
+
content = content.strip().replace(' ', ' \rightarrow ')
|
|
3053
|
+
formatted = (
|
|
3054
|
+
"\n---\n"
|
|
3055
|
+
"## 3. Chronology of Tasks\n"
|
|
3056
|
+
f"$$\\text{{{content}}}$$"
|
|
3057
|
+
)
|
|
3058
|
+
return formatted
|
|
3059
|
+
|
|
3060
|
+
# --- 4. Format and Return ---
|
|
3061
|
+
formatted_question = format_question(question_content)
|
|
3062
|
+
formatted_intents = format_intents(intents_content)
|
|
3063
|
+
formatted_chronology = format_chronology(chronology_content)
|
|
3064
|
+
|
|
3065
|
+
return formatted_question, formatted_intents, formatted_chronology
|
|
3066
|
+
|
|
3067
|
+
|
|
3068
|
+
def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
3069
|
+
"""Combines all formatted parts into a final report string."""
|
|
3070
|
+
return (
|
|
3071
|
+
"# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
3072
|
+
f"{formatted_question}\n"
|
|
3073
|
+
f"{formatted_intents}\n"
|
|
3074
|
+
f"{formatted_chronology}\n"
|
|
3075
|
+
)
|
|
3076
|
+
|
|
3077
|
+
|
|
3078
|
+
def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
3079
|
+
"""
|
|
3080
|
+
Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
3081
|
+
from_predictions(...) which works for multi-label loops without requiring
|
|
3082
|
+
the estimator to expose _estimator_type.
|
|
3083
|
+
"""
|
|
3084
|
+
return re.sub(
|
|
3085
|
+
r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
3086
|
+
r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
989
3087
|
code
|
|
990
3088
|
)
|
|
991
3089
|
|
|
3090
|
+
|
|
3091
|
+
def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
3092
|
+
"""
|
|
3093
|
+
Ensure every Matplotlib/Seaborn Axes has a title.
|
|
3094
|
+
Uses refined_question -> askai_question -> fallback.
|
|
3095
|
+
Only sets a title if it's currently empty.
|
|
3096
|
+
"""
|
|
3097
|
+
import matplotlib.pyplot as plt
|
|
3098
|
+
|
|
3099
|
+
def _all_figures():
|
|
3100
|
+
try:
|
|
3101
|
+
from matplotlib._pylab_helpers import Gcf
|
|
3102
|
+
return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
3103
|
+
except Exception:
|
|
3104
|
+
# Best effort fallback
|
|
3105
|
+
nums = plt.get_fignums()
|
|
3106
|
+
return [plt.figure(n) for n in nums] if nums else []
|
|
3107
|
+
|
|
3108
|
+
# Choose a concise title
|
|
3109
|
+
title = None
|
|
3110
|
+
if isinstance(ctx, dict):
|
|
3111
|
+
title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
3112
|
+
title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
3113
|
+
|
|
3114
|
+
for fig in _all_figures():
|
|
3115
|
+
for ax in getattr(fig, "axes", []):
|
|
3116
|
+
try:
|
|
3117
|
+
if not (ax.get_title() or "").strip():
|
|
3118
|
+
ax.set_title(title)
|
|
3119
|
+
except Exception:
|
|
3120
|
+
pass
|
|
3121
|
+
try:
|
|
3122
|
+
fig.tight_layout()
|
|
3123
|
+
except Exception:
|
|
3124
|
+
pass
|
|
3125
|
+
|
|
3126
|
+
|
|
3127
|
+
def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
3128
|
+
"""
|
|
3129
|
+
Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
3130
|
+
SB_barplot(barplot) -> SB_barplot()
|
|
3131
|
+
SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
3132
|
+
sns.barplot(barplot) -> SB_barplot()
|
|
3133
|
+
sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
3134
|
+
Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
3135
|
+
"""
|
|
3136
|
+
names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
3137
|
+
for n in names:
|
|
3138
|
+
# SB_* with sentinel as the first arg (with or without trailing args)
|
|
3139
|
+
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3140
|
+
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3141
|
+
# sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
3142
|
+
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3143
|
+
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
992
3144
|
return code
|
|
993
3145
|
|
|
3146
|
+
|
|
3147
|
+
def patch_rmse_calls(code: str) -> str:
|
|
3148
|
+
"""
|
|
3149
|
+
Make RMSE robust across sklearn versions.
|
|
3150
|
+
- Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
3151
|
+
- Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
3152
|
+
"""
|
|
3153
|
+
import re
|
|
3154
|
+
# (a) Specific RMSE pattern
|
|
3155
|
+
code = re.sub(
|
|
3156
|
+
r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3157
|
+
r"_SMX_rmse(\1)",
|
|
3158
|
+
code,
|
|
3159
|
+
flags=re.DOTALL
|
|
3160
|
+
)
|
|
3161
|
+
# (b) Guard any other MSE calls
|
|
3162
|
+
code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3163
|
+
return code
|