syntaxmatrix 1.4.6__py3-none-any.whl → 2.5.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. syntaxmatrix/__init__.py +13 -8
  2. syntaxmatrix/agentic/__init__.py +0 -0
  3. syntaxmatrix/agentic/agent_tools.py +24 -0
  4. syntaxmatrix/agentic/agents.py +810 -0
  5. syntaxmatrix/agentic/code_tools_registry.py +37 -0
  6. syntaxmatrix/agentic/model_templates.py +1790 -0
  7. syntaxmatrix/auth.py +308 -14
  8. syntaxmatrix/commentary.py +328 -0
  9. syntaxmatrix/core.py +993 -375
  10. syntaxmatrix/dataset_preprocessing.py +218 -0
  11. syntaxmatrix/db.py +92 -95
  12. syntaxmatrix/display.py +95 -121
  13. syntaxmatrix/generate_page.py +634 -0
  14. syntaxmatrix/gpt_models_latest.py +46 -0
  15. syntaxmatrix/history_store.py +26 -29
  16. syntaxmatrix/kernel_manager.py +96 -17
  17. syntaxmatrix/llm_store.py +1 -1
  18. syntaxmatrix/plottings.py +6 -0
  19. syntaxmatrix/profiles.py +64 -8
  20. syntaxmatrix/project_root.py +55 -43
  21. syntaxmatrix/routes.py +5072 -1398
  22. syntaxmatrix/session.py +19 -0
  23. syntaxmatrix/settings/logging.py +40 -0
  24. syntaxmatrix/settings/model_map.py +300 -33
  25. syntaxmatrix/settings/prompts.py +273 -62
  26. syntaxmatrix/settings/string_navbar.py +3 -3
  27. syntaxmatrix/static/docs.md +272 -0
  28. syntaxmatrix/static/icons/favicon.png +0 -0
  29. syntaxmatrix/static/icons/hero_bg.jpg +0 -0
  30. syntaxmatrix/templates/dashboard.html +608 -147
  31. syntaxmatrix/templates/docs.html +71 -0
  32. syntaxmatrix/templates/error.html +2 -3
  33. syntaxmatrix/templates/login.html +1 -0
  34. syntaxmatrix/templates/register.html +1 -0
  35. syntaxmatrix/ui_modes.py +14 -0
  36. syntaxmatrix/utils.py +2482 -159
  37. syntaxmatrix/vectorizer.py +16 -12
  38. {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/METADATA +20 -17
  39. syntaxmatrix-2.5.5.4.dist-info/RECORD +68 -0
  40. syntaxmatrix/model_templates.py +0 -30
  41. syntaxmatrix/static/icons/favicon.ico +0 -0
  42. syntaxmatrix-1.4.6.dist-info/RECORD +0 -54
  43. {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/WHEEL +0 -0
  44. {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/licenses/LICENSE.txt +0 -0
  45. {syntaxmatrix-1.4.6.dist-info → syntaxmatrix-2.5.5.4.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py CHANGED
@@ -1,50 +1,1188 @@
1
- import openai
2
- from openai import OpenAI
3
- import re, os, textwrap
1
+ from __future__ import annotations
2
+ import re, textwrap
4
3
  import pandas as pd
5
- import matplotlib.pyplot as plt
4
+ import numpy as np
6
5
  import warnings
7
- from .model_templates import classification
8
- import syntaxmatrix as smx
9
6
 
7
+ from difflib import get_close_matches
8
+ from typing import Iterable, Tuple, Dict
9
+ import inspect
10
+ from sklearn.preprocessing import OneHotEncoder
11
+
12
+
13
+ from syntaxmatrix.agentic.model_templates import (
14
+ classification, regression, multilabel_classification,
15
+ eda_overview, eda_correlation,
16
+ anomaly_detection, ts_anomaly_detection,
17
+ dimensionality_reduction, feature_selection,
18
+ time_series_forecasting, time_series_classification,
19
+ unknown_group_proxy_pack, viz_line,
20
+ clustering, recommendation, topic_modelling,
21
+ viz_pie, viz_count_bar, viz_box, viz_scatter,
22
+ viz_stacked_bar, viz_distribution, viz_area, viz_kde,
23
+ )
24
+ import ast
10
25
 
11
26
  warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
12
27
 
28
+ INJECTABLE_INTENTS = [
29
+ "classification",
30
+ "multilabel_classification",
31
+ "regression",
32
+ "anomaly_detection",
33
+ "time_series_forecasting",
34
+ "time_series_classification",
35
+ "ts_anomaly_detection",
36
+ "dimensionality_reduction",
37
+ "feature_selection",
38
+ "clustering",
39
+ "eda",
40
+ "correlation_analysis",
41
+ "visualisation",
42
+ "recommendation",
43
+ "topic_modelling",
44
+ ]
45
+
46
+ def classify_ml_job(prompt: str) -> str:
47
+ """
48
+ Very-light intent classifier.
49
+ Returns one of:
50
+ 'stat_test' | 'time_series' | 'clustering'
51
+ 'classification' | 'regression' | 'eda'
52
+ """
53
+ p = prompt.lower()
54
+
55
+ greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
56
+ if any(p.startswith(g) or p == g for g in greetings):
57
+ return "greeting"
13
58
 
14
- # def ai_generate_code(question, df):
15
- # provider = os.environ.get("provider", "openai")
16
- # model = os.environ.get("model", "gpt-4o-mini")
17
- # api_key = os.environ.get("OPENAI_API_KEY")
59
+ # Feature selection / importance intent
60
+ if any(k in p for k in (
61
+ "feature selection", "select k best", "selectkbest", "rfe",
62
+ "mutual information", "feature importance", "permutation importance",
63
+ "feature engineering suggestions"
64
+ )):
65
+ return "feature_selection"
66
+
67
+ # Dimensionality reduction intent
68
+ if any(k in p for k in (
69
+ "pca", "principal component", "dimensionality reduction",
70
+ "reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
71
+ )):
72
+ return "dimensionality_reduction"
73
+
74
+ # Anomaly / outlier intent
75
+ if any(k in p for k in (
76
+ "anomaly", "anomalies", "outlier", "outliers", "novelty",
77
+ "fraud", "deviation", "rare event", "rare events", "odd pattern",
78
+ "suspicious"
79
+ )):
80
+ return "anomaly_detection"
18
81
 
19
- # llm = OpenAI(api_key=api_key)
82
+ if any(k in p for k in ("t-test", "anova", "p-value")):
83
+ return "stat_test"
84
+ if "forecast" in p or "prophet" in p:
85
+ return "time_series"
86
+ if "cluster" in p or "kmeans" in p:
87
+ return "clustering"
88
+ if any(k in p for k in ("accuracy", "precision", "roc")):
89
+ return "classification"
90
+ if any(k in p for k in ("rmse", "r2", "mae")):
91
+ return "regression"
20
92
 
21
- # context = f"Columns: {list(df.columns)}\nDtypes: {df.dtypes.astype(str).to_dict()}\n"
22
- # prompt = (
23
- # f"You are an expert Python data analyst. Given the dataframe `df` with the following context:\n{context}\n"
24
- # f"Write clean, working Python code that answers the question below. "
25
- # f"DO NOT explain, just output the code only (NO comments or text):\n"
26
- # f"Question: {question}\n"
27
- # f"Output only the working code needed. Assume df is already defined."
28
- # f"Produce at least one visible result"
29
- # f"(syntaxmatrix.display.show(), display(), plt.show())."
30
- # )
31
-
32
- # if provider.lower() == "openai":
33
- # response = llm.chat.completions.create(
34
- # model=model,
35
- # messages=[{"role": "user", "content": prompt}],
36
- # temperature=0.0,
37
- # max_tokens=1024,
38
- # )
39
- # code = response.choices[0].message.content
40
- # if "```python" in code:
41
- # code = code.split("```python")[1].split("```")[0].strip()
42
- # elif "```" in code:
43
- # code = code.split("```")[1].split("```")[0].strip()
44
-
45
- # code = strip_describe_slice(code)
46
- # code = drop_bad_classification_metrics(code, df)
47
- # return code.strip()
93
+ return "eda"
94
+
95
+
96
+ def harden_ai_code(code: str) -> str:
97
+ """
98
+ Make any AI-generated cell resilient:
99
+ - Safe seaborn wrappers + sentinel vars (boxplot/barplot/etc.)
100
+ - Remove 'numeric_only=' args
101
+ - Replace pd.concat(...) with _safe_concat(...)
102
+ - Relax 'required_cols' hard fails
103
+ - Make static numeric_vars dynamic
104
+ - Wrap the whole block in try/except so no exception bubbles up
105
+ """
106
+ # Remove any LLM-added try/except blocks (hardener adds its own)
107
+ import re
108
+
109
+ def strip_placeholders(code: str) -> str:
110
+ code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
111
+ "show('⚠ Block skipped due to an error.')",
112
+ code)
113
+ code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
114
+ return code
115
+
116
+ def _indent(code: str, spaces: int = 4) -> str:
117
+ pad = " " * spaces
118
+ return "\n".join(pad + line for line in code.splitlines())
119
+
120
+ def _SMX_OHE(**k):
121
+ # normalise arg name across sklearn versions
122
+ if "sparse" in k and "sparse_output" not in k:
123
+ k["sparse_output"] = k.pop("sparse")
124
+ # default behaviour we want
125
+ k.setdefault("handle_unknown", "ignore")
126
+ k.setdefault("sparse_output", False)
127
+ try:
128
+ # if running on old sklearn without sparse_output, translate back
129
+ if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
130
+ if "sparse_output" in k:
131
+ k["sparse"] = k.pop("sparse_output")
132
+ return OneHotEncoder(**k)
133
+ except TypeError:
134
+ # final fallback: try legacy name
135
+ if "sparse_output" in k:
136
+ k["sparse"] = k.pop("sparse_output")
137
+ return OneHotEncoder(**k)
138
+
139
+ def _strip_stray_backrefs(code: str) -> str:
140
+ code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
141
+ code = re.sub(r'(?m)[;]\s*\\\d+\s*', '; ', code)
142
+ return code
143
+
144
+ def _wrap_metric_calls(code: str) -> str:
145
+ names = [
146
+ "r2_score","accuracy_score","precision_score","recall_score","f1_score",
147
+ "roc_auc_score","classification_report","confusion_matrix",
148
+ "mean_absolute_error","mean_absolute_percentage_error",
149
+ "explained_variance_score","log_loss","average_precision_score",
150
+ "precision_recall_fscore_support"
151
+ ]
152
+ pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
153
+ def repl(m):
154
+ prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
155
+ name = m.group(2)
156
+ return f"_SMX_call({prefix}{name}, "
157
+ return pat.sub(repl, code)
158
+
159
+ def _smx_patch_mean_squared_error_squared_kw():
160
+ """
161
+ sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
162
+ Patch the module attr so 'from sklearn.metrics import mean_squared_error'
163
+ receives a wrapper that drops 'squared' if the underlying call rejects it.
164
+ """
165
+ try:
166
+ import sklearn.metrics as _sm
167
+ _orig = getattr(_sm, "mean_squared_error", None)
168
+ if not callable(_orig):
169
+ return
170
+ def _mse_compat(y_true, y_pred, *a, **k):
171
+ if "squared" in k:
172
+ try:
173
+ return _orig(y_true, y_pred, *a, **k)
174
+ except TypeError:
175
+ k.pop("squared", None)
176
+ return _orig(y_true, y_pred, *a, **k)
177
+ return _orig(y_true, y_pred, *a, **k)
178
+ _sm.mean_squared_error = _mse_compat
179
+ except Exception:
180
+ pass
181
+
182
+ def _smx_patch_kmeans_n_init_auto():
183
+ """
184
+ sklearn>=1.4 accepts n_init='auto'; older versions want an int.
185
+ Patch sklearn.cluster.KMeans so 'auto' is converted to 10 if TypeError occurs.
186
+ """
187
+ try:
188
+ import sklearn.cluster as _sc
189
+ _Orig = getattr(_sc, "KMeans", None)
190
+ if _Orig is None:
191
+ return
192
+ class KMeansCompat(_Orig):
193
+ def __init__(self, *a, **k):
194
+ if isinstance(k.get("n_init", None), str):
195
+ try:
196
+ super().__init__(*a, **k)
197
+ return
198
+ except TypeError:
199
+ k["n_init"] = 10
200
+ super().__init__(*a, **k)
201
+ _sc.KMeans = KMeansCompat
202
+ except Exception:
203
+ pass
204
+
205
+ def _smx_patch_ohe_name_api():
206
+ """
207
+ Guard get_feature_names_out on older OneHotEncoder.
208
+ Your templates already use _SMX_OHE; this adds a soft fallback for feature names.
209
+ """
210
+ try:
211
+ from sklearn.preprocessing import OneHotEncoder as _OHE
212
+ _orig_get = getattr(_OHE, "get_feature_names_out", None)
213
+ if _orig_get is None:
214
+ # Monkey-patch instance method via mixin
215
+ def _fallback_get_feature_names_out(self, input_features=None):
216
+ cats = getattr(self, "categories_", None) or []
217
+ input_features = input_features or [f"x{i}" for i in range(len(cats))]
218
+ names = []
219
+ for base, cat_list in zip(input_features, cats):
220
+ for j, _ in enumerate(cat_list):
221
+ names.append(f"{base}__{j}")
222
+ return names
223
+ _OHE.get_feature_names_out = _fallback_get_feature_names_out
224
+ except Exception:
225
+ pass
226
+
227
+ # Register and run patches once per execution
228
+ for _patch in (
229
+ _smx_patch_mean_squared_error_squared_kw,
230
+ _smx_patch_kmeans_n_init_auto,
231
+ _smx_patch_ohe_name_api,
232
+ ):
233
+ try:
234
+ _patch()
235
+ except Exception:
236
+ pass
237
+
238
+ PREFACE = (
239
+ "# === SMX Auto-Hardening Preface (do not edit) ===\n"
240
+ "import warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt\n"
241
+ "warnings.filterwarnings('ignore')\n"
242
+ "try:\n"
243
+ " import seaborn as sns\n"
244
+ "except Exception:\n"
245
+ " class _Dummy:\n"
246
+ " def __getattr__(self, name):\n"
247
+ " def _f(*a, **k):\n"
248
+ " from syntaxmatrix.display import show\n"
249
+ " show('⚠ seaborn not available; plot skipped.')\n"
250
+ " return _f\n"
251
+ " sns = _Dummy()\n"
252
+ "\n"
253
+ "from syntaxmatrix.display import show as _SMX_base_show\n"
254
+ "def _SMX_caption_from_ctx():\n"
255
+ " g = globals()\n"
256
+ " t = g.get('refined_question') or g.get('askai_question') or 'Table'\n"
257
+ " return str(t).strip().splitlines()[0][:120]\n"
258
+ "\n"
259
+ "def show(obj, title=None):\n"
260
+ " try:\n"
261
+ " import pandas as pd\n"
262
+ " if isinstance(obj, pd.DataFrame):\n"
263
+ " cap = (title or _SMX_caption_from_ctx())\n"
264
+ " try:\n"
265
+ " return _SMX_base_show(obj.style.set_caption(cap))\n"
266
+ " except Exception:\n"
267
+ " pass\n"
268
+ " except Exception:\n"
269
+ " pass\n"
270
+ " return _SMX_base_show(obj)\n"
271
+ "\n"
272
+ "def _SMX_axes_have_titles(fig=None):\n"
273
+ " import matplotlib.pyplot as _plt\n"
274
+ " fig = fig or _plt.gcf()\n"
275
+ " try:\n"
276
+ " for _ax in fig.get_axes():\n"
277
+ " if (_ax.get_title() or '').strip():\n"
278
+ " return True\n"
279
+ " except Exception:\n"
280
+ " pass\n"
281
+ " return False\n"
282
+ "\n"
283
+ "def _SMX_export_png():\n"
284
+ " import io, base64\n"
285
+ " fig = plt.gcf()\n"
286
+ " try:\n"
287
+ " if not _SMX_axes_have_titles(fig):\n"
288
+ " fig.suptitle(_SMX_caption_from_ctx(), fontsize=10)\n"
289
+ " except Exception:\n"
290
+ " pass\n"
291
+ " buf = io.BytesIO()\n"
292
+ " plt.savefig(buf, format='png', bbox_inches='tight')\n"
293
+ " buf.seek(0)\n"
294
+ " from IPython.display import display, HTML\n"
295
+ " _img = base64.b64encode(buf.read()).decode('ascii')\n"
296
+ " display(HTML(f\"<img src='data:image/png;base64,{_img}' style='max-width:100%;height:auto;border:1px solid #ccc;border-radius:4px;'/>\"))\n"
297
+ " plt.close()\n"
298
+ "\n"
299
+ "def _pick_df():\n"
300
+ " return globals().get('df', None)\n"
301
+ "\n"
302
+ "def _pick_ax_slot():\n"
303
+ " ax = None\n"
304
+ " try:\n"
305
+ " _axes = globals().get('axes', None)\n"
306
+ " import numpy as _np\n"
307
+ " if _axes is not None:\n"
308
+ " arr = _np.ravel(_axes)\n"
309
+ " for _a in arr:\n"
310
+ " try:\n"
311
+ " if hasattr(_a,'has_data') and not _a.has_data():\n"
312
+ " ax = _a; break\n"
313
+ " except Exception:\n"
314
+ " continue\n"
315
+ " except Exception:\n"
316
+ " ax = None\n"
317
+ " return ax\n"
318
+ "\n"
319
+ "def _first_numeric(_d):\n"
320
+ " import numpy as np, pandas as pd\n"
321
+ " try:\n"
322
+ " preferred = [\"median_house_value\", \"price\", \"value\", \"target\", \"label\", \"y\"]\n"
323
+ " for c in preferred:\n"
324
+ " if c in _d.columns and pd.api.types.is_numeric_dtype(_d[c]):\n"
325
+ " return c\n"
326
+ " cols = _d.select_dtypes(include=[np.number]).columns.tolist()\n"
327
+ " return cols[0] if cols else None\n"
328
+ " except Exception:\n"
329
+ " return None\n"
330
+ "\n"
331
+ "def _first_categorical(_d):\n"
332
+ " import pandas as pd, numpy as np\n"
333
+ " try:\n"
334
+ " num = set(_d.select_dtypes(include=[np.number]).columns.tolist())\n"
335
+ " cand = [c for c in _d.columns if c not in num and _d[c].nunique(dropna=True) <= 50]\n"
336
+ " return cand[0] if cand else None\n"
337
+ " except Exception:\n"
338
+ " return None\n"
339
+ "\n"
340
+ "boxplot = barplot = histplot = distplot = lineplot = countplot = heatmap = pairplot = None\n"
341
+ "\n"
342
+ "def _safe_plot(func, *args, **kwargs):\n"
343
+ " try:\n"
344
+ " ax = func(*args, **kwargs)\n"
345
+ " if ax is None:\n"
346
+ " ax = plt.gca()\n"
347
+ " try:\n"
348
+ " if hasattr(ax, 'has_data') and not ax.has_data():\n"
349
+ " from syntaxmatrix.display import show as _show\n"
350
+ " _show('⚠ Empty plot: no data drawn.')\n"
351
+ " except Exception:\n"
352
+ " pass\n"
353
+ " try: plt.tight_layout()\n"
354
+ " except Exception: pass\n"
355
+ " return ax\n"
356
+ " except Exception as e:\n"
357
+ " from syntaxmatrix.display import show as _show\n"
358
+ " _show(f'⚠ Plot skipped: {type(e).__name__}: {e}')\n"
359
+ " return None\n"
360
+ "\n"
361
+ "def SB_histplot(*a, **k):\n"
362
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
363
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
364
+ " if (not a or _sentinel) and not k:\n"
365
+ " d = _pick_df()\n"
366
+ " if d is not None:\n"
367
+ " x = _first_numeric(d)\n"
368
+ " if x is not None:\n"
369
+ " def _draw():\n"
370
+ " plt.hist(d[x].dropna())\n"
371
+ " ax = plt.gca()\n"
372
+ " if not (ax.get_title() or '').strip():\n"
373
+ " ax.set_title(f'Distribution of {x}')\n"
374
+ " return ax\n"
375
+ " return _safe_plot(lambda **kw: _draw())\n"
376
+ " if _missing:\n"
377
+ " return _safe_plot(lambda **kw: plt.hist([]))\n"
378
+ " if _sentinel:\n"
379
+ " a = a[1:]\n"
380
+ " return _safe_plot(getattr(sns,'histplot', plt.hist), *a, **k)\n"
381
+ "\n"
382
+ "def SB_barplot(*a, **k):\n"
383
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
384
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
385
+ " _ax = k.get('ax') or _pick_ax_slot()\n"
386
+ " if _ax is not None:\n"
387
+ " try: plt.sca(_ax)\n"
388
+ " except Exception: pass\n"
389
+ " k.setdefault('ax', _ax)\n"
390
+ " if (not a or _sentinel) and not k:\n"
391
+ " d = _pick_df()\n"
392
+ " if d is not None:\n"
393
+ " x = _first_categorical(d)\n"
394
+ " y = _first_numeric(d)\n"
395
+ " if x and y:\n"
396
+ " import pandas as _pd\n"
397
+ " g = d.groupby(x)[y].mean().reset_index()\n"
398
+ " def _draw():\n"
399
+ " if _missing:\n"
400
+ " plt.bar(g[x], g[y])\n"
401
+ " else:\n"
402
+ " sns.barplot(data=g, x=x, y=y, ax=k.get('ax'))\n"
403
+ " ax = plt.gca()\n"
404
+ " if not (ax.get_title() or '').strip():\n"
405
+ " ax.set_title(f'Mean {y} by {x}')\n"
406
+ " return ax\n"
407
+ " return _safe_plot(lambda **kw: _draw())\n"
408
+ " if _missing:\n"
409
+ " return _safe_plot(lambda **kw: plt.bar([], []))\n"
410
+ " if _sentinel:\n"
411
+ " a = a[1:]\n"
412
+ " return _safe_plot(sns.barplot, *a, **k)\n"
413
+ "\n"
414
+ "def SB_boxplot(*a, **k):\n"
415
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
416
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
417
+ " _ax = k.get('ax') or _pick_ax_slot()\n"
418
+ " if _ax is not None:\n"
419
+ " try: plt.sca(_ax)\n"
420
+ " except Exception: pass\n"
421
+ " k.setdefault('ax', _ax)\n"
422
+ " if (not a or _sentinel) and not k:\n"
423
+ " d = _pick_df()\n"
424
+ " if d is not None:\n"
425
+ " x = _first_categorical(d)\n"
426
+ " y = _first_numeric(d)\n"
427
+ " if x and y:\n"
428
+ " def _draw():\n"
429
+ " if _missing:\n"
430
+ " plt.boxplot(d[y].dropna())\n"
431
+ " else:\n"
432
+ " sns.boxplot(data=d, x=x, y=y, ax=k.get('ax'))\n"
433
+ " ax = plt.gca()\n"
434
+ " if not (ax.get_title() or '').strip():\n"
435
+ " ax.set_title(f'Distribution of {y} by {x}')\n"
436
+ " return ax\n"
437
+ " return _safe_plot(lambda **kw: _draw())\n"
438
+ " if _missing:\n"
439
+ " return _safe_plot(lambda **kw: plt.boxplot([]))\n"
440
+ " if _sentinel:\n"
441
+ " a = a[1:]\n"
442
+ " return _safe_plot(sns.boxplot, *a, **k)\n"
443
+ "\n"
444
+ "def SB_scatterplot(*a, **k):\n"
445
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
446
+ " fn = getattr(sns,'scatterplot', None)\n"
447
+ " # If seaborn is unavailable OR the caller passed (data=..., x='col', y='col'),\n"
448
+ " # use a robust matplotlib path that looks up data and coerces to numeric.\n"
449
+ " if _missing or fn is None:\n"
450
+ " data = k.get('data'); x = k.get('x'); y = k.get('y')\n"
451
+ " if data is not None and isinstance(x, str) and isinstance(y, str) and x in data.columns and y in data.columns:\n"
452
+ " xs = pd.to_numeric(data[x], errors='coerce')\n"
453
+ " ys = pd.to_numeric(data[y], errors='coerce')\n"
454
+ " m = xs.notna() & ys.notna()\n"
455
+ " def _draw():\n"
456
+ " plt.scatter(xs[m], ys[m])\n"
457
+ " ax = plt.gca()\n"
458
+ " if not (ax.get_title() or '').strip():\n"
459
+ " ax.set_title(f'{y} vs {x}')\n"
460
+ " return ax\n"
461
+ " return _safe_plot(lambda **kw: _draw())\n"
462
+ " # else: fall back to auto-pick two numeric columns\n"
463
+ " d = _pick_df()\n"
464
+ " if d is not None:\n"
465
+ " num = d.select_dtypes(include=[np.number]).columns.tolist()\n"
466
+ " if len(num) >= 2:\n"
467
+ " def _draw2():\n"
468
+ " plt.scatter(d[num[0]], d[num[1]])\n"
469
+ " ax = plt.gca()\n"
470
+ " if not (ax.get_title() or '').strip():\n"
471
+ " ax.set_title(f'{num[1]} vs {num[0]}')\n"
472
+ " return ax\n"
473
+ " return _safe_plot(lambda **kw: _draw2())\n"
474
+ " return _safe_plot(lambda **kw: plt.scatter([], []))\n"
475
+ " # seaborn path\n"
476
+ " return _safe_plot(fn, *a, **k)\n"
477
+ "\n"
478
+ "def SB_heatmap(*a, **k):\n"
479
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
480
+ " data = None\n"
481
+ " if a:\n"
482
+ " data = a[0]\n"
483
+ " elif 'data' in k:\n"
484
+ " data = k['data']\n"
485
+ " if data is None:\n"
486
+ " d = _pick_df()\n"
487
+ " try:\n"
488
+ " if d is not None:\n"
489
+ " import numpy as _np\n"
490
+ " data = d.select_dtypes(include=[_np.number]).corr()\n"
491
+ " except Exception:\n"
492
+ " data = None\n"
493
+ " if data is None:\n"
494
+ " from syntaxmatrix.display import show as _show\n"
495
+ " _show('⚠ Heatmap skipped: no data.')\n"
496
+ " return None\n"
497
+ " if not _missing and hasattr(sns, 'heatmap'):\n"
498
+ " _k = {kk: vv for kk, vv in k.items() if kk != 'data'}\n"
499
+ " def _draw():\n"
500
+ " ax = sns.heatmap(data, **_k)\n"
501
+ " try:\n"
502
+ " ax = ax or plt.gca()\n"
503
+ " if not (ax.get_title() or '').strip():\n"
504
+ " ax.set_title('Correlation Heatmap')\n"
505
+ " except Exception:\n"
506
+ " pass\n"
507
+ " return ax\n"
508
+ " return _safe_plot(lambda **kw: _draw())\n"
509
+ " def _mat_heat():\n"
510
+ " im = plt.imshow(data, aspect='auto')\n"
511
+ " try: plt.colorbar()\n"
512
+ " except Exception: pass\n"
513
+ " try:\n"
514
+ " cols = list(getattr(data, 'columns', []))\n"
515
+ " rows = list(getattr(data, 'index', []))\n"
516
+ " if cols: plt.xticks(range(len(cols)), cols, rotation=90)\n"
517
+ " if rows: plt.yticks(range(len(rows)), rows)\n"
518
+ " except Exception:\n"
519
+ " pass\n"
520
+ " ax = plt.gca()\n"
521
+ " try:\n"
522
+ " if not (ax.get_title() or '').strip():\n"
523
+ " ax.set_title('Correlation Heatmap')\n"
524
+ " except Exception:\n"
525
+ " pass\n"
526
+ " return ax\n"
527
+ " return _safe_plot(lambda **kw: _mat_heat())\n"
528
+ "\n"
529
+ "def _safe_concat(objs, **kwargs):\n"
530
+ " import pandas as _pd\n"
531
+ " if objs is None: return _pd.DataFrame()\n"
532
+ " if isinstance(objs,(list,tuple)) and len(objs)==0: return _pd.DataFrame()\n"
533
+ " try: return _pd.concat(objs, **kwargs)\n"
534
+ " except Exception as e:\n"
535
+ " show(f'⚠ concat skipped: {e}')\n"
536
+ " return _pd.DataFrame()\n"
537
+ "\n"
538
+ "from sklearn.preprocessing import OneHotEncoder\n"
539
+ "import inspect\n"
540
+ "def _SMX_OHE(**k):\n"
541
+ " # normalise arg name across sklearn versions\n"
542
+ " if 'sparse' in k and 'sparse_output' not in k:\n"
543
+ " k['sparse_output'] = k.pop('sparse')\n"
544
+ " k.setdefault('handle_unknown','ignore')\n"
545
+ " k.setdefault('sparse_output', False)\n"
546
+ " try:\n"
547
+ " sig = inspect.signature(OneHotEncoder)\n"
548
+ " if 'sparse_output' not in sig.parameters and 'sparse_output' in k:\n"
549
+ " k['sparse'] = k.pop('sparse_output')\n"
550
+ " except Exception:\n"
551
+ " if 'sparse_output' in k:\n"
552
+ " k['sparse'] = k.pop('sparse_output')\n"
553
+ " return OneHotEncoder(**k)\n"
554
+ "\n"
555
+ "import numpy as _np\n"
556
+ "def _SMX_mm(a, b):\n"
557
+ " try:\n"
558
+ " return a @ b # normal path\n"
559
+ " except Exception:\n"
560
+ " try:\n"
561
+ " A = _np.asarray(a); B = _np.asarray(b)\n"
562
+ " # If same 2D shape (e.g. (n,k) & (n,k)), treat as row-wise dot\n"
563
+ " if A.ndim==2 and B.ndim==2 and A.shape==B.shape:\n"
564
+ " return (A * B).sum(axis=1)\n"
565
+ " # Otherwise try element-wise product (broadcast if possible)\n"
566
+ " return A * B\n"
567
+ " except Exception as e:\n"
568
+ " from syntaxmatrix.display import show\n"
569
+ " show(f'⚠ Matmul relaxed: {type(e).__name__}: {e}'); return _np.nan\n"
570
+ "\n"
571
+ "def _SMX_call(fn, *a, **k):\n"
572
+ " try:\n"
573
+ " return fn(*a, **k)\n"
574
+ " except TypeError as e:\n"
575
+ " msg = str(e)\n"
576
+ " if \"unexpected keyword argument 'squared'\" in msg:\n"
577
+ " k.pop('squared', None)\n"
578
+ " return fn(*a, **k)\n"
579
+ " raise\n"
580
+ "\n"
581
+ "def _SMX_rmse(y_true, y_pred):\n"
582
+ " try:\n"
583
+ " from sklearn.metrics import mean_squared_error as _mse\n"
584
+ " try:\n"
585
+ " return _mse(y_true, y_pred, squared=False)\n"
586
+ " except TypeError:\n"
587
+ " return (_mse(y_true, y_pred)) ** 0.5\n"
588
+ " except Exception:\n"
589
+ " import numpy as _np\n"
590
+ " yt = _np.asarray(y_true, dtype=float)\n"
591
+ " yp = _np.asarray(y_pred, dtype=float)\n"
592
+ " diff = yt - yp\n"
593
+ " return float((_np.mean(diff * diff)) ** 0.5)\n"
594
+ "\n"
595
+ "import pandas as _pd\n"
596
+ "import numpy as _np\n"
597
+ "def _SMX_autocoerce_dates(_df):\n"
598
+ " if _df is None or not hasattr(_df, 'columns'): return\n"
599
+ " for c in list(_df.columns):\n"
600
+ " s = _df[c]\n"
601
+ " n = str(c).lower()\n"
602
+ " if _pd.api.types.is_datetime64_any_dtype(s):\n"
603
+ " continue\n"
604
+ " if _pd.api.types.is_object_dtype(s) or ('date' in n or 'time' in n or 'timestamp' in n or n.endswith('_dt')):\n"
605
+ " try:\n"
606
+ " conv = _pd.to_datetime(s, errors='coerce', utc=True).dt.tz_localize(None)\n"
607
+ " # accept only if at least 10% (min 3) parse as dates\n"
608
+ " if getattr(conv, 'notna', lambda: _pd.Series([]))().sum() >= max(3, int(0.1*len(_df))):\n"
609
+ " _df[c] = conv\n"
610
+ " except Exception:\n"
611
+ " pass\n"
612
+ "\n"
613
+ "def _SMX_autocoerce_numeric(_df, cols):\n"
614
+ " if _df is None: return\n"
615
+ " for c in cols:\n"
616
+ " if c in getattr(_df, 'columns', []):\n"
617
+ " try:\n"
618
+ " _df[c] = _pd.to_numeric(_df[c], errors='coerce')\n"
619
+ " except Exception:\n"
620
+ " pass\n"
621
+ "\n"
622
+ "def show(obj, title=None):\n"
623
+ " try:\n"
624
+ " import pandas as pd, numbers\n"
625
+ " cap = (title or _SMX_caption_from_ctx())\n"
626
+ " # 1) DataFrame → Styler with caption\n"
627
+ " if isinstance(obj, pd.DataFrame):\n"
628
+ " try: return _SMX_base_show(obj.style.set_caption(cap))\n"
629
+ " except Exception: pass\n"
630
+ " # 2) dict of scalars → DataFrame with caption\n"
631
+ " if isinstance(obj, dict) and all(isinstance(v, numbers.Number) for v in obj.values()):\n"
632
+ " df_ = pd.DataFrame({'metric': list(obj.keys()), 'value': list(obj.values())})\n"
633
+ " try: return _SMX_base_show(df_.style.set_caption(cap))\n"
634
+ " except Exception: return _SMX_base_show(df_)\n"
635
+ " except Exception:\n"
636
+ " pass\n"
637
+ " return _SMX_base_show(obj)\n"
638
+ )
639
+
640
+ PREFACE_IMPORT = "from syntaxmatrix.smx_preface import *\n"
641
+ # if PREFACE not in code:
642
+ # code = PREFACE_IMPORT + code
643
+
644
+ fixed = code
645
+
646
+ fixed = re.sub(
647
+ r"(?s)^\s*try:\s*(.*?)\s*except\s+Exception\s+as\s+\w+:\s*\n\s*show\([^\n]*\)\s*$",
648
+ r"\1",
649
+ fixed.strip()
650
+ )
651
+
652
+ # 1) Strip numeric_only=... (version-agnostic)
653
+ fixed = re.sub(r",\s*numeric_only\s*=\s*(True|False|None)", "", fixed, flags=re.I)
654
+ fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\s*,\s*", "", fixed, flags=re.I)
655
+ fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\b", "", fixed, flags=re.I)
656
+
657
+ # 2) Use safe seaborn wrappers
658
+ fixed = re.sub(r"\bsns\.boxplot\s*\(", "SB_boxplot(", fixed)
659
+ fixed = re.sub(r"\bsns\.barplot\s*\(", "SB_barplot(", fixed)
660
+ fixed = re.sub(r"\bsns\.histplot\s*\(", "SB_histplot(", fixed)
661
+ fixed = re.sub(r"\bsns\.scatterplot\s*\(", "SB_scatterplot(", fixed)
662
+
663
+ # 3) Guard concat calls
664
+ fixed = re.sub(r"\bpd\.concat\s*\(", "_safe_concat(", fixed)
665
+ fixed = re.sub(r"\bOneHotEncoder\s*\(", "_SMX_OHE(", fixed)
666
+ # Route np.dot to tolerant matmul
667
+ fixed = re.sub(r"\bnp\.dot\s*\(", "_SMX_mm(", fixed)
668
+ fixed = re.sub(r"(df\s*\[[^\]]+\])\s*\.dt", r"SMX_dt(\1).dt", fixed)
669
+
670
+
671
+ # 4) Relax any 'required_cols' hard failure blocks
672
+ fixed = re.sub(
673
+ r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
674
+ "required_cols = [c for c in df.columns]\n# (relaxed by SMX hardener)",
675
+ fixed,
676
+ flags=re.S,
677
+ )
678
+
679
+ # 5) Make static numeric_vars lists dynamic
680
+ fixed = re.sub(
681
+ r"\bnumeric_vars\s*=\s*\[.*?\]",
682
+ "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
683
+ fixed,
684
+ flags=re.S,
685
+ )
686
+ # normalise all .dt usages on df[...] / df.attr / df.loc[...] to SMX_dt(...)
687
+ fixed = re.sub(
688
+ r"((?:df\s*(?:\.\s*(?:loc|iloc)\s*)?\[[^\]]+\]|df\s*\.\s*[A-Za-z_]\w*))\s*\.dt\b",
689
+ lambda m: f"SMX_dt({m.group(1)}).dt",
690
+ fixed
691
+ )
692
+
693
+ try:
694
+ class _SMXMatmulRewriter(ast.NodeTransformer):
695
+ def visit_BinOp(self, node):
696
+ self.generic_visit(node)
697
+ if isinstance(node.op, ast.MatMult):
698
+ return ast.Call(func=ast.Name(id="_SMX_mm", ctx=ast.Load()),
699
+ args=[node.left, node.right], keywords=[])
700
+ return node
701
+ _tree = ast.parse(fixed)
702
+ _tree = _SMXMatmulRewriter().visit(_tree)
703
+ fixed = ast.unparse(_tree)
704
+ except Exception:
705
+ # If AST rewrite fails, keep original; _SMX_mm will still handle np.dot(...)
706
+ pass
707
+
708
+ # 6) Final safety wrapper
709
+ fixed = fixed.replace("\t", " ")
710
+ fixed = textwrap.dedent(fixed).strip("\n")
711
+
712
+ fixed = _strip_stray_backrefs(fixed)
713
+ fixed = _wrap_metric_calls(fixed)
714
+
715
+ # If the transformed code is still not syntactically valid, fall back to a
716
+ # very defensive generic snippet that depends only on `df`. This guarantees
717
+ try:
718
+ ast.parse(fixed)
719
+ except (SyntaxError, IndentationError):
720
+ fixed = (
721
+ "import pandas as pd\n"
722
+ "df = df.copy()\n"
723
+ "_info = {\n"
724
+ " 'rows': len(df),\n"
725
+ " 'cols': len(df.columns),\n"
726
+ " 'numeric_cols': len(df.select_dtypes(include=['number','bool']).columns),\n"
727
+ " 'categorical_cols': len(df.select_dtypes(exclude=['number','bool']).columns),\n"
728
+ "}\n"
729
+ "show(df.head(), title='Sample of data')\n"
730
+ "show(_info, title='Dataset summary')\n"
731
+ "try:\n"
732
+ " _num = df.select_dtypes(include=['number','bool']).columns.tolist()\n"
733
+ " if _num:\n"
734
+ " SB_histplot()\n"
735
+ " _SMX_export_png()\n"
736
+ "except Exception as e:\n"
737
+ " show(f\"⚠ Fallback visualisation failed: {type(e).__name__}: {e}\")\n"
738
+ )
739
+
740
+ # Fix placeholder Ellipsis handlers from LLM
741
+ fixed = re.sub(
742
+ r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
743
+ "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
744
+ fixed,
745
+ )
746
+
747
+ wrapped = PREFACE + "try:\n" + _indent(fixed) + "\nexcept Exception as e:\n show(...)\n"
748
+ wrapped = wrapped.lstrip()
749
+ return wrapped
750
+
751
+
752
+ def indent_code(code: str, spaces: int = 4) -> str:
753
+ pad = " " * spaces
754
+ return "\n".join(pad + line for line in code.splitlines())
755
+
756
+
757
+ def wrap_llm_code_safe(code: str) -> str:
758
+ # Swallow any runtime error from the LLM block instead of crashing the run
759
+ return (
760
+ "# __SAFE_WRAPPED__\n"
761
+ "try:\n" + indent_code(code) + "\n"
762
+ "except Exception as e:\n"
763
+ " from syntaxmatrix.display import show\n"
764
+ " show(f\"⚠️ Skipped LLM block due to: {type(e).__name__}: {e}\")\n"
765
+ )
766
+
767
+
768
+ def fix_boxplot_placeholder(code: str) -> str:
769
+ # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
770
+ return re.sub(
771
+ r"sns\.boxplot\(\s*boxplot\s*\)",
772
+ "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
773
+ code
774
+ )
775
+
776
+
777
+ def relax_required_columns(code: str) -> str:
778
+ # Remove hard failure on required_cols; keep a soft filter instead
779
+ return re.sub(
780
+ r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
781
+ "required_cols = [c for c in df.columns]\n",
782
+ code,
783
+ flags=re.S
784
+ )
785
+
786
+
787
+ def make_numeric_vars_dynamic(code: str) -> str:
788
+ # Replace any static numeric_vars list with a dynamic selection
789
+ return re.sub(
790
+ r"numeric_vars\s*=\s*\[.*?\]",
791
+ "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
792
+ code,
793
+ flags=re.S
794
+ )
795
+
796
+
797
+ def auto_inject_template(code: str, intents, df) -> str:
798
+ """If the LLM forgot the core logic, prepend a skeleton block."""
799
+
800
+ has_fit = ".fit(" in code
801
+ has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
802
+
803
+ UNKNOWN_TOKENS = {
804
+ "unknown","not reported","not_reported","not known","n/a","na",
805
+ "none","nan","missing","unreported","unspecified","null","-",""
806
+ }
807
+
808
+ # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
809
+ def _call_template(func, df, **hints):
810
+ import inspect
811
+ try:
812
+ params = inspect.signature(func).parameters
813
+ kw = {k: v for k, v in hints.items() if k in params}
814
+ try:
815
+ return func(df, **kw)
816
+ except TypeError:
817
+ # In case the template changed its signature at runtime
818
+ return func(df)
819
+ except Exception:
820
+ # Absolute safety net
821
+ try:
822
+ return func(df)
823
+ except Exception:
824
+ # As a last resort, return empty code so we don't 500
825
+ return ""
826
+
827
+ def _guess_classification_target(df: pd.DataFrame) -> str | None:
828
+ cols = list(df.columns)
829
+
830
+ # Helper: does this column look like a sensible label?
831
+ def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
832
+ try:
833
+ nunq = s.dropna().nunique()
834
+ except Exception:
835
+ return False
836
+ # need at least 2 classes, but not hundreds
837
+ if nunq < 2 or nunq > 20:
838
+ return False
839
+ bad_name_keys = ("id", "identifier", "index", "uuid", "key")
840
+ name = str(col_name).lower()
841
+ if any(k in name for k in bad_name_keys):
842
+ return False
843
+ return True
844
+
845
+ # 1) columns whose names look like labels
846
+ label_keys = ("target", "label", "outcome", "class", "y", "status")
847
+ name_candidates: list[str] = []
848
+ for key in label_keys:
849
+ for c in cols:
850
+ if key in str(c).lower():
851
+ name_candidates.append(c)
852
+ if name_candidates:
853
+ break # keep the earliest matching key-group
854
+
855
+ # prioritise name-based candidates that also look like proper label columns
856
+ for c in name_candidates:
857
+ if _is_reasonable_class_col(df[c], c):
858
+ return c
859
+ if name_candidates:
860
+ # fall back to the first name-based candidate if none passed the shape test
861
+ return name_candidates[0]
862
+
863
+ # 2) any column with a small number of distinct values (likely a class label)
864
+ for c in cols:
865
+ s = df[c]
866
+ if _is_reasonable_class_col(s, c):
867
+ return c
868
+
869
+ # Nothing suitable found
870
+ return None
871
+
872
+ def _guess_regression_target(df: pd.DataFrame) -> str | None:
873
+ num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
874
+ if not num_cols:
875
+ return None
876
+ # Avoid obvious ID-like columns
877
+ bad_keys = ("id", "identifier", "index")
878
+ candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
879
+ return (candidates or num_cols)[-1]
880
+
881
+ def _guess_time_col(df: pd.DataFrame) -> str | None:
882
+ # Prefer actual datetime dtype
883
+ dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
884
+ if dt_cols:
885
+ return dt_cols[0]
886
+
887
+ # Fallback: name-based hints
888
+ name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
889
+ for c in df.columns:
890
+ name = str(c).lower()
891
+ if any(k in name for k in name_keys):
892
+ return c
893
+ return None
894
+
895
+ def _guess_entity_col(df: pd.DataFrame) -> str | None:
896
+ # Typical sequence IDs: id, patient, subject, device, series, entity
897
+ keys = ["id", "patient", "subject", "device", "series", "entity"]
898
+ candidates = []
899
+ for c in df.columns:
900
+ name = str(c).lower()
901
+ if any(k in name for k in keys):
902
+ candidates.append(c)
903
+ return candidates[0] if candidates else None
904
+
905
+ def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
906
+ # Try label-like names first
907
+ keys = ["target", "label", "class", "outcome", "y"]
908
+ for key in keys:
909
+ for c in df.columns:
910
+ if key in str(c).lower():
911
+ return c
912
+
913
+ # Fallback: any column with few distinct values (e.g. <= 10)
914
+ for c in df.columns:
915
+ s = df[c]
916
+ # avoid obvious IDs
917
+ if any(k in str(c).lower() for k in ["id", "index"]):
918
+ continue
919
+ try:
920
+ nunq = s.dropna().nunique()
921
+ except Exception:
922
+ continue
923
+ if 1 < nunq <= 10:
924
+ return c
925
+
926
+ return None
927
+
928
+ def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
929
+ cols = list(df.columns)
930
+ lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
931
+ # also include boolean/binary columns with suitable names
932
+ for c in cols:
933
+ s = df[c]
934
+ try:
935
+ nunq = s.dropna().nunique()
936
+ except Exception:
937
+ continue
938
+ if nunq in (2,) and c not in lbl_like:
939
+ # avoid obvious IDs
940
+ if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
941
+ lbl_like.append(c)
942
+ # keep at most, say, 12 to avoid accidental flood
943
+ return lbl_like[:12]
944
+
945
+ def _find_unknownish_column(df: pd.DataFrame) -> str | None:
946
+ # Search categorical-like columns for any 'unknown-like' values or high missingness
947
+ candidates = []
948
+ for c in df.columns:
949
+ s = df[c]
950
+ # focus on object/category/boolean-ish or low-card columns
951
+ if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
952
+ continue
953
+ try:
954
+ vals = s.astype(str).str.strip().str.lower()
955
+ except Exception:
956
+ continue
957
+ # score: presence of unknown tokens + missing rate
958
+ token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
959
+ miss_rate = s.isna().mean()
960
+ name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
961
+ score = 3*token_hit + 2*name_bonus + miss_rate
962
+ if token_hit or miss_rate > 0.05 or name_bonus:
963
+ candidates.append((score, c))
964
+ if not candidates:
965
+ return None
966
+ candidates.sort(reverse=True)
967
+ return candidates[0][1]
968
+
969
+ def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
970
+ cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
971
+ # prefer non-constant columns
972
+ scored = []
973
+ for c in cols:
974
+ try:
975
+ v = df[c].dropna()
976
+ var = float(v.var()) if len(v) else 0.0
977
+ scored.append((var, c))
978
+ except Exception:
979
+ continue
980
+ scored.sort(reverse=True)
981
+ return [c for _, c in scored[:max_n]]
982
+
983
+ def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
984
+ exclude = exclude or set()
985
+ picks = []
986
+ for c in df.columns:
987
+ if c in exclude:
988
+ continue
989
+ s = df[c]
990
+ if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
991
+ nunq = s.dropna().nunique()
992
+ if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
993
+ picks.append((nunq, c))
994
+ picks.sort(reverse=True)
995
+ return [c for _, c in picks[:max_n]]
996
+
997
+ def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
998
+ exclude = exclude or set()
999
+ # name hints first
1000
+ name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
1001
+ for c in df.columns:
1002
+ if c in exclude:
1003
+ continue
1004
+ name = str(c).lower()
1005
+ if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
1006
+ return c
1007
+ # fallback: any binary numeric
1008
+ for c in df.select_dtypes(include=[np.number, "bool"]).columns:
1009
+ if c in exclude:
1010
+ continue
1011
+ try:
1012
+ if df[c].dropna().nunique() == 2:
1013
+ return c
1014
+ except Exception:
1015
+ continue
1016
+ return None
1017
+
1018
+
1019
+ def _pick_viz_template(signal: str):
1020
+ s = signal.lower()
1021
+
1022
+ # explicit chart requests
1023
+ if any(k in s for k in ("pie", "donut")):
1024
+ return viz_pie
1025
+
1026
+ if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
1027
+ return viz_stacked_bar
1028
+
1029
+ if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
1030
+ return viz_distribution
1031
+
1032
+ if any(k in s for k in ("kde", "density")):
1033
+ return viz_kde
1034
+
1035
+ # these three you asked about
1036
+ if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
1037
+ return viz_box
1038
+
1039
+ if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
1040
+ return viz_scatter
1041
+
1042
+ if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
1043
+ return viz_count_bar
1044
+
1045
+ if any(k in s for k in ("area", "trend", "over time", "time series")):
1046
+ return viz_area
1047
+
1048
+ # fallback
1049
+ return viz_line
1050
+
1051
+ for intent in intents:
1052
+
1053
+ if intent not in INJECTABLE_INTENTS:
1054
+ return code
1055
+
1056
+ # Correlation analysis
1057
+ if intent == "correlation_analysis" and not has_fit:
1058
+ return eda_correlation(df) + "\n\n" + code
1059
+
1060
+ # Generic visualisation (keyword-based)
1061
+ if intent == "visualisation" and not has_fit and not has_plot:
1062
+ rq = str(globals().get("refined_question", ""))
1063
+ # aq = str(globals().get("askai_question", ""))
1064
+ signal = rq + "\n" + str(intents) + "\n" + code
1065
+ tpl = _pick_viz_template(signal)
1066
+ return tpl(df) + "\n\n" + code
1067
+
1068
+ if intent == "clustering" and not has_fit:
1069
+ return clustering(df) + "\n\n" + code
1070
+
1071
+ if intent == "recommendation" and not has_fit:
1072
+ return recommendation(df) + "\\n\\n" + code
1073
+
1074
+ if intent == "topic_modelling" and not has_fit:
1075
+ return topic_modelling(df) + "\\n\\n" + code
1076
+
1077
+ if intent == "eda" and not has_fit:
1078
+ return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
1079
+
1080
+ # --- Classification ------------------------------------------------
1081
+ if intent == "classification" and not has_fit:
1082
+ target = _guess_classification_target(df)
1083
+ if target:
1084
+ return classification(df) + "\n\n" + code
1085
+ # return _call_template(classification, df, target) + "\n\n" + code
1086
+
1087
+ # --- Regression ----------------------------------------------------
1088
+ if intent == "regression" and not has_fit:
1089
+ target = _guess_regression_target(df)
1090
+ if target:
1091
+ return regression(df) + "\n\n" + code
1092
+ # return _call_template(regression, df, target) + "\n\n" + code
1093
+
1094
+ # --- Anomaly detection --------------------------------------------
1095
+ if intent == "anomaly_detection":
1096
+ uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
1097
+ if not uses_anomaly:
1098
+ return anomaly_detection(df) + "\n\n" + code
1099
+
1100
+ # --- Time-series anomaly detection --------------------------------
1101
+ if intent == "ts_anomaly_detection":
1102
+ uses_ts = "STL(" in code or "seasonal_decompose(" in code
1103
+ if not uses_ts:
1104
+ return ts_anomaly_detection(df) + "\n\n" + code
1105
+
1106
+ # --- Time-series classification --------------------------------
1107
+ if intent == "time_series_classification" and not has_fit:
1108
+ time_col = _guess_time_col(df)
1109
+ entity_col = _guess_entity_col(df)
1110
+ target_col = _guess_ts_class_target(df)
1111
+
1112
+ # If we can't confidently identify these, do NOT inject anything
1113
+ if time_col and entity_col and target_col:
1114
+ return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
1115
+
1116
+ # --- Dimensionality reduction --------------------------------------
1117
+ if intent == "dimensionality_reduction":
1118
+ uses_dr = any(k in code for k in ("PCA(", "TSNE("))
1119
+ if not uses_dr:
1120
+ return dimensionality_reduction(df) + "\n\n" + code
1121
+
1122
+ # --- Feature selection ---------------------------------------------
1123
+ if intent == "feature_selection":
1124
+ uses_fs = any(k in code for k in (
1125
+ "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
1126
+ ))
1127
+ if not uses_fs:
1128
+ return feature_selection(df) + "\n\n" + code
1129
+
1130
+ # --- EDA / correlation / visualisation -----------------------------
1131
+ if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
1132
+ if intent == "correlation_analysis":
1133
+ return eda_correlation(df) + "\n\n" + code
1134
+ else:
1135
+ return eda_overview(df) + "\n\n" + code
1136
+
1137
+ # --- Time-series forecasting ---------------------------------------
1138
+ if intent == "time_series_forecasting" and not has_fit:
1139
+ uses_ts_forecast = any(k in code for k in (
1140
+ "ARIMA", "ExponentialSmoothing", "forecast", "predict("
1141
+ ))
1142
+ if not uses_ts_forecast:
1143
+ return time_series_forecasting(df) + "\n\n" + code
1144
+
1145
+ # --- Multi-label classification -----------------------------------
1146
+ if intent in ("multilabel_classification",) and not has_fit:
1147
+ label_cols = _guess_multilabel_cols(df)
1148
+ if len(label_cols) >= 2:
1149
+ return multilabel_classification(df, label_cols) + "\n\n" + code
1150
+
1151
+ group_col = _find_unknownish_column(df)
1152
+ if group_col:
1153
+ num_cols = _guess_numeric_cols(df)
1154
+ cat_cols = _guess_categorical_cols(df, exclude={group_col})
1155
+ outcome_col = None # generic; let template skip if not present
1156
+ tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1157
+
1158
+ # Return template + guarded (repaired) LLM code, so it never crashes
1159
+ repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1160
+ return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1161
+
1162
+ return code
1163
+
1164
+
1165
+ def fix_values_sum_numeric_only_bug(code: str) -> str:
1166
+ """
1167
+ If a previous pass injected numeric_only=True into a NumPy-style sum,
1168
+ e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1169
+ """
1170
+ # .values.sum(numeric_only=True, ...)
1171
+ code = re.sub(
1172
+ r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1173
+ ".to_numpy().sum()",
1174
+ code,
1175
+ flags=re.IGNORECASE,
1176
+ )
1177
+ # .to_numpy().sum(numeric_only=True, ...)
1178
+ code = re.sub(
1179
+ r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1180
+ ".to_numpy().sum()",
1181
+ code,
1182
+ flags=re.IGNORECASE,
1183
+ )
1184
+ return code
1185
+
48
1186
 
49
1187
  def strip_describe_slice(code: str) -> str:
50
1188
  """
@@ -59,10 +1197,12 @@ def strip_describe_slice(code: str) -> str:
59
1197
  )
60
1198
  return pat.sub(r"\1)", code)
61
1199
 
1200
+
62
1201
  def remove_plt_show(code: str) -> str:
63
1202
  """Removes all plt.show() calls from the generated code string."""
64
1203
  return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
65
1204
 
1205
+
66
1206
  def patch_plot_with_table(code: str) -> str:
67
1207
  """
68
1208
  ▸ strips every `plt.show()` (avoids warnings)
@@ -149,7 +1289,7 @@ def patch_plot_with_table(code: str) -> str:
149
1289
  ")\n"
150
1290
  )
151
1291
 
152
- tbl_block += "from syntaxmatrix.display import show\nshow(summary_table)"
1292
+ tbl_block += "show(summary_table, title='Summary Statistics')"
153
1293
 
154
1294
  # 5. inject image-export block, then table block, after the plot
155
1295
  patched = (
@@ -289,10 +1429,10 @@ def refine_eda_question(raw_question, df=None, max_points=1000):
289
1429
  "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
290
1430
  )
291
1431
 
292
-
293
1432
  # 9. Fallback: return the raw question
294
1433
  return q
295
1434
 
1435
+
296
1436
  def patch_plot_code(code, df, user_question=None):
297
1437
 
298
1438
  # ── Early guard: abort nicely if the generated code references columns that
@@ -313,10 +1453,13 @@ def patch_plot_code(code, df, user_question=None):
313
1453
 
314
1454
  if missing_cols:
315
1455
  cols_list = ", ".join(missing_cols)
316
- return (
317
- f"print('⚠️ Column(s) \"{cols_list}\" not found in the dataset. "
318
- f"Please check the column names and try again.')"
1456
+ warning = (
1457
+ f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1458
+ "These must either exist in df or be created earlier in the code; "
1459
+ "otherwise you may see a KeyError.')\n"
319
1460
  )
1461
+ # Prepend the warning but keep the original code so it can still run
1462
+ code = warning + code
320
1463
 
321
1464
  # 1. For line plots (auto-aggregate)
322
1465
  m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
@@ -425,6 +1568,16 @@ def patch_plot_code(code, df, user_question=None):
425
1568
  # Fallback: Return original code
426
1569
  return code
427
1570
 
1571
+
1572
+ def ensure_matplotlib_title(code, title_var="refined_question"):
1573
+ import re
1574
+ makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1575
+ has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1576
+ if makes_plot and not has_title:
1577
+ code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1578
+ return code
1579
+
1580
+
428
1581
  def ensure_output(code: str) -> str:
429
1582
  """
430
1583
  Guarantees that AI-generated code actually surfaces results in the UI
@@ -441,7 +1594,6 @@ def ensure_output(code: str) -> str:
441
1594
  # not a comment / print / assignment / pyplot call
442
1595
  if (last and not last.startswith(("print(", "plt.", "#")) and "=" not in last):
443
1596
  lines[-1] = f"_out = {last}"
444
- lines.append("from syntaxmatrix.display import show")
445
1597
  lines.append("show(_out)")
446
1598
 
447
1599
  # ── 3· auto-surface common stats tuples (stat, p) ───────────────────
@@ -449,14 +1601,12 @@ def ensure_output(code: str) -> str:
449
1601
  if re.search(r"\bchi2\s*,\s*p\s*,", code) and "show((" in code:
450
1602
  pass # AI already shows the tuple
451
1603
  elif re.search(r"\bchi2\s*,\s*p\s*,", code):
452
- lines.append("from syntaxmatrix.display import show")
453
1604
  lines.append("show((chi2, p))")
454
1605
 
455
1606
  # ── 4· classification report (string) ───────────────────────────────
456
1607
  cr_match = re.search(r"^\s*(\w+)\s*=\s*classification_report\(", code, re.M)
457
1608
  if cr_match and f"show({cr_match.group(1)})" not in "\n".join(lines):
458
1609
  var = cr_match.group(1)
459
- lines.append("from syntaxmatrix.display import show")
460
1610
  lines.append(f"show({var})")
461
1611
 
462
1612
  # 5-bis · pivot tables (DataFrame)
@@ -493,18 +1643,17 @@ def ensure_output(code: str) -> str:
493
1643
  assign_scalar = re.match(r"\s*(\w+)\s*=\s*.+\.shape\[\s*0\s*\]\s*$", lines[-1])
494
1644
  if assign_scalar:
495
1645
  var = assign_scalar.group(1)
496
- lines.append("from syntaxmatrix.display import show")
497
1646
  lines.append(f"show({var})")
498
1647
 
499
1648
  # ── 8. utils.ensure_output()
500
1649
  assign_df = re.match(r"\s*(\w+)\s*=\s*df\[", lines[-1])
501
1650
  if assign_df:
502
1651
  var = assign_df.group(1)
503
- lines.append("from syntaxmatrix.display import show")
504
1652
  lines.append(f"show({var})")
505
1653
 
506
1654
  return "\n".join(lines)
507
1655
 
1656
+
508
1657
  def get_plotting_imports(code):
509
1658
  imports = []
510
1659
  if "plt." in code and "import matplotlib.pyplot as plt" not in code:
@@ -524,6 +1673,7 @@ def get_plotting_imports(code):
524
1673
  code = "\n".join(imports) + "\n\n" + code
525
1674
  return code
526
1675
 
1676
+
527
1677
  def patch_pairplot(code, df):
528
1678
  if "sns.pairplot" in code:
529
1679
  # Always assign and print pairgrid
@@ -534,29 +1684,82 @@ def patch_pairplot(code, df):
534
1684
  code += "\nprint(pairgrid)"
535
1685
  return code
536
1686
 
1687
+
537
1688
  def ensure_image_output(code: str) -> str:
538
1689
  """
539
- Injects a PNG exporter in front of every plt.show() so dashboards
540
- get real <img> HTML instead of a blank cell.
1690
+ Replace each plt.show() with an indented _SMX_export_png() call.
1691
+ This keeps block indentation valid and still renders images in the dashboard.
541
1692
  """
542
1693
  if "plt.show()" not in code:
543
1694
  return code
544
1695
 
545
- exporter = (
546
- # -- NEW: use display(), not print() --------------------------
547
- "import io, base64\n"
548
- "buf = io.BytesIO()\n"
549
- "plt.savefig(buf, format='png', bbox_inches='tight')\n"
550
- "buf.seek(0)\n"
551
- "img_b64 = base64.b64encode(buf.read()).decode('utf-8')\n"
552
- "from IPython.display import display, HTML\n"
553
- "display(HTML(f'<img src=\"data:image/png;base64,{img_b64}\" "
554
- "style=\"max-width:100%;\">'))\n"
555
- "plt.close()\n"
1696
+ import re
1697
+ out_lines = []
1698
+ for ln in code.splitlines():
1699
+ if "plt.show()" not in ln:
1700
+ out_lines.append(ln)
1701
+ continue
1702
+
1703
+ # works for:
1704
+ # plt.show()
1705
+ # plt.tight_layout(); plt.show()
1706
+ # ... ; plt.show(); ... (multiple on one line)
1707
+ indent = re.match(r"^(\s*)", ln).group(1)
1708
+ parts = ln.split("plt.show()")
1709
+
1710
+ # keep whatever is before the first plt.show()
1711
+ if parts[0].strip():
1712
+ out_lines.append(parts[0].rstrip())
1713
+
1714
+ # for every plt.show() we removed, insert exporter at same indent
1715
+ for _ in range(len(parts) - 1):
1716
+ out_lines.append(indent + "_SMX_export_png()")
1717
+
1718
+ # keep whatever comes after the last plt.show()
1719
+ if parts[-1].strip():
1720
+ out_lines.append(indent + parts[-1].lstrip())
1721
+
1722
+ return "\n".join(out_lines)
1723
+
1724
+
1725
+ def clean_llm_code(code: str) -> str:
1726
+ """
1727
+ Make LLM output safe to exec:
1728
+ - If fenced blocks exist, keep the largest one (usually the real code).
1729
+ - Otherwise strip any stray ``` / ```python lines.
1730
+ - Remove common markdown/preamble junk.
1731
+ """
1732
+ code = str(code or "")
1733
+
1734
+ # Extract fenced blocks (```python ... ``` or ``` ... ```)
1735
+ blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1736
+
1737
+ if blocks:
1738
+ # pick the largest block; small trailing blocks are usually garbage
1739
+ largest = max(blocks, key=lambda b: len(b.strip()))
1740
+ if len(largest.strip().splitlines()) >= 10:
1741
+ code = largest
1742
+ else:
1743
+ # if no meaningful block, just remove fence markers
1744
+ code = re.sub(r"^```.*?$", "", code, flags=re.M)
1745
+ else:
1746
+ # no complete blocks — still remove any stray fence lines
1747
+ code = re.sub(r"^```.*?$", "", code, flags=re.M)
1748
+
1749
+ # Strip common markdown/preamble lines
1750
+ drop_prefixes = (
1751
+ "here is", "here's", "below is", "sure,", "certainly",
1752
+ "explanation", "note:", "```"
556
1753
  )
1754
+ cleaned_lines = []
1755
+ for ln in code.splitlines():
1756
+ s = ln.strip().lower()
1757
+ if any(s.startswith(p) for p in drop_prefixes):
1758
+ continue
1759
+ cleaned_lines.append(ln)
1760
+
1761
+ return "\n".join(cleaned_lines).strip()
557
1762
 
558
- # exporter BEFORE the original plt.show()
559
- return code.replace("plt.show()", exporter + "plt.show()")
560
1763
 
561
1764
  def fix_groupby_describe_slice(code: str) -> str:
562
1765
  """
@@ -579,6 +1782,7 @@ def fix_groupby_describe_slice(code: str) -> str:
579
1782
  )
580
1783
  return pat.sub(repl, code)
581
1784
 
1785
+
582
1786
  def fix_importance_groupby(code: str) -> str:
583
1787
  pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
584
1788
  if "importance_df" in code:
@@ -625,10 +1829,12 @@ def inject_auto_preprocessing(code: str) -> str:
625
1829
  # simply prepend; model code that follows can wrap estimator in a Pipeline
626
1830
  return prep_snippet + code
627
1831
 
1832
+
628
1833
  def fix_to_datetime_errors(code: str) -> str:
629
1834
  """
630
1835
  Force every pd.to_datetime(…) call to ignore bad dates so that
631
- ‘year 16500 is out of range’ and similar issues don’t crash runs.
1836
+
1837
+ 'year 16500 is out of range' and similar issues don’t crash runs.
632
1838
  """
633
1839
  import re
634
1840
  # look for any pd.to_datetime( … )
@@ -641,25 +1847,67 @@ def fix_to_datetime_errors(code: str) -> str:
641
1847
  return f"pd.to_datetime({inside}, errors='coerce')"
642
1848
  return pat.sub(repl, code)
643
1849
 
1850
+
644
1851
  def fix_numeric_sum(code: str) -> str:
645
1852
  """
646
- Rewrites every `.sum(` call so it becomes
647
- `.sum(numeric_only=True, …)` unless that keyword is already present.
1853
+ Make .sum(...) code safe across pandas versions by removing any
1854
+ numeric_only=... argument (True/False/None) from function calls.
1855
+
1856
+ This avoids errors on pandas versions where numeric_only is not
1857
+ supported for Series/grouped sums, and we rely instead on explicit
1858
+ numeric column selection (e.g. select_dtypes) in the generated code.
648
1859
  """
649
- pattern = re.compile(r"\.sum\(\s*([^\)]*)\)")
1860
+ # Case 1: ..., numeric_only=True/False/None
1861
+ code = re.sub(
1862
+ r",\s*numeric_only\s*=\s*(True|False|None)",
1863
+ "",
1864
+ code,
1865
+ flags=re.IGNORECASE,
1866
+ )
650
1867
 
651
- def _repl(match):
652
- args = match.group(1)
653
- if "numeric_only" in args: # already safe
654
- return match.group(0)
1868
+ # Case 2: numeric_only=True/False/None, ... (as first argument)
1869
+ code = re.sub(
1870
+ r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1871
+ "",
1872
+ code,
1873
+ flags=re.IGNORECASE,
1874
+ )
655
1875
 
656
- args = args.strip()
657
- if args: # keep existing positional / kw args
658
- args += ", "
659
- return f".sum({args}numeric_only=True)"
1876
+ # Case 3: numeric_only=True/False/None (only argument)
1877
+ code = re.sub(
1878
+ r"numeric_only\s*=\s*(True|False|None)",
1879
+ "",
1880
+ code,
1881
+ flags=re.IGNORECASE,
1882
+ )
1883
+
1884
+ return code
1885
+
1886
+
1887
+ def fix_concat_empty_list(code: str) -> str:
1888
+ """
1889
+ Make pd.concat calls resilient to empty lists of objects.
1890
+
1891
+ Transforms patterns like:
1892
+ pd.concat(frames, ignore_index=True)
1893
+ pd.concat(frames)
1894
+
1895
+ into:
1896
+ pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1897
+ pd.concat(frames or [pd.DataFrame()])
1898
+
1899
+ Only triggers when the first argument is a simple variable name.
1900
+ """
1901
+ pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1902
+
1903
+ def _repl(m):
1904
+ name = m.group(1)
1905
+ sep = m.group(2) # ',' or ')'
1906
+ return f"pd.concat({name} or [pd.DataFrame()]{sep}"
660
1907
 
661
1908
  return pattern.sub(_repl, code)
662
1909
 
1910
+
663
1911
  def fix_numeric_aggs(code: str) -> str:
664
1912
  _AGG_FUNCS = ("sum", "mean")
665
1913
  pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
@@ -673,69 +1921,66 @@ def fix_numeric_aggs(code: str) -> str:
673
1921
  return f".{func}({args}numeric_only=True)"
674
1922
  return pat.sub(_repl, code)
675
1923
 
1924
+
676
1925
  def ensure_accuracy_block(code: str) -> str:
677
1926
  """
678
- If the code fits an estimator but never prints accuracy,
679
- inject an evaluation block that re-uses *whatever variable name*
680
- appears immediately before `.fit(`.
1927
+ Inject a sensible evaluation block right after the last `<est>.fit(...)`
1928
+ Classification accuracy + weighted F1
1929
+ Regression → R², RMSE, MAE
1930
+ Heuristic: infer task from estimator names present in the code.
681
1931
  """
682
- # Already prints accuracy? – bail out early
683
- if re.search(r"accuracy_score\s*\(", code):
1932
+ import re, textwrap
1933
+
1934
+ # If any proper metric already exists, do nothing
1935
+ if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
684
1936
  return code
685
1937
 
686
- # Find the last `<var>.fit(` call
1938
+ # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
687
1939
  m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
688
1940
  if not m:
689
- return code # no model at all
1941
+ return code # no estimator
690
1942
 
691
- var = m[-1].group(1) # estimator variable name
1943
+ var = m[-1].group(1)
1944
+ # indent with same leading whitespace used on that line
692
1945
  indent = re.match(r"\s*", code[m[-1].start():]).group(0)
693
1946
 
694
- eval_block = textwrap.dedent(f"""
695
- {indent}# ── automatic accuracy evaluation ─────────
696
- {indent}from sklearn.metrics import accuracy_score
697
- {indent}y_pred = {var}.predict(X_test)
698
- {indent}acc = accuracy_score(y_test, y_pred)
699
- {indent}print(f"Model accuracy on hold-out set: {{acc:.2%}}")
700
- """)
1947
+ # Detect regression by estimator names / hints in code
1948
+ is_regression = bool(
1949
+ re.search(
1950
+ r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1951
+ r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1952
+ r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1953
+ )
1954
+ or re.search(r"\bOLS\s*\(", code)
1955
+ or re.search(r"\bRegressor\b", code)
1956
+ )
1957
+
1958
+ if is_regression:
1959
+ # inject numpy import if needed for RMSE
1960
+ if "import numpy as np" not in code and "np." not in code:
1961
+ code = "import numpy as np\n" + code
1962
+ eval_block = textwrap.dedent(f"""
1963
+ {indent}# ── automatic regression evaluation ─────────
1964
+ {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
1965
+ {indent}y_pred = {var}.predict(X_test)
1966
+ {indent}r2 = r2_score(y_test, y_pred)
1967
+ {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
1968
+ {indent}mae = float(mean_absolute_error(y_test, y_pred))
1969
+ {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
1970
+ """)
1971
+ else:
1972
+ eval_block = textwrap.dedent(f"""
1973
+ {indent}# ── automatic classification evaluation ─────────
1974
+ {indent}from sklearn.metrics import accuracy_score, f1_score
1975
+ {indent}y_pred = {var}.predict(X_test)
1976
+ {indent}acc = accuracy_score(y_test, y_pred)
1977
+ {indent}f1 = f1_score(y_test, y_pred, average='weighted')
1978
+ {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
1979
+ """)
701
1980
 
702
1981
  insert_at = code.find("\n", m[-1].end()) + 1
703
1982
  return code[:insert_at] + eval_block + code[insert_at:]
704
1983
 
705
- def classify(prompt: str) -> str:
706
- """
707
- Very-light intent classifier.
708
- Returns one of:
709
- 'stat_test' | 'time_series' | 'clustering'
710
- 'classification' | 'regression' | 'eda'
711
- """
712
- p = prompt.lower().strip()
713
- greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
714
- if any(p.startswith(g) or p == g for g in greetings):
715
- return "greeting"
716
-
717
- if any(k in p for k in ("t-test", "anova", "p-value")):
718
- return "stat_test"
719
- if "forecast" in p or "prophet" in p:
720
- return "time_series"
721
- if "cluster" in p or "kmeans" in p:
722
- return "clustering"
723
- if any(k in p for k in ("accuracy", "precision", "roc")):
724
- return "classification"
725
- if any(k in p for k in ("rmse", "r2", "mae")):
726
- return "regression"
727
- return "eda"
728
-
729
- def auto_inject_template(code: str, intent: str, df) -> str:
730
- """If the LLM forgot the core logic, prepend a skeleton block."""
731
- has_fit = ".fit(" in code
732
-
733
- if intent == "classification" and not has_fit:
734
- # guess a y column that contains 'diabetes' as in your dataset
735
- target = next((c for c in df.columns if "diabetes" in c.lower()), None)
736
- if target:
737
- return classification(df, target) + "\n\n" + code
738
- return code
739
1984
 
740
1985
  def fix_scatter_and_summary(code: str) -> str:
741
1986
  """
@@ -763,6 +2008,7 @@ def fix_scatter_and_summary(code: str) -> str:
763
2008
 
764
2009
  return code
765
2010
 
2011
+
766
2012
  def auto_format_with_black(code: str) -> str:
767
2013
  """
768
2014
  Format the generated code with Black. Falls back silently if Black
@@ -777,6 +2023,7 @@ def auto_format_with_black(code: str) -> str:
777
2023
  except Exception:
778
2024
  return code
779
2025
 
2026
+
780
2027
  def ensure_preproc_in_pipeline(code: str) -> str:
781
2028
  """
782
2029
  If code defines `preproc = ColumnTransformer(...)` but then builds
@@ -789,52 +2036,1128 @@ def ensure_preproc_in_pipeline(code: str) -> str:
789
2036
  code
790
2037
  )
791
2038
 
2039
+
792
2040
  def fix_plain_prints(code: str) -> str:
793
2041
  """
794
- Rewrite print(<var>) show(<var>) when <var> looks like
795
- a pandas / numpy / sklearn object (heuristic: not a string literal).
2042
+ Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
2043
+ to go through SyntaxMatrix's smart display (so it renders in the dashboard).
2044
+ Keeps string prints alone.
796
2045
  """
797
- import re
798
- return re.sub(
799
- r"print\((\w+)\)",
2046
+
2047
+ # Skip obvious string-literal prints
2048
+ new = re.sub(
2049
+ r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
800
2050
  r"from syntaxmatrix.display import show\nshow(\1)",
801
2051
  code,
802
2052
  )
2053
+ return new
2054
+
803
2055
 
804
- # --------------------------------------------------------------------------
805
- # ✂
806
- # --------------------------------------------------------------------------
807
- def drop_bad_classification_metrics(code: str, y) -> str:
2056
+ def fix_print_html(code: str) -> str:
808
2057
  """
809
- If the prediction target is continuous (i.e. a regression task) and the
810
- generated code mistakenly calls classification metrics such as
811
- `accuracy_score`, `classification_report`, or `confusion_matrix`,
812
- comment those lines out so the cell can still run.
2058
+ Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
2059
+ not printed as `<IPython.core.display.HTML object>` to the server console.
2060
+ - Rewrites: print(HTML(...)) → display(HTML(...))
2061
+ print(display(...)) display(...)
2062
+ print(df.to_html(...)) → display(HTML(df.to_html(...)))
2063
+ Also prepends `from IPython.display import display, HTML` if required.
2064
+ """
2065
+ import re
2066
+
2067
+ new = code
2068
+
2069
+ # 1) print(HTML(...)) -> display(HTML(...))
2070
+ new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
2071
+
2072
+ # 2) print(display(...)) -> display(...)
2073
+ new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
2074
+
2075
+ # 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
2076
+ new = re.sub(
2077
+ r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
2078
+ r"display(HTML(\1.to_html(", new
2079
+ )
2080
+
2081
+ # If code references HTML() or display() make sure the import exists
2082
+ if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
2083
+ "from IPython.display import display, HTML" not in new:
2084
+ new = "from IPython.display import display, HTML\n" + new
2085
+
2086
+ return new
2087
+
813
2088
 
814
- Works whether `y` is:
815
- • a pandas Series -> y.dtype.kind is available
816
- • a pandas DataFrame (multi-column) -> we infer by looking at *all*
2089
+ def ensure_ipy_display(code: str) -> str:
817
2090
  """
818
- # ── decide whether y looks continuous ────────────────────────────────
819
- try:
820
- kind = y.dtype.kind # Series path
821
- except AttributeError:
822
- # DataFrame path: regression if *every* column’s dtype is numeric/datetime
823
- numeric_kinds = set("fiuM") # float, int, unsigned, datetime
824
- col_kinds = {dt.kind for dt in getattr(y, "dtypes", [])}
825
- kind = "f" if col_kinds and col_kinds.issubset(numeric_kinds) else "O"
826
-
827
- # ── if regression, strip classification lines ───────────────────────
828
- if kind in "fM": # float or datetime
829
- patterns = [
830
- r"\n.*accuracy_score[^\n]*",
831
- r"\n.*classification_report[^\n]*",
832
- r"\n.*confusion_matrix[^\n]*",
833
- ]
834
- for pat in patterns:
2091
+ Guarantee that the cell has proper IPython display imports so that
2092
+ display(HTML(...)) produces 'display_data' events the kernel captures.
2093
+ """
2094
+ if "display(" in code and "from IPython.display import display, HTML" not in code:
2095
+ return "from IPython.display import display, HTML\n" + code
2096
+ return code
2097
+
2098
+
2099
+ def drop_bad_classification_metrics(code: str, y_or_df) -> str:
2100
+ """
2101
+ Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
2102
+ if the generated cell is *regression*. We infer this from:
2103
+ 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
2104
+ 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
2105
+ Safe across datasets and queries.
2106
+ """
2107
+ import re
2108
+ import pandas as pd
2109
+
2110
+ # 1) Heuristic by estimator names in the *code* (fast path)
2111
+ regression_by_model = bool(re.search(
2112
+ r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
2113
+ r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
2114
+ r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
2115
+ ) or re.search(r"\bOLS\s*\(", code))
2116
+
2117
+ is_regression = regression_by_model
2118
+
2119
+ # 2) If not obvious from the model, try to infer from y dtype (if we can)
2120
+ if not is_regression:
2121
+ try:
2122
+ # Try to parse: y = df['target']
2123
+ m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
2124
+ if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
2125
+ y = y_or_df[m.group(1)]
2126
+ if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2127
+ is_regression = True
2128
+ else:
2129
+ # If a Series was passed
2130
+ y = y_or_df
2131
+ if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2132
+ is_regression = True
2133
+ except Exception:
2134
+ pass
2135
+
2136
+ if is_regression:
2137
+ # Strip classification-only lines
2138
+ for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
835
2139
  code = re.sub(pat, "", code, flags=re.I)
836
2140
 
837
2141
  return code
838
2142
 
839
- # from syntaxmatrix.core import SyntaxMUI
840
- # ai_generate_code = SyntaxMUI.ai_generate_code
2143
+
2144
+ def force_capture_display(code: str) -> str:
2145
+ """
2146
+ Ensure our executor captures HTML output:
2147
+ - Remove any import that would override our 'display' hook.
2148
+ - Keep/allow importing HTML only.
2149
+ - Handle alias cases like 'display as d'.
2150
+ """
2151
+ import re
2152
+ new = code
2153
+
2154
+ # 'from IPython.display import display, HTML' -> keep HTML only
2155
+ new = re.sub(
2156
+ r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
2157
+ r"from IPython.display import HTML\1", new
2158
+ )
2159
+
2160
+ # 'from IPython.display import display as d' -> 'd = display'
2161
+ new = re.sub(
2162
+ r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
2163
+ r"\1 = display", new
2164
+ )
2165
+
2166
+ # 'from IPython.display import display' -> remove (use our injected display)
2167
+ new = re.sub(
2168
+ r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
2169
+ r"# display import removed (SMX capture active)", new
2170
+ )
2171
+
2172
+ # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
2173
+ new = re.sub(
2174
+ r"(?m)\bIPython\.display\.display\s*\(",
2175
+ "display(", new
2176
+ )
2177
+ new = re.sub(
2178
+ r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
2179
+ r"(?=.*import\s+IPython\.display\s+as\s+\1)",
2180
+ "display(", new
2181
+ )
2182
+ return new
2183
+
2184
+
2185
+ def strip_matplotlib_show(code: str) -> str:
2186
+ """Remove blocking plt.show() calls (we export base64 instead)."""
2187
+ import re
2188
+ return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
2189
+
2190
+
2191
+ def inject_display_shim(code: str) -> str:
2192
+ """
2193
+ Provide display()/HTML() if missing, forwarding to our executor hook.
2194
+ Harmless if the names already exist.
2195
+ """
2196
+ shim = (
2197
+ "try:\n"
2198
+ " display\n"
2199
+ "except NameError:\n"
2200
+ " def display(obj=None, **kwargs):\n"
2201
+ " __builtins__.get('_smx_display', print)(obj)\n"
2202
+ "try:\n"
2203
+ " HTML\n"
2204
+ "except NameError:\n"
2205
+ " class HTML:\n"
2206
+ " def __init__(self, data): self.data = str(data)\n"
2207
+ " def _repr_html_(self): return self.data\n"
2208
+ "\n"
2209
+ )
2210
+ return shim + code
2211
+
2212
+
2213
+ def strip_spurious_column_tokens(code: str) -> str:
2214
+ """
2215
+ Remove common stop-words ('the','whether', ...) when they appear
2216
+ inside column lists, e.g.:
2217
+ predictors = ['BMI','the','HbA1c']
2218
+ df[['GGT','whether','BMI']]
2219
+ Leaves other strings intact.
2220
+ """
2221
+ STOP = {
2222
+ "the","whether","a","an","and","or","of","to","in","on","for","by",
2223
+ "with","as","at","from","that","this","these","those","is","are","was","were",
2224
+ "coef", "Coef", "coefficient", "Coefficient"
2225
+ }
2226
+
2227
+ def _norm(s: str) -> str:
2228
+ return re.sub(r"[^a-z0-9]+", "", s.lower())
2229
+
2230
+ def _clean_list(content: str) -> str:
2231
+ # Rebuild a string list, keeping only non-stopword items
2232
+ items = re.findall(r"(['\"])(.*?)\1", content)
2233
+ if not items:
2234
+ return "[" + content + "]"
2235
+ keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
2236
+ return "[" + ", ".join(keep) + "]"
2237
+
2238
+ # Variable assignments: predictors/features/columns/cols = [...]
2239
+ code = re.sub(
2240
+ r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
2241
+ lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
2242
+ code
2243
+ )
2244
+
2245
+ # df[[ ... ]] selections
2246
+ code = re.sub(
2247
+ r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2248
+
2249
+ return code
2250
+
2251
+
2252
+ def patch_prefix_seaborn_calls(code: str) -> str:
2253
+ """
2254
+ Ensure bare seaborn calls are prefixed with `sns.`.
2255
+ E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2256
+ """
2257
+ if "sns." in code:
2258
+ # still fix any leftover bare calls alongside prefixed ones
2259
+ pass
2260
+
2261
+ # functions commonly used from seaborn
2262
+ funcs = [
2263
+ "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2264
+ "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2265
+ "scatterplot","lineplot","catplot","displot","lmplot"
2266
+ ]
2267
+ # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2268
+ # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2269
+ pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2270
+
2271
+ def _add_prefix(m):
2272
+ fn = m.group(1)
2273
+ return f"sns.{fn}("
2274
+
2275
+ return pattern.sub(_add_prefix, code)
2276
+
2277
+
2278
+ def patch_ensure_seaborn_import(code: str) -> str:
2279
+ """
2280
+ If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2281
+ Also set a quiet theme for consistent visuals.
2282
+ """
2283
+ needs_sns = "sns." in code
2284
+ has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2285
+ if needs_sns and not has_import:
2286
+ # Insert after the first block of imports if possible, else at top
2287
+ import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2288
+ inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2289
+ if import_block:
2290
+ start = import_block.end()
2291
+ code = code[:start] + inject + code[start:]
2292
+ else:
2293
+ code = inject + code
2294
+ return code
2295
+
2296
+
2297
+ def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2298
+ """
2299
+ Normalise pie-chart requests.
2300
+
2301
+ Supports three patterns:
2302
+ A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2303
+ B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2304
+ → one pie per facet level (grid) + counts bar of facet sizes.
2305
+ C) Single pie when no split/facet is requested.
2306
+
2307
+ Notes:
2308
+ - Pie variables must be categorical (or numeric binned).
2309
+ - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2310
+ """
2311
+
2312
+ q = (user_question or "")
2313
+ q_low = q.lower()
2314
+
2315
+ # Prefer explicit: df['col'].value_counts()
2316
+ m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2317
+ col = m.group(1) if m else None
2318
+
2319
+ # ---------- helpers ----------
2320
+ def _is_cat(col):
2321
+ return (str(df[col].dtype).startswith("category")
2322
+ or df[col].dtype == "object"
2323
+ or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2324
+
2325
+ def _cats_from_question(question: str):
2326
+ found = []
2327
+ for c in df.columns:
2328
+ if c.lower() in question.lower() and _is_cat(c):
2329
+ found.append(c)
2330
+ # dedupe preserve order
2331
+ seen, out = set(), []
2332
+ for c in found:
2333
+ if c not in seen:
2334
+ out.append(c); seen.add(c)
2335
+ return out
2336
+
2337
+ def _fallback_cat():
2338
+ cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2339
+ if not cats: return None
2340
+ cats.sort(key=lambda t: t[1])
2341
+ return cats[0][0]
2342
+
2343
+ def _infer_comp_pref(question: str) -> str:
2344
+ ql = (question or "").lower()
2345
+ if "heatmap" in ql or "matrix" in ql:
2346
+ return "heatmap"
2347
+ if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2348
+ return "stacked_bar_pct"
2349
+ if "stacked" in ql:
2350
+ return "stacked_bar"
2351
+ if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2352
+ return "grouped_bar"
2353
+ return "counts_bar"
2354
+
2355
+ # parse threshold split like "HbA1c ≥ 6.5"
2356
+ def _parse_split(question: str):
2357
+ ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2358
+ m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2359
+ if not m: return None
2360
+ col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2361
+ op = ops_map.get(op_raw);
2362
+ if not op: return None
2363
+ # case-insensitive column match
2364
+ candidates = {c.lower(): c for c in df.columns}
2365
+ col = candidates.get(col_raw.lower())
2366
+ if not col: return None
2367
+ try: val = float(val_raw)
2368
+ except Exception: return None
2369
+ return (col, op, val)
2370
+
2371
+ # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2372
+ def _extract_facet(question: str):
2373
+ # 1) explicit "by/ across / within / per <col>"
2374
+ for kw in [" by ", " across ", " within ", " within each ", " per "]:
2375
+ m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2376
+ if m:
2377
+ col_raw = m.group(1).strip()
2378
+ candidates = {c.lower(): c for c in df.columns}
2379
+ if col_raw.lower() in candidates:
2380
+ return (candidates[col_raw.lower()], "auto")
2381
+ # 2) "bin <col>"
2382
+ m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2383
+ if m2:
2384
+ col_raw = m2.group(1).strip()
2385
+ candidates = {c.lower(): c for c in df.columns}
2386
+ if col_raw.lower() in candidates:
2387
+ return (candidates[col_raw.lower()], "bin")
2388
+ # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2389
+ if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2390
+ any(c.lower() == "bmi" for c in df.columns.str.lower()):
2391
+ bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2392
+ return (bmi_col, "bmi")
2393
+ return None
2394
+
2395
+ def _bmi_bins(series: pd.Series):
2396
+ # WHO cutoffs
2397
+ bins = [-np.inf, 18.5, 25, 30, np.inf]
2398
+ labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2399
+ return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2400
+
2401
+ wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2402
+ if not wants_pie:
2403
+ return code
2404
+
2405
+ split = _parse_split(q)
2406
+ facet = _extract_facet(q)
2407
+ cats = _cats_from_question(q)
2408
+ _comp_pref = _infer_comp_pref(q)
2409
+
2410
+ # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2411
+ for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2412
+ if hard in df.columns and hard not in cats and hard.lower() in q_low:
2413
+ cats.append(hard)
2414
+
2415
+ # --------------- CASE A: threshold split (cohorts) ---------------
2416
+ if split:
2417
+ if not (cats or any(_is_cat(c) for c in df.columns)):
2418
+ return code
2419
+ if not cats:
2420
+ pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2421
+ pool.sort(key=lambda t: t[1])
2422
+ cats = [t[0] for t in pool[:3]] if pool else []
2423
+ if not cats:
2424
+ return code
2425
+
2426
+ split_col, op, val = split
2427
+ cond_str = f"(df['{split_col}'] {op} {val})"
2428
+ snippet = f"""
2429
+ import numpy as np
2430
+ import pandas as pd
2431
+ import matplotlib.pyplot as plt
2432
+
2433
+ _mask_a = ({cond_str}) & df['{split_col}'].notna()
2434
+ _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2435
+
2436
+ _cohort_a_name = "{split_col} {op} {val}"
2437
+ _cohort_b_name = "NOT ({split_col} {op} {val})"
2438
+
2439
+ _cat_cols = {cats!r}
2440
+ n = len(_cat_cols)
2441
+ fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2442
+ if n == 1:
2443
+ axes = np.array([axes])
2444
+
2445
+ for i, col in enumerate(_cat_cols):
2446
+ s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2447
+ s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2448
+
2449
+ ax_a = axes[i, 0]; ax_b = axes[i, 1]
2450
+ if len(s_a) > 0:
2451
+ ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2452
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2453
+ ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2454
+
2455
+ if len(s_b) > 0:
2456
+ ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2457
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2458
+ ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2459
+
2460
+ plt.tight_layout(); plt.show()
2461
+
2462
+ # grouped bar complement
2463
+ for col in _cat_cols:
2464
+ _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2465
+ .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2466
+ _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2467
+ _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2468
+
2469
+ if _comp_pref == "grouped_bar":
2470
+ ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2471
+ ax.set_title(f"{col} by cohort (grouped)")
2472
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2473
+ plt.tight_layout(); plt.show()
2474
+
2475
+ elif _comp_pref == "stacked_bar":
2476
+ ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2477
+ ax.set_title(f"{col} by cohort (stacked)")
2478
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2479
+ plt.tight_layout(); plt.show()
2480
+
2481
+ elif _comp_pref == "stacked_bar_pct":
2482
+ _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2483
+ ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2484
+ ax.set_title(f"{col} by cohort (100% stacked)")
2485
+ ax.set_xlabel(col); ax.set_ylabel("Percent")
2486
+ plt.tight_layout(); plt.show()
2487
+
2488
+ elif _comp_pref == "heatmap":
2489
+ _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2490
+ import numpy as np
2491
+ fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2492
+ im = ax.imshow(_perc.values, aspect='auto')
2493
+ ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2494
+ ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2495
+ ax.set_title(f"{col} by cohort — % heatmap")
2496
+ for i in range(_perc.shape[0]):
2497
+ for j in range(_perc.shape[1]):
2498
+ ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2499
+ fig.colorbar(im, ax=ax, label="%")
2500
+ plt.tight_layout(); plt.show()
2501
+
2502
+ else: # counts_bar (default)
2503
+ ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2504
+ ax.set_title(f"{col}: total counts (both cohorts)")
2505
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2506
+ plt.tight_layout(); plt.show()
2507
+ """.lstrip()
2508
+ return snippet
2509
+
2510
+ # --------------- CASE B: facet-by (categories/bins) ---------------
2511
+ if facet:
2512
+ facet_col, how = facet
2513
+ # Build facet series
2514
+ if pd.api.types.is_numeric_dtype(df[facet_col]):
2515
+ if how == "bmi":
2516
+ facet_series = _bmi_bins(df[facet_col])
2517
+ else:
2518
+ # generic numeric bins: 3 equal-width bins by default
2519
+ facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2520
+ else:
2521
+ facet_series = df[facet_col].astype(str)
2522
+
2523
+ # Choose pie dimension (categorical to count inside each facet)
2524
+ pie_dim = None
2525
+ for c in cats:
2526
+ if c in df.columns and _is_cat(c):
2527
+ pie_dim = c; break
2528
+ if pie_dim is None:
2529
+ pie_dim = _fallback_cat()
2530
+ if pie_dim is None:
2531
+ return code
2532
+
2533
+ snippet = f"""
2534
+ import math
2535
+ import pandas as pd
2536
+ import matplotlib.pyplot as plt
2537
+
2538
+ df = df.copy()
2539
+ _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2540
+
2541
+ def _select_facet_col(df, preferred=None):
2542
+ if preferred is not None:
2543
+ return preferred
2544
+ # Prefer low-cardinality categoricals (readable pies/grids)
2545
+ cat_cols = [
2546
+ c for c in df.columns
2547
+ if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2548
+ and df[c].nunique() > 1 and df[c].nunique() <= 20
2549
+ ]
2550
+ if cat_cols:
2551
+ cat_cols.sort(key=lambda c: df[c].nunique())
2552
+ return cat_cols[0]
2553
+ # Else fall back to first usable numeric
2554
+ num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2555
+ return num_cols[0] if num_cols else None
2556
+
2557
+ _facet_col = _select_facet_col(df, _preferred)
2558
+
2559
+ if _facet_col is None:
2560
+ # Nothing suitable → single facet keeps pipeline alive
2561
+ df["__facet__"] = "All"
2562
+ else:
2563
+ s = df[_facet_col]
2564
+ if pd.api.types.is_numeric_dtype(s):
2565
+ # Robust numeric binning: quantiles first, fallback to equal-width
2566
+ uniq = pd.Series(s).dropna().nunique()
2567
+ q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2568
+ try:
2569
+ df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2570
+ except Exception:
2571
+ df["__facet__"] = pd.cut(s.astype(float), bins=q)
2572
+ else:
2573
+ # Cap long tails; keep top categories
2574
+ vc = s.astype(str).value_counts()
2575
+ keep = vc.index[:{top_n}]
2576
+ df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2577
+
2578
+ levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2579
+ levels = [x for x in levels if x != "nan"]
2580
+ levels.sort()
2581
+
2582
+ m = len(levels)
2583
+ cols = 3 if m >= 3 else m or 1
2584
+ rows = int(math.ceil(m / cols))
2585
+
2586
+ fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2587
+ if not isinstance(axes, (list, np.ndarray)):
2588
+ axes = np.array([[axes]])
2589
+ axes = axes.reshape(rows, cols)
2590
+
2591
+ for i, lvl in enumerate(levels):
2592
+ r, c = divmod(i, cols)
2593
+ ax = axes[r, c]
2594
+ s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2595
+ .astype(str).value_counts().nlargest({top_n}))
2596
+ if len(s) > 0:
2597
+ ax.pie(s.values, labels=[str(x) for x in s.index],
2598
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2599
+ ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2600
+
2601
+ # hide any empty subplots
2602
+ for j in range(m, rows*cols):
2603
+ r, c = divmod(j, cols)
2604
+ axes[r, c].axis("off")
2605
+
2606
+ plt.tight_layout(); plt.show()
2607
+
2608
+ # --- companion visual (adaptive) ---
2609
+ _comp_pref = "{_comp_pref}"
2610
+ # build contingency table: pie_dim x facet
2611
+ _tab = (df[["__facet__", "{pie_dim}"]]
2612
+ .dropna()
2613
+ .astype({{"__facet__": str, "{pie_dim}": str}})
2614
+ .value_counts()
2615
+ .unstack(level="__facet__")
2616
+ .fillna(0))
2617
+
2618
+ # keep top categories by overall size
2619
+ _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2620
+
2621
+ if _comp_pref == "grouped_bar":
2622
+ ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2623
+ ax.set_title("{pie_dim} by {facet_col} (grouped)")
2624
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2625
+ plt.tight_layout(); plt.show()
2626
+
2627
+ elif _comp_pref == "stacked_bar":
2628
+ ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2629
+ ax.set_title("{pie_dim} by {facet_col} (stacked)")
2630
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2631
+ plt.tight_layout(); plt.show()
2632
+
2633
+ elif _comp_pref == "stacked_bar_pct":
2634
+ _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2635
+ ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2636
+ ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2637
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2638
+ plt.tight_layout(); plt.show()
2639
+
2640
+ elif _comp_pref == "heatmap":
2641
+ _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2642
+ import numpy as np
2643
+ fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2644
+ im = ax.imshow(_perc.values, aspect='auto')
2645
+ ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2646
+ ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2647
+ ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2648
+ for i in range(_perc.shape[0]):
2649
+ for j in range(_perc.shape[1]):
2650
+ ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2651
+ fig.colorbar(im, ax=ax, label="%")
2652
+ plt.tight_layout(); plt.show()
2653
+
2654
+ else: # counts_bar (default denominators)
2655
+ _counts = df["__facet"].value_counts()
2656
+ ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2657
+ ax.set_title("Counts by {facet_col}")
2658
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2659
+ plt.tight_layout(); plt.show()
2660
+
2661
+ """.lstrip()
2662
+ return snippet
2663
+
2664
+ # --------------- CASE C: single pie ---------------
2665
+ chosen = None
2666
+ for c in cats:
2667
+ if c in df.columns and _is_cat(c):
2668
+ chosen = c; break
2669
+ if chosen is None:
2670
+ chosen = _fallback_cat()
2671
+
2672
+ if chosen:
2673
+ snippet = f"""
2674
+ import matplotlib.pyplot as plt
2675
+ counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2676
+ fig, ax = plt.subplots()
2677
+ if len(counts) > 0:
2678
+ ax.pie(counts.values, labels=[str(i) for i in counts.index],
2679
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2680
+ ax.set_title('Distribution of {chosen} (top {top_n})')
2681
+ ax.axis('equal')
2682
+ plt.show()
2683
+ """.lstrip()
2684
+ return snippet
2685
+
2686
+ # numeric last resort
2687
+ num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2688
+ if num_cols:
2689
+ col = num_cols[0]
2690
+ snippet = f"""
2691
+ import pandas as pd
2692
+ import matplotlib.pyplot as plt
2693
+ bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2694
+ counts = bins.value_counts().sort_index()
2695
+ fig, ax = plt.subplots()
2696
+ if len(counts) > 0:
2697
+ ax.pie(counts.values, labels=[str(i) for i in counts.index],
2698
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2699
+ ax.set_title('Distribution of {col} (binned)')
2700
+ ax.axis('equal')
2701
+ plt.show()
2702
+ """.lstrip()
2703
+ return snippet
2704
+
2705
+ return code
2706
+
2707
+
2708
+ def patch_fix_seaborn_palette_calls(code: str) -> str:
2709
+ """
2710
+ Removes seaborn `palette=` when no `hue=` is present in the same call.
2711
+ Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2712
+ """
2713
+ if "sns." not in code:
2714
+ return code
2715
+
2716
+ # Targets common seaborn plotters
2717
+ funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2718
+ pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2719
+
2720
+ def _fix_call(m):
2721
+ head, inner = m.group(1), m.group(2)
2722
+ # If there's already hue=, keep as is
2723
+ if re.search(r"(?<!\w)hue\s*=", inner):
2724
+ return f"{head}{inner})"
2725
+ # Otherwise remove palette=... safely (and any adjacent comma spacing)
2726
+ inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2727
+ inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2728
+ inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2729
+ return f"{head}{inner2})"
2730
+
2731
+ return pattern.sub(_fix_call, code)
2732
+
2733
+
2734
+ def patch_quiet_specific_warnings(code: str) -> str:
2735
+ """
2736
+ Inserts targeted warning filters (not blanket ignores).
2737
+ - seaborn palette/hue deprecation
2738
+ - python-dotenv parse chatter
2739
+ """
2740
+ prelude = (
2741
+ "import warnings\n"
2742
+ "warnings.filterwarnings(\n"
2743
+ " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
2744
+ "warnings.filterwarnings(\n"
2745
+ " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
2746
+ )
2747
+ # If warnings already imported once, just add filters; else insert full prelude.
2748
+ if "import warnings" in code:
2749
+ code = re.sub(
2750
+ r"(import warnings[^\n]*\n)",
2751
+ lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
2752
+ code,
2753
+ count=1
2754
+ )
2755
+
2756
+ else:
2757
+ # place after first import block if possible
2758
+ m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
2759
+ if m:
2760
+ idx = m.end()
2761
+ code = code[:idx] + prelude + code[idx:]
2762
+ else:
2763
+ code = prelude + code
2764
+ return code
2765
+
2766
+
2767
+ def _norm_col_name(s: str) -> str:
2768
+ """normalise a column name: lowercase + strip non-alphanumerics."""
2769
+ return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2770
+
2771
+
2772
+ def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2773
+ """return the actual df column that matches any candidate (after normalisation)."""
2774
+ norm_map = {_norm_col_name(c): c for c in df.columns}
2775
+ for cand in candidates:
2776
+ hit = norm_map.get(_norm_col_name(cand))
2777
+ if hit is not None:
2778
+ return hit
2779
+ return None
2780
+
2781
+
2782
+ def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2783
+ """
2784
+ If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2785
+ Returns (df, found_bool).
2786
+ """
2787
+ if target in df.columns:
2788
+ return df, True
2789
+ col = _first_present(df, [target, *aliases])
2790
+ if col is None:
2791
+ return df, False
2792
+ df[target] = df[col]
2793
+ return df, True
2794
+
2795
+
2796
+ def strip_python_dotenv(code: str) -> str:
2797
+ """
2798
+ Remove any use of python-dotenv from generated code, including:
2799
+ - single and multi-line 'from dotenv import ...'
2800
+ - 'import dotenv' (with or without alias) and calls via any alias
2801
+ - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2802
+ - IPython magics (%load_ext dotenv, %dotenv, %env …)
2803
+ - shell installs like '!pip install python-dotenv'
2804
+ """
2805
+ original = code
2806
+
2807
+ # 0) Kill IPython magics & shell installs referencing dotenv
2808
+ code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2809
+ code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2810
+ code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2811
+ code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2812
+
2813
+ # 1) Remove single-line 'from dotenv import ...'
2814
+ code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2815
+
2816
+ # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2817
+ code = re.sub(
2818
+ r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2819
+ "",
2820
+ code,
2821
+ flags=re.MULTILINE,
2822
+ )
2823
+
2824
+ # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2825
+ aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2826
+ code, flags=re.MULTILINE)
2827
+ code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2828
+ "", code, flags=re.MULTILINE)
2829
+
2830
+ # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2831
+ # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2832
+ fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2833
+ # bare calls
2834
+ code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2835
+ # dotted calls with any identifier prefix (alias or module)
2836
+ code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2837
+ "", code, flags=re.MULTILINE)
2838
+
2839
+ # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2840
+ for al in aliases:
2841
+ code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2842
+
2843
+ # 6) Tidy excess blank lines
2844
+ code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2845
+ return code
2846
+
2847
+
2848
+ def fix_predict_calls_records_arg(code: str) -> str:
2849
+ """
2850
+ If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2851
+ (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2852
+ Works line-by-line to avoid over-rewrites elsewhere.
2853
+ Examples fixed:
2854
+ predict_patient(X_test.iloc[:5].to_dict('records'))
2855
+ predict_risk(df.head(3).to_dict(orient="records"))
2856
+ → predict_patient(X_test.iloc[:5])
2857
+ """
2858
+ fixed_lines = []
2859
+ for line in code.splitlines():
2860
+ if "predict_" in line and "to_dict" in line and "records" in line:
2861
+ line = re.sub(
2862
+ r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2863
+ "",
2864
+ line
2865
+ )
2866
+ fixed_lines.append(line)
2867
+ return "\n".join(fixed_lines)
2868
+
2869
+
2870
+ def fix_fstring_backslash_paths(code: str) -> str:
2871
+ """
2872
+ Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2873
+ → f"...{os.path.join(out_dir, r'plots\\img.png')}"
2874
+ Only touches f-strings that contain a backslash path inside {...}.
2875
+ """
2876
+ def _fix_line(line: str) -> str:
2877
+ # quick check: only f-strings need scanning
2878
+ if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2879
+ return line
2880
+ # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2881
+ pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2882
+ def repl(m):
2883
+ left = m.group(1)
2884
+ right = m.group(2).strip().replace('"', '\\"')
2885
+ return "{os.path.join(" + left + ', r"' + right + '")}'
2886
+ return pattern.sub(repl, line)
2887
+
2888
+ return "\n".join(_fix_line(ln) for ln in code.splitlines())
2889
+
2890
+
2891
+ def ensure_os_import(code: str) -> str:
2892
+ """
2893
+ If os.path.join is used but 'import os' is missing, inject it at the top.
2894
+ """
2895
+ needs = "os.path.join(" in code
2896
+ has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2897
+ has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2898
+ if needs and not (has_import_os or has_from_os):
2899
+ return "import os\n" + code
2900
+ return code
2901
+
2902
+
2903
+ def fix_seaborn_boxplot_nameerror(code: str) -> str:
2904
+ """
2905
+ Fix bad calls like: sns.boxplot(boxplot)
2906
+ Heuristic:
2907
+ - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2908
+ - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
2909
+ - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2910
+ - Else → sns.boxplot(ax=ax)
2911
+ Ensures a matplotlib Axes 'ax' exists.
2912
+ """
2913
+ pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2914
+ if not pattern.search(code):
2915
+ return code
2916
+
2917
+ has_plot_df = re.search(r"\bplot_df\b", code) is not None
2918
+ has_var = re.search(r"\bvar\b", code) is not None
2919
+ has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2920
+
2921
+ if has_plot_df and has_var and has_fh:
2922
+ replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2923
+ elif has_plot_df and has_var:
2924
+ replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2925
+ elif has_plot_df:
2926
+ replacement = "sns.boxplot(data=plot_df, ax=ax)"
2927
+ else:
2928
+ replacement = "sns.boxplot(ax=ax)"
2929
+
2930
+ fixed = pattern.sub(replacement, code)
2931
+
2932
+ # Ensure 'fig, ax = plt.subplots(...)' exists
2933
+ if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2934
+ # Insert right before the first seaborn call
2935
+ m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2936
+ insert_at = m.start() if m else 0
2937
+ fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2938
+
2939
+ return fixed
2940
+
2941
+
2942
+ def fix_seaborn_barplot_nameerror(code: str) -> str:
2943
+ """
2944
+ Fix bad calls like: sns.barplot(barplot)
2945
+ Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2946
+ otherwise degrade safely to an empty call on an existing Axes.
2947
+ """
2948
+ import re
2949
+ pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2950
+ if not pattern.search(code):
2951
+ return code
2952
+
2953
+ has_plot_df = re.search(r"\bplot_df\b", code) is not None
2954
+ has_var = re.search(r"\bvar\b", code) is not None
2955
+ has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2956
+
2957
+ if has_plot_df and has_var and has_fh:
2958
+ replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2959
+ elif has_plot_df and has_var:
2960
+ replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2961
+ elif has_plot_df:
2962
+ replacement = "sns.barplot(data=plot_df, ax=ax)"
2963
+ else:
2964
+ replacement = "sns.barplot(ax=ax)"
2965
+
2966
+ # ensure an Axes 'ax' exists (no-op if already present)
2967
+ if "ax =" not in code:
2968
+ code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2969
+
2970
+ return pattern.sub(replacement, code)
2971
+
2972
+
2973
+ def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2974
+ """
2975
+ Parses the raw text to extract and format the 'refined question',
2976
+ 'intents (tasks)', and 'chronology of tasks' sections.
2977
+ Args:
2978
+ raw_text: The complete input string containing the ML pipeline structure.
2979
+ Returns:
2980
+ A tuple containing:
2981
+ (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2982
+ """
2983
+ # --- 1. Regex Pattern to Extract Sections ---
2984
+ # The pattern uses capturing groups (?) to look for the section headers
2985
+ # (e.g., 'refined question:') and captures all the content until the next
2986
+ # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2987
+
2988
+ pattern = re.compile(
2989
+ r"refined question:(?P<question>.*?)"
2990
+ r"intents \(tasks\):(?P<intents>.*?)"
2991
+ r"Chronology of tasks:(?P<chronology>.*)",
2992
+ re.IGNORECASE | re.DOTALL
2993
+ )
2994
+
2995
+ match = pattern.search(raw_text)
2996
+
2997
+ if not match:
2998
+ raise ValueError("Input text structure does not match the expected pattern.")
2999
+
3000
+ # --- 2. Extract Content ---
3001
+ question_content = match.group('question').strip()
3002
+ intents_content = match.group('intents').strip()
3003
+ chronology_content = match.group('chronology').strip()
3004
+
3005
+ # --- 3. Formatting Functions ---
3006
+
3007
+ def format_question(content):
3008
+ """Formats the Refined Question section."""
3009
+ # Clean up leading/trailing whitespace and ensure clean paragraphs
3010
+ content = content.strip().replace('\n', ' ').replace(' ', ' ')
3011
+
3012
+ # Simple formatting using Markdown headers and bolding
3013
+ formatted = (
3014
+ # "## 1. Project Goal and Objectives\n\n"
3015
+ "<b> Refined Question:</b>\n"
3016
+ f"{content}\n"
3017
+ )
3018
+ return formatted
3019
+
3020
+ def format_intents(content):
3021
+ """Formats the Intents (Tasks) section as a structured list."""
3022
+ # Use regex to find and format each numbered task
3023
+ # It finds 'N. **Text** - ...' and breaks it down.
3024
+
3025
+ tasks = []
3026
+ # Pattern: N. **Text** - Content (including newlines, non-greedy)
3027
+ # We need to explicitly handle the list items starting with '-' within the content
3028
+ task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
3029
+
3030
+ # Split the content by lines and join tasks back into clean strings
3031
+ raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
3032
+
3033
+ for task in raw_tasks:
3034
+ # Replace the initial task number and **Heading** with a Heading 3
3035
+ task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
3036
+
3037
+ # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
3038
+ task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
3039
+ tasks.append(task)
3040
+
3041
+ formatted_tasks = "\n\n".join(tasks)
3042
+
3043
+ return (
3044
+ "\n---\n"
3045
+ "## 2. Methodology and Tasks\n\n"
3046
+ f"{formatted_tasks}\n"
3047
+ )
3048
+
3049
+ def format_chronology(content):
3050
+ """Formats the Chronology section."""
3051
+ # Uses the given LaTeX format
3052
+ content = content.strip().replace(' ', ' \rightarrow ')
3053
+ formatted = (
3054
+ "\n---\n"
3055
+ "## 3. Chronology of Tasks\n"
3056
+ f"$$\\text{{{content}}}$$"
3057
+ )
3058
+ return formatted
3059
+
3060
+ # --- 4. Format and Return ---
3061
+ formatted_question = format_question(question_content)
3062
+ formatted_intents = format_intents(intents_content)
3063
+ formatted_chronology = format_chronology(chronology_content)
3064
+
3065
+ return formatted_question, formatted_intents, formatted_chronology
3066
+
3067
+
3068
+ def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
3069
+ """Combines all formatted parts into a final report string."""
3070
+ return (
3071
+ "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
3072
+ f"{formatted_question}\n"
3073
+ f"{formatted_intents}\n"
3074
+ f"{formatted_chronology}\n"
3075
+ )
3076
+
3077
+
3078
+ def fix_confusion_matrix_for_multilabel(code: str) -> str:
3079
+ """
3080
+ Replace ConfusionMatrixDisplay.from_estimator(...) usages with
3081
+ from_predictions(...) which works for multi-label loops without requiring
3082
+ the estimator to expose _estimator_type.
3083
+ """
3084
+ return re.sub(
3085
+ r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
3086
+ r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
3087
+ code
3088
+ )
3089
+
3090
+
3091
+ def smx_auto_title_plots(ctx=None, fallback="Analysis"):
3092
+ """
3093
+ Ensure every Matplotlib/Seaborn Axes has a title.
3094
+ Uses refined_question -> askai_question -> fallback.
3095
+ Only sets a title if it's currently empty.
3096
+ """
3097
+ import matplotlib.pyplot as plt
3098
+
3099
+ def _all_figures():
3100
+ try:
3101
+ from matplotlib._pylab_helpers import Gcf
3102
+ return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
3103
+ except Exception:
3104
+ # Best effort fallback
3105
+ nums = plt.get_fignums()
3106
+ return [plt.figure(n) for n in nums] if nums else []
3107
+
3108
+ # Choose a concise title
3109
+ title = None
3110
+ if isinstance(ctx, dict):
3111
+ title = ctx.get("refined_question") or ctx.get("askai_question")
3112
+ title = (str(title).strip().splitlines()[0][:120]) if title else fallback
3113
+
3114
+ for fig in _all_figures():
3115
+ for ax in getattr(fig, "axes", []):
3116
+ try:
3117
+ if not (ax.get_title() or "").strip():
3118
+ ax.set_title(title)
3119
+ except Exception:
3120
+ pass
3121
+ try:
3122
+ fig.tight_layout()
3123
+ except Exception:
3124
+ pass
3125
+
3126
+
3127
+ def patch_fix_sentinel_plot_calls(code: str) -> str:
3128
+ """
3129
+ Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
3130
+ SB_barplot(barplot) -> SB_barplot()
3131
+ SB_barplot(barplot, ...) -> SB_barplot(...)
3132
+ sns.barplot(barplot) -> SB_barplot()
3133
+ sns.barplot(barplot, ...) -> SB_barplot(...)
3134
+ Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
3135
+ """
3136
+ names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
3137
+ for n in names:
3138
+ # SB_* with sentinel as the first arg (with or without trailing args)
3139
+ code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3140
+ code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3141
+ # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
3142
+ code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3143
+ code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3144
+ return code
3145
+
3146
+
3147
+ def patch_rmse_calls(code: str) -> str:
3148
+ """
3149
+ Make RMSE robust across sklearn versions.
3150
+ - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
3151
+ - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
3152
+ """
3153
+ import re
3154
+ # (a) Specific RMSE pattern
3155
+ code = re.sub(
3156
+ r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3157
+ r"_SMX_rmse(\1)",
3158
+ code,
3159
+ flags=re.DOTALL
3160
+ )
3161
+ # (b) Guard any other MSE calls
3162
+ code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3163
+ return code