syntaxmatrix 2.3.5__py3-none-any.whl → 2.5.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
syntaxmatrix/utils.py CHANGED
@@ -1,15 +1,1189 @@
1
- import openai
2
- from openai import OpenAI
3
- import re, os, textwrap
1
+ from __future__ import annotations
2
+ import re, textwrap
4
3
  import pandas as pd
5
- import matplotlib.pyplot as plt
4
+ import numpy as np
6
5
  import warnings
7
- from .model_templates import classification
8
- import syntaxmatrix as smx
9
6
 
7
+ from difflib import get_close_matches
8
+ from typing import Iterable, Tuple, Dict
9
+ import inspect
10
+ from sklearn.preprocessing import OneHotEncoder
11
+
12
+
13
+ from syntaxmatrix.agentic.model_templates import (
14
+ classification, regression, multilabel_classification,
15
+ eda_overview, eda_correlation,
16
+ anomaly_detection, ts_anomaly_detection,
17
+ dimensionality_reduction, feature_selection,
18
+ time_series_forecasting, time_series_classification,
19
+ unknown_group_proxy_pack, viz_line,
20
+ clustering, recommendation, topic_modelling,
21
+ viz_pie, viz_count_bar, viz_box, viz_scatter,
22
+ viz_stacked_bar, viz_distribution, viz_area, viz_kde,
23
+ )
24
+ import ast
10
25
 
11
26
  warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
12
27
 
28
+ INJECTABLE_INTENTS = [
29
+ "classification",
30
+ "multilabel_classification",
31
+ "regression",
32
+ "anomaly_detection",
33
+ "time_series_forecasting",
34
+ "time_series_classification",
35
+ "ts_anomaly_detection",
36
+ "dimensionality_reduction",
37
+ "feature_selection",
38
+ "clustering",
39
+ "eda",
40
+ "correlation_analysis",
41
+ "visualisation",
42
+ "recommendation",
43
+ "topic_modelling",
44
+ ]
45
+
46
+ def classify_ml_job(prompt: str) -> str:
47
+ """
48
+ Very-light intent classifier.
49
+ Returns one of:
50
+ 'stat_test' | 'time_series' | 'clustering'
51
+ 'classification' | 'regression' | 'eda'
52
+ """
53
+ p = prompt.lower()
54
+
55
+ greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
56
+ if any(p.startswith(g) or p == g for g in greetings):
57
+ return "greeting"
58
+
59
+ # Feature selection / importance intent
60
+ if any(k in p for k in (
61
+ "feature selection", "select k best", "selectkbest", "rfe",
62
+ "mutual information", "feature importance", "permutation importance",
63
+ "feature engineering suggestions"
64
+ )):
65
+ return "feature_selection"
66
+
67
+ # Dimensionality reduction intent
68
+ if any(k in p for k in (
69
+ "pca", "principal component", "dimensionality reduction",
70
+ "reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
71
+ )):
72
+ return "dimensionality_reduction"
73
+
74
+ # Anomaly / outlier intent
75
+ if any(k in p for k in (
76
+ "anomaly", "anomalies", "outlier", "outliers", "novelty",
77
+ "fraud", "deviation", "rare event", "rare events", "odd pattern",
78
+ "suspicious"
79
+ )):
80
+ return "anomaly_detection"
81
+
82
+ if any(k in p for k in ("t-test", "anova", "p-value")):
83
+ return "stat_test"
84
+ if "forecast" in p or "prophet" in p:
85
+ return "time_series"
86
+ if "cluster" in p or "kmeans" in p:
87
+ return "clustering"
88
+ if any(k in p for k in ("accuracy", "precision", "roc")):
89
+ return "classification"
90
+ if any(k in p for k in ("rmse", "r2", "mae")):
91
+ return "regression"
92
+
93
+ return "eda"
94
+
95
+
96
+ def harden_ai_code(code: str) -> str:
97
+ """
98
+ Make any AI-generated cell resilient:
99
+ - Safe seaborn wrappers + sentinel vars (boxplot/barplot/etc.)
100
+ - Remove 'numeric_only=' args
101
+ - Replace pd.concat(...) with _safe_concat(...)
102
+ - Relax 'required_cols' hard fails
103
+ - Make static numeric_vars dynamic
104
+ - Wrap the whole block in try/except so no exception bubbles up
105
+ """
106
+ # Remove any LLM-added try/except blocks (hardener adds its own)
107
+ import re
108
+
109
+ def strip_placeholders(code: str) -> str:
110
+ code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
111
+ "show('⚠ Block skipped due to an error.')",
112
+ code)
113
+ code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
114
+ return code
115
+
116
+ def _indent(code: str, spaces: int = 4) -> str:
117
+ pad = " " * spaces
118
+ return "\n".join(pad + line for line in code.splitlines())
119
+
120
+ def _SMX_OHE(**k):
121
+ # normalise arg name across sklearn versions
122
+ if "sparse" in k and "sparse_output" not in k:
123
+ k["sparse_output"] = k.pop("sparse")
124
+ # default behaviour we want
125
+ k.setdefault("handle_unknown", "ignore")
126
+ k.setdefault("sparse_output", False)
127
+ try:
128
+ # if running on old sklearn without sparse_output, translate back
129
+ if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
130
+ if "sparse_output" in k:
131
+ k["sparse"] = k.pop("sparse_output")
132
+ return OneHotEncoder(**k)
133
+ except TypeError:
134
+ # final fallback: try legacy name
135
+ if "sparse_output" in k:
136
+ k["sparse"] = k.pop("sparse_output")
137
+ return OneHotEncoder(**k)
138
+
139
+ def _strip_stray_backrefs(code: str) -> str:
140
+ code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
141
+ code = re.sub(r'(?m)[;]\s*\\\d+\s*', '; ', code)
142
+ return code
143
+
144
+ def _wrap_metric_calls(code: str) -> str:
145
+ names = [
146
+ "r2_score","accuracy_score","precision_score","recall_score","f1_score",
147
+ "roc_auc_score","classification_report","confusion_matrix",
148
+ "mean_absolute_error","mean_absolute_percentage_error",
149
+ "explained_variance_score","log_loss","average_precision_score",
150
+ "precision_recall_fscore_support"
151
+ ]
152
+ pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
153
+ def repl(m):
154
+ prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
155
+ name = m.group(2)
156
+ return f"_SMX_call({prefix}{name}, "
157
+ return pat.sub(repl, code)
158
+
159
+ def _smx_patch_mean_squared_error_squared_kw():
160
+ """
161
+ sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
162
+ Patch the module attr so 'from sklearn.metrics import mean_squared_error'
163
+ receives a wrapper that drops 'squared' if the underlying call rejects it.
164
+ """
165
+ try:
166
+ import sklearn.metrics as _sm
167
+ _orig = getattr(_sm, "mean_squared_error", None)
168
+ if not callable(_orig):
169
+ return
170
+ def _mse_compat(y_true, y_pred, *a, **k):
171
+ if "squared" in k:
172
+ try:
173
+ return _orig(y_true, y_pred, *a, **k)
174
+ except TypeError:
175
+ k.pop("squared", None)
176
+ return _orig(y_true, y_pred, *a, **k)
177
+ return _orig(y_true, y_pred, *a, **k)
178
+ _sm.mean_squared_error = _mse_compat
179
+ except Exception:
180
+ pass
181
+
182
+ def _smx_patch_kmeans_n_init_auto():
183
+ """
184
+ sklearn>=1.4 accepts n_init='auto'; older versions want an int.
185
+ Patch sklearn.cluster.KMeans so 'auto' is converted to 10 if TypeError occurs.
186
+ """
187
+ try:
188
+ import sklearn.cluster as _sc
189
+ _Orig = getattr(_sc, "KMeans", None)
190
+ if _Orig is None:
191
+ return
192
+ class KMeansCompat(_Orig):
193
+ def __init__(self, *a, **k):
194
+ if isinstance(k.get("n_init", None), str):
195
+ try:
196
+ super().__init__(*a, **k)
197
+ return
198
+ except TypeError:
199
+ k["n_init"] = 10
200
+ super().__init__(*a, **k)
201
+ _sc.KMeans = KMeansCompat
202
+ except Exception:
203
+ pass
204
+
205
+ def _smx_patch_ohe_name_api():
206
+ """
207
+ Guard get_feature_names_out on older OneHotEncoder.
208
+ Your templates already use _SMX_OHE; this adds a soft fallback for feature names.
209
+ """
210
+ try:
211
+ from sklearn.preprocessing import OneHotEncoder as _OHE
212
+ _orig_get = getattr(_OHE, "get_feature_names_out", None)
213
+ if _orig_get is None:
214
+ # Monkey-patch instance method via mixin
215
+ def _fallback_get_feature_names_out(self, input_features=None):
216
+ cats = getattr(self, "categories_", None) or []
217
+ input_features = input_features or [f"x{i}" for i in range(len(cats))]
218
+ names = []
219
+ for base, cat_list in zip(input_features, cats):
220
+ for j, _ in enumerate(cat_list):
221
+ names.append(f"{base}__{j}")
222
+ return names
223
+ _OHE.get_feature_names_out = _fallback_get_feature_names_out
224
+ except Exception:
225
+ pass
226
+
227
+ # Register and run patches once per execution
228
+ for _patch in (
229
+ _smx_patch_mean_squared_error_squared_kw,
230
+ _smx_patch_kmeans_n_init_auto,
231
+ _smx_patch_ohe_name_api,
232
+ ):
233
+ try:
234
+ _patch()
235
+ except Exception:
236
+ pass
237
+
238
+ PREFACE = (
239
+ "# === SMX Auto-Hardening Preface (do not edit) ===\n"
240
+ "import warnings, numpy as np, pandas as pd, matplotlib.pyplot as plt\n"
241
+ "warnings.filterwarnings('ignore')\n"
242
+ "try:\n"
243
+ " import seaborn as sns\n"
244
+ "except Exception:\n"
245
+ " class _Dummy:\n"
246
+ " def __getattr__(self, name):\n"
247
+ " def _f(*a, **k):\n"
248
+ " from syntaxmatrix.display import show\n"
249
+ " show('⚠ seaborn not available; plot skipped.')\n"
250
+ " return _f\n"
251
+ " sns = _Dummy()\n"
252
+ "\n"
253
+ "from syntaxmatrix.display import show as _SMX_base_show\n"
254
+ "def _SMX_caption_from_ctx():\n"
255
+ " g = globals()\n"
256
+ " t = g.get('refined_question') or g.get('askai_question') or 'Table'\n"
257
+ " return str(t).strip().splitlines()[0][:120]\n"
258
+ "\n"
259
+ "def show(obj, title=None):\n"
260
+ " try:\n"
261
+ " import pandas as pd\n"
262
+ " if isinstance(obj, pd.DataFrame):\n"
263
+ " cap = (title or _SMX_caption_from_ctx())\n"
264
+ " try:\n"
265
+ " return _SMX_base_show(obj.style.set_caption(cap))\n"
266
+ " except Exception:\n"
267
+ " pass\n"
268
+ " except Exception:\n"
269
+ " pass\n"
270
+ " return _SMX_base_show(obj)\n"
271
+ "\n"
272
+ "def _SMX_axes_have_titles(fig=None):\n"
273
+ " import matplotlib.pyplot as _plt\n"
274
+ " fig = fig or _plt.gcf()\n"
275
+ " try:\n"
276
+ " for _ax in fig.get_axes():\n"
277
+ " if (_ax.get_title() or '').strip():\n"
278
+ " return True\n"
279
+ " except Exception:\n"
280
+ " pass\n"
281
+ " return False\n"
282
+ "\n"
283
+ "def _SMX_export_png():\n"
284
+ " import io, base64\n"
285
+ " fig = plt.gcf()\n"
286
+ " try:\n"
287
+ " if not _SMX_axes_have_titles(fig):\n"
288
+ " fig.suptitle(_SMX_caption_from_ctx(), fontsize=10)\n"
289
+ " except Exception:\n"
290
+ " pass\n"
291
+ " buf = io.BytesIO()\n"
292
+ " plt.savefig(buf, format='png', bbox_inches='tight')\n"
293
+ " buf.seek(0)\n"
294
+ " from IPython.display import display, HTML\n"
295
+ " _img = base64.b64encode(buf.read()).decode('ascii')\n"
296
+ " display(HTML(f\"<img src='data:image/png;base64,{_img}' style='max-width:100%;height:auto;border:1px solid #ccc;border-radius:4px;'/>\"))\n"
297
+ " plt.close()\n"
298
+ "\n"
299
+ "def _pick_df():\n"
300
+ " return globals().get('df', None)\n"
301
+ "\n"
302
+ "def _pick_ax_slot():\n"
303
+ " ax = None\n"
304
+ " try:\n"
305
+ " _axes = globals().get('axes', None)\n"
306
+ " import numpy as _np\n"
307
+ " if _axes is not None:\n"
308
+ " arr = _np.ravel(_axes)\n"
309
+ " for _a in arr:\n"
310
+ " try:\n"
311
+ " if hasattr(_a,'has_data') and not _a.has_data():\n"
312
+ " ax = _a; break\n"
313
+ " except Exception:\n"
314
+ " continue\n"
315
+ " except Exception:\n"
316
+ " ax = None\n"
317
+ " return ax\n"
318
+ "\n"
319
+ "def _first_numeric(_d):\n"
320
+ " import numpy as np, pandas as pd\n"
321
+ " try:\n"
322
+ " preferred = [\"median_house_value\", \"price\", \"value\", \"target\", \"label\", \"y\"]\n"
323
+ " for c in preferred:\n"
324
+ " if c in _d.columns and pd.api.types.is_numeric_dtype(_d[c]):\n"
325
+ " return c\n"
326
+ " cols = _d.select_dtypes(include=[np.number]).columns.tolist()\n"
327
+ " return cols[0] if cols else None\n"
328
+ " except Exception:\n"
329
+ " return None\n"
330
+ "\n"
331
+ "def _first_categorical(_d):\n"
332
+ " import pandas as pd, numpy as np\n"
333
+ " try:\n"
334
+ " num = set(_d.select_dtypes(include=[np.number]).columns.tolist())\n"
335
+ " cand = [c for c in _d.columns if c not in num and _d[c].nunique(dropna=True) <= 50]\n"
336
+ " return cand[0] if cand else None\n"
337
+ " except Exception:\n"
338
+ " return None\n"
339
+ "\n"
340
+ "boxplot = barplot = histplot = distplot = lineplot = countplot = heatmap = pairplot = None\n"
341
+ "\n"
342
+ "def _safe_plot(func, *args, **kwargs):\n"
343
+ " try:\n"
344
+ " ax = func(*args, **kwargs)\n"
345
+ " if ax is None:\n"
346
+ " ax = plt.gca()\n"
347
+ " try:\n"
348
+ " if hasattr(ax, 'has_data') and not ax.has_data():\n"
349
+ " from syntaxmatrix.display import show as _show\n"
350
+ " _show('⚠ Empty plot: no data drawn.')\n"
351
+ " except Exception:\n"
352
+ " pass\n"
353
+ " try: plt.tight_layout()\n"
354
+ " except Exception: pass\n"
355
+ " return ax\n"
356
+ " except Exception as e:\n"
357
+ " from syntaxmatrix.display import show as _show\n"
358
+ " _show(f'⚠ Plot skipped: {type(e).__name__}: {e}')\n"
359
+ " return None\n"
360
+ "\n"
361
+ "def SB_histplot(*a, **k):\n"
362
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
363
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
364
+ " if (not a or _sentinel) and not k:\n"
365
+ " d = _pick_df()\n"
366
+ " if d is not None:\n"
367
+ " x = _first_numeric(d)\n"
368
+ " if x is not None:\n"
369
+ " def _draw():\n"
370
+ " plt.hist(d[x].dropna())\n"
371
+ " ax = plt.gca()\n"
372
+ " if not (ax.get_title() or '').strip():\n"
373
+ " ax.set_title(f'Distribution of {x}')\n"
374
+ " return ax\n"
375
+ " return _safe_plot(lambda **kw: _draw())\n"
376
+ " if _missing:\n"
377
+ " return _safe_plot(lambda **kw: plt.hist([]))\n"
378
+ " if _sentinel:\n"
379
+ " a = a[1:]\n"
380
+ " return _safe_plot(getattr(sns,'histplot', plt.hist), *a, **k)\n"
381
+ "\n"
382
+ "def SB_barplot(*a, **k):\n"
383
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
384
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
385
+ " _ax = k.get('ax') or _pick_ax_slot()\n"
386
+ " if _ax is not None:\n"
387
+ " try: plt.sca(_ax)\n"
388
+ " except Exception: pass\n"
389
+ " k.setdefault('ax', _ax)\n"
390
+ " if (not a or _sentinel) and not k:\n"
391
+ " d = _pick_df()\n"
392
+ " if d is not None:\n"
393
+ " x = _first_categorical(d)\n"
394
+ " y = _first_numeric(d)\n"
395
+ " if x and y:\n"
396
+ " import pandas as _pd\n"
397
+ " g = d.groupby(x)[y].mean().reset_index()\n"
398
+ " def _draw():\n"
399
+ " if _missing:\n"
400
+ " plt.bar(g[x], g[y])\n"
401
+ " else:\n"
402
+ " sns.barplot(data=g, x=x, y=y, ax=k.get('ax'))\n"
403
+ " ax = plt.gca()\n"
404
+ " if not (ax.get_title() or '').strip():\n"
405
+ " ax.set_title(f'Mean {y} by {x}')\n"
406
+ " return ax\n"
407
+ " return _safe_plot(lambda **kw: _draw())\n"
408
+ " if _missing:\n"
409
+ " return _safe_plot(lambda **kw: plt.bar([], []))\n"
410
+ " if _sentinel:\n"
411
+ " a = a[1:]\n"
412
+ " return _safe_plot(sns.barplot, *a, **k)\n"
413
+ "\n"
414
+ "def SB_boxplot(*a, **k):\n"
415
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
416
+ " _sentinel = (len(a) >= 1 and a[0] is None)\n"
417
+ " _ax = k.get('ax') or _pick_ax_slot()\n"
418
+ " if _ax is not None:\n"
419
+ " try: plt.sca(_ax)\n"
420
+ " except Exception: pass\n"
421
+ " k.setdefault('ax', _ax)\n"
422
+ " if (not a or _sentinel) and not k:\n"
423
+ " d = _pick_df()\n"
424
+ " if d is not None:\n"
425
+ " x = _first_categorical(d)\n"
426
+ " y = _first_numeric(d)\n"
427
+ " if x and y:\n"
428
+ " def _draw():\n"
429
+ " if _missing:\n"
430
+ " plt.boxplot(d[y].dropna())\n"
431
+ " else:\n"
432
+ " sns.boxplot(data=d, x=x, y=y, ax=k.get('ax'))\n"
433
+ " ax = plt.gca()\n"
434
+ " if not (ax.get_title() or '').strip():\n"
435
+ " ax.set_title(f'Distribution of {y} by {x}')\n"
436
+ " return ax\n"
437
+ " return _safe_plot(lambda **kw: _draw())\n"
438
+ " if _missing:\n"
439
+ " return _safe_plot(lambda **kw: plt.boxplot([]))\n"
440
+ " if _sentinel:\n"
441
+ " a = a[1:]\n"
442
+ " return _safe_plot(sns.boxplot, *a, **k)\n"
443
+ "\n"
444
+ "def SB_scatterplot(*a, **k):\n"
445
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
446
+ " fn = getattr(sns,'scatterplot', None)\n"
447
+ " # If seaborn is unavailable OR the caller passed (data=..., x='col', y='col'),\n"
448
+ " # use a robust matplotlib path that looks up data and coerces to numeric.\n"
449
+ " if _missing or fn is None:\n"
450
+ " data = k.get('data'); x = k.get('x'); y = k.get('y')\n"
451
+ " if data is not None and isinstance(x, str) and isinstance(y, str) and x in data.columns and y in data.columns:\n"
452
+ " xs = pd.to_numeric(data[x], errors='coerce')\n"
453
+ " ys = pd.to_numeric(data[y], errors='coerce')\n"
454
+ " m = xs.notna() & ys.notna()\n"
455
+ " def _draw():\n"
456
+ " plt.scatter(xs[m], ys[m])\n"
457
+ " ax = plt.gca()\n"
458
+ " if not (ax.get_title() or '').strip():\n"
459
+ " ax.set_title(f'{y} vs {x}')\n"
460
+ " return ax\n"
461
+ " return _safe_plot(lambda **kw: _draw())\n"
462
+ " # else: fall back to auto-pick two numeric columns\n"
463
+ " d = _pick_df()\n"
464
+ " if d is not None:\n"
465
+ " num = d.select_dtypes(include=[np.number]).columns.tolist()\n"
466
+ " if len(num) >= 2:\n"
467
+ " def _draw2():\n"
468
+ " plt.scatter(d[num[0]], d[num[1]])\n"
469
+ " ax = plt.gca()\n"
470
+ " if not (ax.get_title() or '').strip():\n"
471
+ " ax.set_title(f'{num[1]} vs {num[0]}')\n"
472
+ " return ax\n"
473
+ " return _safe_plot(lambda **kw: _draw2())\n"
474
+ " return _safe_plot(lambda **kw: plt.scatter([], []))\n"
475
+ " # seaborn path\n"
476
+ " return _safe_plot(fn, *a, **k)\n"
477
+ "\n"
478
+ "def SB_heatmap(*a, **k):\n"
479
+ " _missing = (getattr(sns, '__class__', type(sns)).__name__ == '_Dummy')\n"
480
+ " data = None\n"
481
+ " if a:\n"
482
+ " data = a[0]\n"
483
+ " elif 'data' in k:\n"
484
+ " data = k['data']\n"
485
+ " if data is None:\n"
486
+ " d = _pick_df()\n"
487
+ " try:\n"
488
+ " if d is not None:\n"
489
+ " import numpy as _np\n"
490
+ " data = d.select_dtypes(include=[_np.number]).corr()\n"
491
+ " except Exception:\n"
492
+ " data = None\n"
493
+ " if data is None:\n"
494
+ " from syntaxmatrix.display import show as _show\n"
495
+ " _show('⚠ Heatmap skipped: no data.')\n"
496
+ " return None\n"
497
+ " if not _missing and hasattr(sns, 'heatmap'):\n"
498
+ " _k = {kk: vv for kk, vv in k.items() if kk != 'data'}\n"
499
+ " def _draw():\n"
500
+ " ax = sns.heatmap(data, **_k)\n"
501
+ " try:\n"
502
+ " ax = ax or plt.gca()\n"
503
+ " if not (ax.get_title() or '').strip():\n"
504
+ " ax.set_title('Correlation Heatmap')\n"
505
+ " except Exception:\n"
506
+ " pass\n"
507
+ " return ax\n"
508
+ " return _safe_plot(lambda **kw: _draw())\n"
509
+ " def _mat_heat():\n"
510
+ " im = plt.imshow(data, aspect='auto')\n"
511
+ " try: plt.colorbar()\n"
512
+ " except Exception: pass\n"
513
+ " try:\n"
514
+ " cols = list(getattr(data, 'columns', []))\n"
515
+ " rows = list(getattr(data, 'index', []))\n"
516
+ " if cols: plt.xticks(range(len(cols)), cols, rotation=90)\n"
517
+ " if rows: plt.yticks(range(len(rows)), rows)\n"
518
+ " except Exception:\n"
519
+ " pass\n"
520
+ " ax = plt.gca()\n"
521
+ " try:\n"
522
+ " if not (ax.get_title() or '').strip():\n"
523
+ " ax.set_title('Correlation Heatmap')\n"
524
+ " except Exception:\n"
525
+ " pass\n"
526
+ " return ax\n"
527
+ " return _safe_plot(lambda **kw: _mat_heat())\n"
528
+ "\n"
529
+ "def _safe_concat(objs, **kwargs):\n"
530
+ " import pandas as _pd\n"
531
+ " if objs is None: return _pd.DataFrame()\n"
532
+ " if isinstance(objs,(list,tuple)) and len(objs)==0: return _pd.DataFrame()\n"
533
+ " try: return _pd.concat(objs, **kwargs)\n"
534
+ " except Exception as e:\n"
535
+ " show(f'⚠ concat skipped: {e}')\n"
536
+ " return _pd.DataFrame()\n"
537
+ "\n"
538
+ "from sklearn.preprocessing import OneHotEncoder\n"
539
+ "import inspect\n"
540
+ "def _SMX_OHE(**k):\n"
541
+ " # normalise arg name across sklearn versions\n"
542
+ " if 'sparse' in k and 'sparse_output' not in k:\n"
543
+ " k['sparse_output'] = k.pop('sparse')\n"
544
+ " k.setdefault('handle_unknown','ignore')\n"
545
+ " k.setdefault('sparse_output', False)\n"
546
+ " try:\n"
547
+ " sig = inspect.signature(OneHotEncoder)\n"
548
+ " if 'sparse_output' not in sig.parameters and 'sparse_output' in k:\n"
549
+ " k['sparse'] = k.pop('sparse_output')\n"
550
+ " except Exception:\n"
551
+ " if 'sparse_output' in k:\n"
552
+ " k['sparse'] = k.pop('sparse_output')\n"
553
+ " return OneHotEncoder(**k)\n"
554
+ "\n"
555
+ "import numpy as _np\n"
556
+ "def _SMX_mm(a, b):\n"
557
+ " try:\n"
558
+ " return a @ b # normal path\n"
559
+ " except Exception:\n"
560
+ " try:\n"
561
+ " A = _np.asarray(a); B = _np.asarray(b)\n"
562
+ " # If same 2D shape (e.g. (n,k) & (n,k)), treat as row-wise dot\n"
563
+ " if A.ndim==2 and B.ndim==2 and A.shape==B.shape:\n"
564
+ " return (A * B).sum(axis=1)\n"
565
+ " # Otherwise try element-wise product (broadcast if possible)\n"
566
+ " return A * B\n"
567
+ " except Exception as e:\n"
568
+ " from syntaxmatrix.display import show\n"
569
+ " show(f'⚠ Matmul relaxed: {type(e).__name__}: {e}'); return _np.nan\n"
570
+ "\n"
571
+ "def _SMX_call(fn, *a, **k):\n"
572
+ " try:\n"
573
+ " return fn(*a, **k)\n"
574
+ " except TypeError as e:\n"
575
+ " msg = str(e)\n"
576
+ " if \"unexpected keyword argument 'squared'\" in msg:\n"
577
+ " k.pop('squared', None)\n"
578
+ " return fn(*a, **k)\n"
579
+ " raise\n"
580
+ "\n"
581
+ "def _SMX_rmse(y_true, y_pred):\n"
582
+ " try:\n"
583
+ " from sklearn.metrics import mean_squared_error as _mse\n"
584
+ " try:\n"
585
+ " return _mse(y_true, y_pred, squared=False)\n"
586
+ " except TypeError:\n"
587
+ " return (_mse(y_true, y_pred)) ** 0.5\n"
588
+ " except Exception:\n"
589
+ " import numpy as _np\n"
590
+ " yt = _np.asarray(y_true, dtype=float)\n"
591
+ " yp = _np.asarray(y_pred, dtype=float)\n"
592
+ " diff = yt - yp\n"
593
+ " return float((_np.mean(diff * diff)) ** 0.5)\n"
594
+ "\n"
595
+ "import pandas as _pd\n"
596
+ "import numpy as _np\n"
597
+ "def _SMX_autocoerce_dates(_df):\n"
598
+ " if _df is None or not hasattr(_df, 'columns'): return\n"
599
+ " for c in list(_df.columns):\n"
600
+ " s = _df[c]\n"
601
+ " n = str(c).lower()\n"
602
+ " if _pd.api.types.is_datetime64_any_dtype(s):\n"
603
+ " continue\n"
604
+ " if _pd.api.types.is_object_dtype(s) or ('date' in n or 'time' in n or 'timestamp' in n or n.endswith('_dt')):\n"
605
+ " try:\n"
606
+ " conv = _pd.to_datetime(s, errors='coerce', utc=True).dt.tz_localize(None)\n"
607
+ " # accept only if at least 10% (min 3) parse as dates\n"
608
+ " if getattr(conv, 'notna', lambda: _pd.Series([]))().sum() >= max(3, int(0.1*len(_df))):\n"
609
+ " _df[c] = conv\n"
610
+ " except Exception:\n"
611
+ " pass\n"
612
+ "\n"
613
+ "def _SMX_autocoerce_numeric(_df, cols):\n"
614
+ " if _df is None: return\n"
615
+ " for c in cols:\n"
616
+ " if c in getattr(_df, 'columns', []):\n"
617
+ " try:\n"
618
+ " _df[c] = _pd.to_numeric(_df[c], errors='coerce')\n"
619
+ " except Exception:\n"
620
+ " pass\n"
621
+ "\n"
622
+ "def show(obj, title=None):\n"
623
+ " try:\n"
624
+ " import pandas as pd, numbers\n"
625
+ " cap = (title or _SMX_caption_from_ctx())\n"
626
+ " # 1) DataFrame → Styler with caption\n"
627
+ " if isinstance(obj, pd.DataFrame):\n"
628
+ " try: return _SMX_base_show(obj.style.set_caption(cap))\n"
629
+ " except Exception: pass\n"
630
+ " # 2) dict of scalars → DataFrame with caption\n"
631
+ " if isinstance(obj, dict) and all(isinstance(v, numbers.Number) for v in obj.values()):\n"
632
+ " df_ = pd.DataFrame({'metric': list(obj.keys()), 'value': list(obj.values())})\n"
633
+ " try: return _SMX_base_show(df_.style.set_caption(cap))\n"
634
+ " except Exception: return _SMX_base_show(df_)\n"
635
+ " except Exception:\n"
636
+ " pass\n"
637
+ " return _SMX_base_show(obj)\n"
638
+ )
639
+
640
+ PREFACE_IMPORT = "from syntaxmatrix.smx_preface import *\n"
641
+ # if PREFACE not in code:
642
+ # code = PREFACE_IMPORT + code
643
+
644
+ fixed = code
645
+
646
+ fixed = re.sub(
647
+ r"(?s)^\s*try:\s*(.*?)\s*except\s+Exception\s+as\s+\w+:\s*\n\s*show\([^\n]*\)\s*$",
648
+ r"\1",
649
+ fixed.strip()
650
+ )
651
+
652
+ # 1) Strip numeric_only=... (version-agnostic)
653
+ fixed = re.sub(r",\s*numeric_only\s*=\s*(True|False|None)", "", fixed, flags=re.I)
654
+ fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\s*,\s*", "", fixed, flags=re.I)
655
+ fixed = re.sub(r"\bnumeric_only\s*=\s*(True|False|None)\b", "", fixed, flags=re.I)
656
+
657
+ # 2) Use safe seaborn wrappers
658
+ fixed = re.sub(r"\bsns\.boxplot\s*\(", "SB_boxplot(", fixed)
659
+ fixed = re.sub(r"\bsns\.barplot\s*\(", "SB_barplot(", fixed)
660
+ fixed = re.sub(r"\bsns\.histplot\s*\(", "SB_histplot(", fixed)
661
+ fixed = re.sub(r"\bsns\.scatterplot\s*\(", "SB_scatterplot(", fixed)
662
+
663
+ # 3) Guard concat calls
664
+ fixed = re.sub(r"\bpd\.concat\s*\(", "_safe_concat(", fixed)
665
+ fixed = re.sub(r"\bOneHotEncoder\s*\(", "_SMX_OHE(", fixed)
666
+ # Route np.dot to tolerant matmul
667
+ fixed = re.sub(r"\bnp\.dot\s*\(", "_SMX_mm(", fixed)
668
+ fixed = re.sub(r"(df\s*\[[^\]]+\])\s*\.dt", r"SMX_dt(\1).dt", fixed)
669
+
670
+
671
+ # 4) Relax any 'required_cols' hard failure blocks
672
+ fixed = re.sub(
673
+ r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
674
+ "required_cols = [c for c in df.columns]\n# (relaxed by SMX hardener)",
675
+ fixed,
676
+ flags=re.S,
677
+ )
678
+
679
+ # 5) Make static numeric_vars lists dynamic
680
+ fixed = re.sub(
681
+ r"\bnumeric_vars\s*=\s*\[.*?\]",
682
+ "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
683
+ fixed,
684
+ flags=re.S,
685
+ )
686
+ # normalise all .dt usages on df[...] / df.attr / df.loc[...] to SMX_dt(...)
687
+ fixed = re.sub(
688
+ r"((?:df\s*(?:\.\s*(?:loc|iloc)\s*)?\[[^\]]+\]|df\s*\.\s*[A-Za-z_]\w*))\s*\.dt\b",
689
+ lambda m: f"SMX_dt({m.group(1)}).dt",
690
+ fixed
691
+ )
692
+
693
+ try:
694
+ class _SMXMatmulRewriter(ast.NodeTransformer):
695
+ def visit_BinOp(self, node):
696
+ self.generic_visit(node)
697
+ if isinstance(node.op, ast.MatMult):
698
+ return ast.Call(func=ast.Name(id="_SMX_mm", ctx=ast.Load()),
699
+ args=[node.left, node.right], keywords=[])
700
+ return node
701
+ _tree = ast.parse(fixed)
702
+ _tree = _SMXMatmulRewriter().visit(_tree)
703
+ fixed = ast.unparse(_tree)
704
+ except Exception:
705
+ # If AST rewrite fails, keep original; _SMX_mm will still handle np.dot(...)
706
+ pass
707
+
708
+ # 6) Final safety wrapper
709
+ fixed = fixed.replace("\t", " ")
710
+ fixed = textwrap.dedent(fixed).strip("\n")
711
+
712
+ fixed = _strip_stray_backrefs(fixed)
713
+ fixed = _wrap_metric_calls(fixed)
714
+
715
+ # If the transformed code is still not syntactically valid, fall back to a
716
+ # very defensive generic snippet that depends only on `df`. This guarantees
717
+ try:
718
+ ast.parse(fixed)
719
+ except (SyntaxError, IndentationError):
720
+ fixed = (
721
+ "import pandas as pd\n"
722
+ "df = df.copy()\n"
723
+ "_info = {\n"
724
+ " 'rows': len(df),\n"
725
+ " 'cols': len(df.columns),\n"
726
+ " 'numeric_cols': len(df.select_dtypes(include=['number','bool']).columns),\n"
727
+ " 'categorical_cols': len(df.select_dtypes(exclude=['number','bool']).columns),\n"
728
+ "}\n"
729
+ "show(df.head(), title='Sample of data')\n"
730
+ "show(_info, title='Dataset summary')\n"
731
+ "try:\n"
732
+ " _num = df.select_dtypes(include=['number','bool']).columns.tolist()\n"
733
+ " if _num:\n"
734
+ " SB_histplot()\n"
735
+ " _SMX_export_png()\n"
736
+ "except Exception as e:\n"
737
+ " show(f\"⚠ Fallback visualisation failed: {type(e).__name__}: {e}\")\n"
738
+ )
739
+
740
+ # Fix placeholder Ellipsis handlers from LLM
741
+ fixed = re.sub(
742
+ r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
743
+ "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
744
+ fixed,
745
+ )
746
+
747
+ wrapped = PREFACE + "try:\n" + _indent(fixed) + "\nexcept Exception as e:\n show(...)\n"
748
+ wrapped = wrapped.lstrip()
749
+ return wrapped
750
+
751
+
752
+ def indent_code(code: str, spaces: int = 4) -> str:
753
+ pad = " " * spaces
754
+ return "\n".join(pad + line for line in code.splitlines())
755
+
756
+
757
+ def wrap_llm_code_safe(code: str) -> str:
758
+ # Swallow any runtime error from the LLM block instead of crashing the run
759
+ return (
760
+ "# __SAFE_WRAPPED__\n"
761
+ "try:\n" + indent_code(code) + "\n"
762
+ "except Exception as e:\n"
763
+ " from syntaxmatrix.display import show\n"
764
+ " show(f\"⚠️ Skipped LLM block due to: {type(e).__name__}: {e}\")\n"
765
+ )
766
+
767
+
768
+ def fix_boxplot_placeholder(code: str) -> str:
769
+ # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
770
+ return re.sub(
771
+ r"sns\.boxplot\(\s*boxplot\s*\)",
772
+ "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
773
+ code
774
+ )
775
+
776
+
777
+ def relax_required_columns(code: str) -> str:
778
+ # Remove hard failure on required_cols; keep a soft filter instead
779
+ return re.sub(
780
+ r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
781
+ "required_cols = [c for c in df.columns]\n",
782
+ code,
783
+ flags=re.S
784
+ )
785
+
786
+
787
+ def make_numeric_vars_dynamic(code: str) -> str:
788
+ # Replace any static numeric_vars list with a dynamic selection
789
+ return re.sub(
790
+ r"numeric_vars\s*=\s*\[.*?\]",
791
+ "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
792
+ code,
793
+ flags=re.S
794
+ )
795
+
796
+
797
+ def auto_inject_template(code: str, intents, df) -> str:
798
+ """If the LLM forgot the core logic, prepend a skeleton block."""
799
+
800
+ has_fit = ".fit(" in code
801
+ has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
802
+
803
+ UNKNOWN_TOKENS = {
804
+ "unknown","not reported","not_reported","not known","n/a","na",
805
+ "none","nan","missing","unreported","unspecified","null","-",""
806
+ }
807
+
808
+ # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
809
+ def _call_template(func, df, **hints):
810
+ import inspect
811
+ try:
812
+ params = inspect.signature(func).parameters
813
+ kw = {k: v for k, v in hints.items() if k in params}
814
+ try:
815
+ return func(df, **kw)
816
+ except TypeError:
817
+ # In case the template changed its signature at runtime
818
+ return func(df)
819
+ except Exception:
820
+ # Absolute safety net
821
+ try:
822
+ return func(df)
823
+ except Exception:
824
+ # As a last resort, return empty code so we don't 500
825
+ return ""
826
+
827
+ def _guess_classification_target(df: pd.DataFrame) -> str | None:
828
+ cols = list(df.columns)
829
+
830
+ # Helper: does this column look like a sensible label?
831
+ def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
832
+ try:
833
+ nunq = s.dropna().nunique()
834
+ except Exception:
835
+ return False
836
+ # need at least 2 classes, but not hundreds
837
+ if nunq < 2 or nunq > 20:
838
+ return False
839
+ bad_name_keys = ("id", "identifier", "index", "uuid", "key")
840
+ name = str(col_name).lower()
841
+ if any(k in name for k in bad_name_keys):
842
+ return False
843
+ return True
844
+
845
+ # 1) columns whose names look like labels
846
+ label_keys = ("target", "label", "outcome", "class", "y", "status")
847
+ name_candidates: list[str] = []
848
+ for key in label_keys:
849
+ for c in cols:
850
+ if key in str(c).lower():
851
+ name_candidates.append(c)
852
+ if name_candidates:
853
+ break # keep the earliest matching key-group
854
+
855
+ # prioritise name-based candidates that also look like proper label columns
856
+ for c in name_candidates:
857
+ if _is_reasonable_class_col(df[c], c):
858
+ return c
859
+ if name_candidates:
860
+ # fall back to the first name-based candidate if none passed the shape test
861
+ return name_candidates[0]
862
+
863
+ # 2) any column with a small number of distinct values (likely a class label)
864
+ for c in cols:
865
+ s = df[c]
866
+ if _is_reasonable_class_col(s, c):
867
+ return c
868
+
869
+ # Nothing suitable found
870
+ return None
871
+
872
+ def _guess_regression_target(df: pd.DataFrame) -> str | None:
873
+ num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
874
+ if not num_cols:
875
+ return None
876
+ # Avoid obvious ID-like columns
877
+ bad_keys = ("id", "identifier", "index")
878
+ candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
879
+ return (candidates or num_cols)[-1]
880
+
881
+ def _guess_time_col(df: pd.DataFrame) -> str | None:
882
+ # Prefer actual datetime dtype
883
+ dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
884
+ if dt_cols:
885
+ return dt_cols[0]
886
+
887
+ # Fallback: name-based hints
888
+ name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
889
+ for c in df.columns:
890
+ name = str(c).lower()
891
+ if any(k in name for k in name_keys):
892
+ return c
893
+ return None
894
+
895
+ def _guess_entity_col(df: pd.DataFrame) -> str | None:
896
+ # Typical sequence IDs: id, patient, subject, device, series, entity
897
+ keys = ["id", "patient", "subject", "device", "series", "entity"]
898
+ candidates = []
899
+ for c in df.columns:
900
+ name = str(c).lower()
901
+ if any(k in name for k in keys):
902
+ candidates.append(c)
903
+ return candidates[0] if candidates else None
904
+
905
+ def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
906
+ # Try label-like names first
907
+ keys = ["target", "label", "class", "outcome", "y"]
908
+ for key in keys:
909
+ for c in df.columns:
910
+ if key in str(c).lower():
911
+ return c
912
+
913
+ # Fallback: any column with few distinct values (e.g. <= 10)
914
+ for c in df.columns:
915
+ s = df[c]
916
+ # avoid obvious IDs
917
+ if any(k in str(c).lower() for k in ["id", "index"]):
918
+ continue
919
+ try:
920
+ nunq = s.dropna().nunique()
921
+ except Exception:
922
+ continue
923
+ if 1 < nunq <= 10:
924
+ return c
925
+
926
+ return None
927
+
928
+ def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
929
+ cols = list(df.columns)
930
+ lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
931
+ # also include boolean/binary columns with suitable names
932
+ for c in cols:
933
+ s = df[c]
934
+ try:
935
+ nunq = s.dropna().nunique()
936
+ except Exception:
937
+ continue
938
+ if nunq in (2,) and c not in lbl_like:
939
+ # avoid obvious IDs
940
+ if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
941
+ lbl_like.append(c)
942
+ # keep at most, say, 12 to avoid accidental flood
943
+ return lbl_like[:12]
944
+
945
+ def _find_unknownish_column(df: pd.DataFrame) -> str | None:
946
+ # Search categorical-like columns for any 'unknown-like' values or high missingness
947
+ candidates = []
948
+ for c in df.columns:
949
+ s = df[c]
950
+ # focus on object/category/boolean-ish or low-card columns
951
+ if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
952
+ continue
953
+ try:
954
+ vals = s.astype(str).str.strip().str.lower()
955
+ except Exception:
956
+ continue
957
+ # score: presence of unknown tokens + missing rate
958
+ token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
959
+ miss_rate = s.isna().mean()
960
+ name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
961
+ score = 3*token_hit + 2*name_bonus + miss_rate
962
+ if token_hit or miss_rate > 0.05 or name_bonus:
963
+ candidates.append((score, c))
964
+ if not candidates:
965
+ return None
966
+ candidates.sort(reverse=True)
967
+ return candidates[0][1]
968
+
969
+ def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
970
+ cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
971
+ # prefer non-constant columns
972
+ scored = []
973
+ for c in cols:
974
+ try:
975
+ v = df[c].dropna()
976
+ var = float(v.var()) if len(v) else 0.0
977
+ scored.append((var, c))
978
+ except Exception:
979
+ continue
980
+ scored.sort(reverse=True)
981
+ return [c for _, c in scored[:max_n]]
982
+
983
+ def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
984
+ exclude = exclude or set()
985
+ picks = []
986
+ for c in df.columns:
987
+ if c in exclude:
988
+ continue
989
+ s = df[c]
990
+ if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
991
+ nunq = s.dropna().nunique()
992
+ if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
993
+ picks.append((nunq, c))
994
+ picks.sort(reverse=True)
995
+ return [c for _, c in picks[:max_n]]
996
+
997
+ def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
998
+ exclude = exclude or set()
999
+ # name hints first
1000
+ name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
1001
+ for c in df.columns:
1002
+ if c in exclude:
1003
+ continue
1004
+ name = str(c).lower()
1005
+ if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
1006
+ return c
1007
+ # fallback: any binary numeric
1008
+ for c in df.select_dtypes(include=[np.number, "bool"]).columns:
1009
+ if c in exclude:
1010
+ continue
1011
+ try:
1012
+ if df[c].dropna().nunique() == 2:
1013
+ return c
1014
+ except Exception:
1015
+ continue
1016
+ return None
1017
+
1018
+
1019
+ def _pick_viz_template(signal: str):
1020
+ s = signal.lower()
1021
+
1022
+ # explicit chart requests
1023
+ if any(k in s for k in ("pie", "donut")):
1024
+ return viz_pie
1025
+
1026
+ if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
1027
+ return viz_stacked_bar
1028
+
1029
+ if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
1030
+ return viz_distribution
1031
+
1032
+ if any(k in s for k in ("kde", "density")):
1033
+ return viz_kde
1034
+
1035
+ # these three you asked about
1036
+ if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
1037
+ return viz_box
1038
+
1039
+ if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
1040
+ return viz_scatter
1041
+
1042
+ if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
1043
+ return viz_count_bar
1044
+
1045
+ if any(k in s for k in ("area", "trend", "over time", "time series")):
1046
+ return viz_area
1047
+
1048
+ # fallback
1049
+ return viz_line
1050
+
1051
+ for intent in intents:
1052
+
1053
+ if intent not in INJECTABLE_INTENTS:
1054
+ return code
1055
+
1056
+ # Correlation analysis
1057
+ if intent == "correlation_analysis" and not has_fit:
1058
+ return eda_correlation(df) + "\n\n" + code
1059
+
1060
+ # Generic visualisation (keyword-based)
1061
+ if intent == "visualisation" and not has_fit and not has_plot:
1062
+ rq = str(globals().get("refined_question", ""))
1063
+ # aq = str(globals().get("askai_question", ""))
1064
+ signal = rq + "\n" + str(intents) + "\n" + code
1065
+ tpl = _pick_viz_template(signal)
1066
+ return tpl(df) + "\n\n" + code
1067
+
1068
+ if intent == "clustering" and not has_fit:
1069
+ return clustering(df) + "\n\n" + code
1070
+
1071
+ if intent == "recommendation" and not has_fit:
1072
+ return recommendation(df) + "\\n\\n" + code
1073
+
1074
+ if intent == "topic_modelling" and not has_fit:
1075
+ return topic_modelling(df) + "\\n\\n" + code
1076
+
1077
+ if intent == "eda" and not has_fit:
1078
+ return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
1079
+
1080
+ # --- Classification ------------------------------------------------
1081
+ if intent == "classification" and not has_fit:
1082
+ target = _guess_classification_target(df)
1083
+ if target:
1084
+ return classification(df) + "\n\n" + code
1085
+ # return _call_template(classification, df, target) + "\n\n" + code
1086
+
1087
+ # --- Regression ----------------------------------------------------
1088
+ if intent == "regression" and not has_fit:
1089
+ target = _guess_regression_target(df)
1090
+ if target:
1091
+ return regression(df) + "\n\n" + code
1092
+ # return _call_template(regression, df, target) + "\n\n" + code
1093
+
1094
+ # --- Anomaly detection --------------------------------------------
1095
+ if intent == "anomaly_detection":
1096
+ uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
1097
+ if not uses_anomaly:
1098
+ return anomaly_detection(df) + "\n\n" + code
1099
+
1100
+ # --- Time-series anomaly detection --------------------------------
1101
+ if intent == "ts_anomaly_detection":
1102
+ uses_ts = "STL(" in code or "seasonal_decompose(" in code
1103
+ if not uses_ts:
1104
+ return ts_anomaly_detection(df) + "\n\n" + code
1105
+
1106
+ # --- Time-series classification --------------------------------
1107
+ if intent == "time_series_classification" and not has_fit:
1108
+ time_col = _guess_time_col(df)
1109
+ entity_col = _guess_entity_col(df)
1110
+ target_col = _guess_ts_class_target(df)
1111
+
1112
+ # If we can't confidently identify these, do NOT inject anything
1113
+ if time_col and entity_col and target_col:
1114
+ return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
1115
+
1116
+ # --- Dimensionality reduction --------------------------------------
1117
+ if intent == "dimensionality_reduction":
1118
+ uses_dr = any(k in code for k in ("PCA(", "TSNE("))
1119
+ if not uses_dr:
1120
+ return dimensionality_reduction(df) + "\n\n" + code
1121
+
1122
+ # --- Feature selection ---------------------------------------------
1123
+ if intent == "feature_selection":
1124
+ uses_fs = any(k in code for k in (
1125
+ "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
1126
+ ))
1127
+ if not uses_fs:
1128
+ return feature_selection(df) + "\n\n" + code
1129
+
1130
+ # --- EDA / correlation / visualisation -----------------------------
1131
+ if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
1132
+ if intent == "correlation_analysis":
1133
+ return eda_correlation(df) + "\n\n" + code
1134
+ else:
1135
+ return eda_overview(df) + "\n\n" + code
1136
+
1137
+ # --- Time-series forecasting ---------------------------------------
1138
+ if intent == "time_series_forecasting" and not has_fit:
1139
+ uses_ts_forecast = any(k in code for k in (
1140
+ "ARIMA", "ExponentialSmoothing", "forecast", "predict("
1141
+ ))
1142
+ if not uses_ts_forecast:
1143
+ return time_series_forecasting(df) + "\n\n" + code
1144
+
1145
+ # --- Multi-label classification -----------------------------------
1146
+ if intent in ("multilabel_classification",) and not has_fit:
1147
+ label_cols = _guess_multilabel_cols(df)
1148
+ if len(label_cols) >= 2:
1149
+ return multilabel_classification(df, label_cols) + "\n\n" + code
1150
+
1151
+ group_col = _find_unknownish_column(df)
1152
+ if group_col:
1153
+ num_cols = _guess_numeric_cols(df)
1154
+ cat_cols = _guess_categorical_cols(df, exclude={group_col})
1155
+ outcome_col = None # generic; let template skip if not present
1156
+ tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1157
+
1158
+ # Return template + guarded (repaired) LLM code, so it never crashes
1159
+ repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1160
+ return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1161
+
1162
+ return code
1163
+
1164
+
1165
+ def fix_values_sum_numeric_only_bug(code: str) -> str:
1166
+ """
1167
+ If a previous pass injected numeric_only=True into a NumPy-style sum,
1168
+ e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1169
+ """
1170
+ # .values.sum(numeric_only=True, ...)
1171
+ code = re.sub(
1172
+ r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1173
+ ".to_numpy().sum()",
1174
+ code,
1175
+ flags=re.IGNORECASE,
1176
+ )
1177
+ # .to_numpy().sum(numeric_only=True, ...)
1178
+ code = re.sub(
1179
+ r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1180
+ ".to_numpy().sum()",
1181
+ code,
1182
+ flags=re.IGNORECASE,
1183
+ )
1184
+ return code
1185
+
1186
+
13
1187
  def strip_describe_slice(code: str) -> str:
14
1188
  """
15
1189
  Remove any pattern like df.groupby(...).describe()[[ ... ]] because
@@ -23,10 +1197,12 @@ def strip_describe_slice(code: str) -> str:
23
1197
  )
24
1198
  return pat.sub(r"\1)", code)
25
1199
 
1200
+
26
1201
  def remove_plt_show(code: str) -> str:
27
1202
  """Removes all plt.show() calls from the generated code string."""
28
1203
  return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
29
1204
 
1205
+
30
1206
  def patch_plot_with_table(code: str) -> str:
31
1207
  """
32
1208
  ▸ strips every `plt.show()` (avoids warnings)
@@ -113,7 +1289,7 @@ def patch_plot_with_table(code: str) -> str:
113
1289
  ")\n"
114
1290
  )
115
1291
 
116
- tbl_block += "from syntaxmatrix.display import show\nshow(summary_table)"
1292
+ tbl_block += "show(summary_table, title='Summary Statistics')"
117
1293
 
118
1294
  # 5. inject image-export block, then table block, after the plot
119
1295
  patched = (
@@ -253,10 +1429,10 @@ def refine_eda_question(raw_question, df=None, max_points=1000):
253
1429
  "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
254
1430
  )
255
1431
 
256
-
257
1432
  # 9. Fallback: return the raw question
258
1433
  return q
259
1434
 
1435
+
260
1436
  def patch_plot_code(code, df, user_question=None):
261
1437
 
262
1438
  # ── Early guard: abort nicely if the generated code references columns that
@@ -277,10 +1453,13 @@ def patch_plot_code(code, df, user_question=None):
277
1453
 
278
1454
  if missing_cols:
279
1455
  cols_list = ", ".join(missing_cols)
280
- return (
281
- f"print('⚠️ Column(s) \"{cols_list}\" not found in the dataset. "
282
- f"Please check the column names and try again.')"
1456
+ warning = (
1457
+ f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1458
+ "These must either exist in df or be created earlier in the code; "
1459
+ "otherwise you may see a KeyError.')\n"
283
1460
  )
1461
+ # Prepend the warning but keep the original code so it can still run
1462
+ code = warning + code
284
1463
 
285
1464
  # 1. For line plots (auto-aggregate)
286
1465
  m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
@@ -389,6 +1568,16 @@ def patch_plot_code(code, df, user_question=None):
389
1568
  # Fallback: Return original code
390
1569
  return code
391
1570
 
1571
+
1572
+ def ensure_matplotlib_title(code, title_var="refined_question"):
1573
+ import re
1574
+ makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1575
+ has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1576
+ if makes_plot and not has_title:
1577
+ code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1578
+ return code
1579
+
1580
+
392
1581
  def ensure_output(code: str) -> str:
393
1582
  """
394
1583
  Guarantees that AI-generated code actually surfaces results in the UI
@@ -405,7 +1594,6 @@ def ensure_output(code: str) -> str:
405
1594
  # not a comment / print / assignment / pyplot call
406
1595
  if (last and not last.startswith(("print(", "plt.", "#")) and "=" not in last):
407
1596
  lines[-1] = f"_out = {last}"
408
- lines.append("from syntaxmatrix.display import show")
409
1597
  lines.append("show(_out)")
410
1598
 
411
1599
  # ── 3· auto-surface common stats tuples (stat, p) ───────────────────
@@ -413,14 +1601,12 @@ def ensure_output(code: str) -> str:
413
1601
  if re.search(r"\bchi2\s*,\s*p\s*,", code) and "show((" in code:
414
1602
  pass # AI already shows the tuple
415
1603
  elif re.search(r"\bchi2\s*,\s*p\s*,", code):
416
- lines.append("from syntaxmatrix.display import show")
417
1604
  lines.append("show((chi2, p))")
418
1605
 
419
1606
  # ── 4· classification report (string) ───────────────────────────────
420
1607
  cr_match = re.search(r"^\s*(\w+)\s*=\s*classification_report\(", code, re.M)
421
1608
  if cr_match and f"show({cr_match.group(1)})" not in "\n".join(lines):
422
1609
  var = cr_match.group(1)
423
- lines.append("from syntaxmatrix.display import show")
424
1610
  lines.append(f"show({var})")
425
1611
 
426
1612
  # 5-bis · pivot tables (DataFrame)
@@ -457,18 +1643,17 @@ def ensure_output(code: str) -> str:
457
1643
  assign_scalar = re.match(r"\s*(\w+)\s*=\s*.+\.shape\[\s*0\s*\]\s*$", lines[-1])
458
1644
  if assign_scalar:
459
1645
  var = assign_scalar.group(1)
460
- lines.append("from syntaxmatrix.display import show")
461
1646
  lines.append(f"show({var})")
462
1647
 
463
1648
  # ── 8. utils.ensure_output()
464
1649
  assign_df = re.match(r"\s*(\w+)\s*=\s*df\[", lines[-1])
465
1650
  if assign_df:
466
1651
  var = assign_df.group(1)
467
- lines.append("from syntaxmatrix.display import show")
468
1652
  lines.append(f"show({var})")
469
1653
 
470
1654
  return "\n".join(lines)
471
1655
 
1656
+
472
1657
  def get_plotting_imports(code):
473
1658
  imports = []
474
1659
  if "plt." in code and "import matplotlib.pyplot as plt" not in code:
@@ -488,6 +1673,7 @@ def get_plotting_imports(code):
488
1673
  code = "\n".join(imports) + "\n\n" + code
489
1674
  return code
490
1675
 
1676
+
491
1677
  def patch_pairplot(code, df):
492
1678
  if "sns.pairplot" in code:
493
1679
  # Always assign and print pairgrid
@@ -498,29 +1684,82 @@ def patch_pairplot(code, df):
498
1684
  code += "\nprint(pairgrid)"
499
1685
  return code
500
1686
 
1687
+
501
1688
  def ensure_image_output(code: str) -> str:
502
1689
  """
503
- Injects a PNG exporter in front of every plt.show() so dashboards
504
- get real <img> HTML instead of a blank cell.
1690
+ Replace each plt.show() with an indented _SMX_export_png() call.
1691
+ This keeps block indentation valid and still renders images in the dashboard.
505
1692
  """
506
1693
  if "plt.show()" not in code:
507
1694
  return code
508
1695
 
509
- exporter = (
510
- # -- NEW: use display(), not print() --------------------------
511
- "import io, base64\n"
512
- "buf = io.BytesIO()\n"
513
- "plt.savefig(buf, format='png', bbox_inches='tight')\n"
514
- "buf.seek(0)\n"
515
- "img_b64 = base64.b64encode(buf.read()).decode('utf-8')\n"
516
- "from IPython.display import display, HTML\n"
517
- "display(HTML(f'<img src=\"data:image/png;base64,{img_b64}\" "
518
- "style=\"max-width:100%;\">'))\n"
519
- "plt.close()\n"
1696
+ import re
1697
+ out_lines = []
1698
+ for ln in code.splitlines():
1699
+ if "plt.show()" not in ln:
1700
+ out_lines.append(ln)
1701
+ continue
1702
+
1703
+ # works for:
1704
+ # plt.show()
1705
+ # plt.tight_layout(); plt.show()
1706
+ # ... ; plt.show(); ... (multiple on one line)
1707
+ indent = re.match(r"^(\s*)", ln).group(1)
1708
+ parts = ln.split("plt.show()")
1709
+
1710
+ # keep whatever is before the first plt.show()
1711
+ if parts[0].strip():
1712
+ out_lines.append(parts[0].rstrip())
1713
+
1714
+ # for every plt.show() we removed, insert exporter at same indent
1715
+ for _ in range(len(parts) - 1):
1716
+ out_lines.append(indent + "_SMX_export_png()")
1717
+
1718
+ # keep whatever comes after the last plt.show()
1719
+ if parts[-1].strip():
1720
+ out_lines.append(indent + parts[-1].lstrip())
1721
+
1722
+ return "\n".join(out_lines)
1723
+
1724
+
1725
+ def clean_llm_code(code: str) -> str:
1726
+ """
1727
+ Make LLM output safe to exec:
1728
+ - If fenced blocks exist, keep the largest one (usually the real code).
1729
+ - Otherwise strip any stray ``` / ```python lines.
1730
+ - Remove common markdown/preamble junk.
1731
+ """
1732
+ code = str(code or "")
1733
+
1734
+ # Extract fenced blocks (```python ... ``` or ``` ... ```)
1735
+ blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1736
+
1737
+ if blocks:
1738
+ # pick the largest block; small trailing blocks are usually garbage
1739
+ largest = max(blocks, key=lambda b: len(b.strip()))
1740
+ if len(largest.strip().splitlines()) >= 10:
1741
+ code = largest
1742
+ else:
1743
+ # if no meaningful block, just remove fence markers
1744
+ code = re.sub(r"^```.*?$", "", code, flags=re.M)
1745
+ else:
1746
+ # no complete blocks — still remove any stray fence lines
1747
+ code = re.sub(r"^```.*?$", "", code, flags=re.M)
1748
+
1749
+ # Strip common markdown/preamble lines
1750
+ drop_prefixes = (
1751
+ "here is", "here's", "below is", "sure,", "certainly",
1752
+ "explanation", "note:", "```"
520
1753
  )
1754
+ cleaned_lines = []
1755
+ for ln in code.splitlines():
1756
+ s = ln.strip().lower()
1757
+ if any(s.startswith(p) for p in drop_prefixes):
1758
+ continue
1759
+ cleaned_lines.append(ln)
1760
+
1761
+ return "\n".join(cleaned_lines).strip()
521
1762
 
522
- # exporter BEFORE the original plt.show()
523
- return code.replace("plt.show()", exporter + "plt.show()")
524
1763
 
525
1764
  def fix_groupby_describe_slice(code: str) -> str:
526
1765
  """
@@ -543,6 +1782,7 @@ def fix_groupby_describe_slice(code: str) -> str:
543
1782
  )
544
1783
  return pat.sub(repl, code)
545
1784
 
1785
+
546
1786
  def fix_importance_groupby(code: str) -> str:
547
1787
  pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
548
1788
  if "importance_df" in code:
@@ -589,10 +1829,12 @@ def inject_auto_preprocessing(code: str) -> str:
589
1829
  # simply prepend; model code that follows can wrap estimator in a Pipeline
590
1830
  return prep_snippet + code
591
1831
 
1832
+
592
1833
  def fix_to_datetime_errors(code: str) -> str:
593
1834
  """
594
1835
  Force every pd.to_datetime(…) call to ignore bad dates so that
595
- ‘year 16500 is out of range’ and similar issues don’t crash runs.
1836
+
1837
+ 'year 16500 is out of range' and similar issues don’t crash runs.
596
1838
  """
597
1839
  import re
598
1840
  # look for any pd.to_datetime( … )
@@ -605,25 +1847,67 @@ def fix_to_datetime_errors(code: str) -> str:
605
1847
  return f"pd.to_datetime({inside}, errors='coerce')"
606
1848
  return pat.sub(repl, code)
607
1849
 
1850
+
608
1851
  def fix_numeric_sum(code: str) -> str:
609
1852
  """
610
- Rewrites every `.sum(` call so it becomes
611
- `.sum(numeric_only=True, …)` unless that keyword is already present.
1853
+ Make .sum(...) code safe across pandas versions by removing any
1854
+ numeric_only=... argument (True/False/None) from function calls.
1855
+
1856
+ This avoids errors on pandas versions where numeric_only is not
1857
+ supported for Series/grouped sums, and we rely instead on explicit
1858
+ numeric column selection (e.g. select_dtypes) in the generated code.
612
1859
  """
613
- pattern = re.compile(r"\.sum\(\s*([^\)]*)\)")
1860
+ # Case 1: ..., numeric_only=True/False/None
1861
+ code = re.sub(
1862
+ r",\s*numeric_only\s*=\s*(True|False|None)",
1863
+ "",
1864
+ code,
1865
+ flags=re.IGNORECASE,
1866
+ )
614
1867
 
615
- def _repl(match):
616
- args = match.group(1)
617
- if "numeric_only" in args: # already safe
618
- return match.group(0)
1868
+ # Case 2: numeric_only=True/False/None, ... (as first argument)
1869
+ code = re.sub(
1870
+ r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1871
+ "",
1872
+ code,
1873
+ flags=re.IGNORECASE,
1874
+ )
619
1875
 
620
- args = args.strip()
621
- if args: # keep existing positional / kw args
622
- args += ", "
623
- return f".sum({args}numeric_only=True)"
1876
+ # Case 3: numeric_only=True/False/None (only argument)
1877
+ code = re.sub(
1878
+ r"numeric_only\s*=\s*(True|False|None)",
1879
+ "",
1880
+ code,
1881
+ flags=re.IGNORECASE,
1882
+ )
1883
+
1884
+ return code
1885
+
1886
+
1887
+ def fix_concat_empty_list(code: str) -> str:
1888
+ """
1889
+ Make pd.concat calls resilient to empty lists of objects.
1890
+
1891
+ Transforms patterns like:
1892
+ pd.concat(frames, ignore_index=True)
1893
+ pd.concat(frames)
1894
+
1895
+ into:
1896
+ pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1897
+ pd.concat(frames or [pd.DataFrame()])
1898
+
1899
+ Only triggers when the first argument is a simple variable name.
1900
+ """
1901
+ pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1902
+
1903
+ def _repl(m):
1904
+ name = m.group(1)
1905
+ sep = m.group(2) # ',' or ')'
1906
+ return f"pd.concat({name} or [pd.DataFrame()]{sep}"
624
1907
 
625
1908
  return pattern.sub(_repl, code)
626
1909
 
1910
+
627
1911
  def fix_numeric_aggs(code: str) -> str:
628
1912
  _AGG_FUNCS = ("sum", "mean")
629
1913
  pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
@@ -637,6 +1921,7 @@ def fix_numeric_aggs(code: str) -> str:
637
1921
  return f".{func}({args}numeric_only=True)"
638
1922
  return pat.sub(_repl, code)
639
1923
 
1924
+
640
1925
  def ensure_accuracy_block(code: str) -> str:
641
1926
  """
642
1927
  Inject a sensible evaluation block right after the last `<est>.fit(...)`
@@ -696,40 +1981,6 @@ def ensure_accuracy_block(code: str) -> str:
696
1981
  insert_at = code.find("\n", m[-1].end()) + 1
697
1982
  return code[:insert_at] + eval_block + code[insert_at:]
698
1983
 
699
- def classify(prompt: str) -> str:
700
- """
701
- Very-light intent classifier.
702
- Returns one of:
703
- 'stat_test' | 'time_series' | 'clustering'
704
- 'classification' | 'regression' | 'eda'
705
- """
706
- p = prompt.lower().strip()
707
- greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
708
- if any(p.startswith(g) or p == g for g in greetings):
709
- return "greeting"
710
-
711
- if any(k in p for k in ("t-test", "anova", "p-value")):
712
- return "stat_test"
713
- if "forecast" in p or "prophet" in p:
714
- return "time_series"
715
- if "cluster" in p or "kmeans" in p:
716
- return "clustering"
717
- if any(k in p for k in ("accuracy", "precision", "roc")):
718
- return "classification"
719
- if any(k in p for k in ("rmse", "r2", "mae")):
720
- return "regression"
721
- return "eda"
722
-
723
- def auto_inject_template(code: str, intent: str, df) -> str:
724
- """If the LLM forgot the core logic, prepend a skeleton block."""
725
- has_fit = ".fit(" in code
726
-
727
- if intent == "classification" and not has_fit:
728
- # guess a y column that contains 'diabetes' as in your dataset
729
- target = next((c for c in df.columns if "diabetes" in c.lower()), None)
730
- if target:
731
- return classification(df, target) + "\n\n" + code
732
- return code
733
1984
 
734
1985
  def fix_scatter_and_summary(code: str) -> str:
735
1986
  """
@@ -757,6 +2008,7 @@ def fix_scatter_and_summary(code: str) -> str:
757
2008
 
758
2009
  return code
759
2010
 
2011
+
760
2012
  def auto_format_with_black(code: str) -> str:
761
2013
  """
762
2014
  Format the generated code with Black. Falls back silently if Black
@@ -771,6 +2023,7 @@ def auto_format_with_black(code: str) -> str:
771
2023
  except Exception:
772
2024
  return code
773
2025
 
2026
+
774
2027
  def ensure_preproc_in_pipeline(code: str) -> str:
775
2028
  """
776
2029
  If code defines `preproc = ColumnTransformer(...)` but then builds
@@ -783,13 +2036,14 @@ def ensure_preproc_in_pipeline(code: str) -> str:
783
2036
  code
784
2037
  )
785
2038
 
2039
+
786
2040
  def fix_plain_prints(code: str) -> str:
787
2041
  """
788
2042
  Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
789
2043
  to go through SyntaxMatrix's smart display (so it renders in the dashboard).
790
2044
  Keeps string prints alone.
791
2045
  """
792
- import re
2046
+
793
2047
  # Skip obvious string-literal prints
794
2048
  new = re.sub(
795
2049
  r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
@@ -798,6 +2052,7 @@ def fix_plain_prints(code: str) -> str:
798
2052
  )
799
2053
  return new
800
2054
 
2055
+
801
2056
  def fix_print_html(code: str) -> str:
802
2057
  """
803
2058
  Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
@@ -830,6 +2085,7 @@ def fix_print_html(code: str) -> str:
830
2085
 
831
2086
  return new
832
2087
 
2088
+
833
2089
  def ensure_ipy_display(code: str) -> str:
834
2090
  """
835
2091
  Guarantee that the cell has proper IPython display imports so that
@@ -838,9 +2094,8 @@ def ensure_ipy_display(code: str) -> str:
838
2094
  if "display(" in code and "from IPython.display import display, HTML" not in code:
839
2095
  return "from IPython.display import display, HTML\n" + code
840
2096
  return code
841
- # --------------------------------------------------------------------------
842
- # ✂
843
- # --------------------------------------------------------------------------
2097
+
2098
+
844
2099
  def drop_bad_classification_metrics(code: str, y_or_df) -> str:
845
2100
  """
846
2101
  Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
@@ -885,6 +2140,7 @@ def drop_bad_classification_metrics(code: str, y_or_df) -> str:
885
2140
 
886
2141
  return code
887
2142
 
2143
+
888
2144
  def force_capture_display(code: str) -> str:
889
2145
  """
890
2146
  Ensure our executor captures HTML output:
@@ -925,11 +2181,13 @@ def force_capture_display(code: str) -> str:
925
2181
  )
926
2182
  return new
927
2183
 
2184
+
928
2185
  def strip_matplotlib_show(code: str) -> str:
929
2186
  """Remove blocking plt.show() calls (we export base64 instead)."""
930
2187
  import re
931
2188
  return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
932
2189
 
2190
+
933
2191
  def inject_display_shim(code: str) -> str:
934
2192
  """
935
2193
  Provide display()/HTML() if missing, forwarding to our executor hook.
@@ -951,6 +2209,7 @@ def inject_display_shim(code: str) -> str:
951
2209
  )
952
2210
  return shim + code
953
2211
 
2212
+
954
2213
  def strip_spurious_column_tokens(code: str) -> str:
955
2214
  """
956
2215
  Remove common stop-words ('the','whether', ...) when they appear
@@ -961,7 +2220,8 @@ def strip_spurious_column_tokens(code: str) -> str:
961
2220
  """
962
2221
  STOP = {
963
2222
  "the","whether","a","an","and","or","of","to","in","on","for","by",
964
- "with","as","at","from","that","this","these","those","is","are","was","were"
2223
+ "with","as","at","from","that","this","these","those","is","are","was","were",
2224
+ "coef", "Coef", "coefficient", "Coefficient"
965
2225
  }
966
2226
 
967
2227
  def _norm(s: str) -> str:
@@ -984,10 +2244,920 @@ def strip_spurious_column_tokens(code: str) -> str:
984
2244
 
985
2245
  # df[[ ... ]] selections
986
2246
  code = re.sub(
987
- r"df\s*\[\s*\[([^\]]+)\]\s*\]",
988
- lambda m: "df[" + _clean_list(m.group(1)) + "]",
2247
+ r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2248
+
2249
+ return code
2250
+
2251
+
2252
+ def patch_prefix_seaborn_calls(code: str) -> str:
2253
+ """
2254
+ Ensure bare seaborn calls are prefixed with `sns.`.
2255
+ E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2256
+ """
2257
+ if "sns." in code:
2258
+ # still fix any leftover bare calls alongside prefixed ones
2259
+ pass
2260
+
2261
+ # functions commonly used from seaborn
2262
+ funcs = [
2263
+ "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2264
+ "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2265
+ "scatterplot","lineplot","catplot","displot","lmplot"
2266
+ ]
2267
+ # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2268
+ # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2269
+ pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2270
+
2271
+ def _add_prefix(m):
2272
+ fn = m.group(1)
2273
+ return f"sns.{fn}("
2274
+
2275
+ return pattern.sub(_add_prefix, code)
2276
+
2277
+
2278
+ def patch_ensure_seaborn_import(code: str) -> str:
2279
+ """
2280
+ If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2281
+ Also set a quiet theme for consistent visuals.
2282
+ """
2283
+ needs_sns = "sns." in code
2284
+ has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2285
+ if needs_sns and not has_import:
2286
+ # Insert after the first block of imports if possible, else at top
2287
+ import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2288
+ inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2289
+ if import_block:
2290
+ start = import_block.end()
2291
+ code = code[:start] + inject + code[start:]
2292
+ else:
2293
+ code = inject + code
2294
+ return code
2295
+
2296
+
2297
+ def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2298
+ """
2299
+ Normalise pie-chart requests.
2300
+
2301
+ Supports three patterns:
2302
+ A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2303
+ B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2304
+ → one pie per facet level (grid) + counts bar of facet sizes.
2305
+ C) Single pie when no split/facet is requested.
2306
+
2307
+ Notes:
2308
+ - Pie variables must be categorical (or numeric binned).
2309
+ - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2310
+ """
2311
+
2312
+ q = (user_question or "")
2313
+ q_low = q.lower()
2314
+
2315
+ # Prefer explicit: df['col'].value_counts()
2316
+ m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2317
+ col = m.group(1) if m else None
2318
+
2319
+ # ---------- helpers ----------
2320
+ def _is_cat(col):
2321
+ return (str(df[col].dtype).startswith("category")
2322
+ or df[col].dtype == "object"
2323
+ or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2324
+
2325
+ def _cats_from_question(question: str):
2326
+ found = []
2327
+ for c in df.columns:
2328
+ if c.lower() in question.lower() and _is_cat(c):
2329
+ found.append(c)
2330
+ # dedupe preserve order
2331
+ seen, out = set(), []
2332
+ for c in found:
2333
+ if c not in seen:
2334
+ out.append(c); seen.add(c)
2335
+ return out
2336
+
2337
+ def _fallback_cat():
2338
+ cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2339
+ if not cats: return None
2340
+ cats.sort(key=lambda t: t[1])
2341
+ return cats[0][0]
2342
+
2343
+ def _infer_comp_pref(question: str) -> str:
2344
+ ql = (question or "").lower()
2345
+ if "heatmap" in ql or "matrix" in ql:
2346
+ return "heatmap"
2347
+ if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2348
+ return "stacked_bar_pct"
2349
+ if "stacked" in ql:
2350
+ return "stacked_bar"
2351
+ if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2352
+ return "grouped_bar"
2353
+ return "counts_bar"
2354
+
2355
+ # parse threshold split like "HbA1c ≥ 6.5"
2356
+ def _parse_split(question: str):
2357
+ ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2358
+ m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2359
+ if not m: return None
2360
+ col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2361
+ op = ops_map.get(op_raw);
2362
+ if not op: return None
2363
+ # case-insensitive column match
2364
+ candidates = {c.lower(): c for c in df.columns}
2365
+ col = candidates.get(col_raw.lower())
2366
+ if not col: return None
2367
+ try: val = float(val_raw)
2368
+ except Exception: return None
2369
+ return (col, op, val)
2370
+
2371
+ # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2372
+ def _extract_facet(question: str):
2373
+ # 1) explicit "by/ across / within / per <col>"
2374
+ for kw in [" by ", " across ", " within ", " within each ", " per "]:
2375
+ m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2376
+ if m:
2377
+ col_raw = m.group(1).strip()
2378
+ candidates = {c.lower(): c for c in df.columns}
2379
+ if col_raw.lower() in candidates:
2380
+ return (candidates[col_raw.lower()], "auto")
2381
+ # 2) "bin <col>"
2382
+ m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2383
+ if m2:
2384
+ col_raw = m2.group(1).strip()
2385
+ candidates = {c.lower(): c for c in df.columns}
2386
+ if col_raw.lower() in candidates:
2387
+ return (candidates[col_raw.lower()], "bin")
2388
+ # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2389
+ if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2390
+ any(c.lower() == "bmi" for c in df.columns.str.lower()):
2391
+ bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2392
+ return (bmi_col, "bmi")
2393
+ return None
2394
+
2395
+ def _bmi_bins(series: pd.Series):
2396
+ # WHO cutoffs
2397
+ bins = [-np.inf, 18.5, 25, 30, np.inf]
2398
+ labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2399
+ return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2400
+
2401
+ wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2402
+ if not wants_pie:
2403
+ return code
2404
+
2405
+ split = _parse_split(q)
2406
+ facet = _extract_facet(q)
2407
+ cats = _cats_from_question(q)
2408
+ _comp_pref = _infer_comp_pref(q)
2409
+
2410
+ # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2411
+ for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2412
+ if hard in df.columns and hard not in cats and hard.lower() in q_low:
2413
+ cats.append(hard)
2414
+
2415
+ # --------------- CASE A: threshold split (cohorts) ---------------
2416
+ if split:
2417
+ if not (cats or any(_is_cat(c) for c in df.columns)):
2418
+ return code
2419
+ if not cats:
2420
+ pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2421
+ pool.sort(key=lambda t: t[1])
2422
+ cats = [t[0] for t in pool[:3]] if pool else []
2423
+ if not cats:
2424
+ return code
2425
+
2426
+ split_col, op, val = split
2427
+ cond_str = f"(df['{split_col}'] {op} {val})"
2428
+ snippet = f"""
2429
+ import numpy as np
2430
+ import pandas as pd
2431
+ import matplotlib.pyplot as plt
2432
+
2433
+ _mask_a = ({cond_str}) & df['{split_col}'].notna()
2434
+ _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2435
+
2436
+ _cohort_a_name = "{split_col} {op} {val}"
2437
+ _cohort_b_name = "NOT ({split_col} {op} {val})"
2438
+
2439
+ _cat_cols = {cats!r}
2440
+ n = len(_cat_cols)
2441
+ fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2442
+ if n == 1:
2443
+ axes = np.array([axes])
2444
+
2445
+ for i, col in enumerate(_cat_cols):
2446
+ s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2447
+ s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2448
+
2449
+ ax_a = axes[i, 0]; ax_b = axes[i, 1]
2450
+ if len(s_a) > 0:
2451
+ ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2452
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2453
+ ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2454
+
2455
+ if len(s_b) > 0:
2456
+ ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2457
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2458
+ ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2459
+
2460
+ plt.tight_layout(); plt.show()
2461
+
2462
+ # grouped bar complement
2463
+ for col in _cat_cols:
2464
+ _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2465
+ .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2466
+ _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2467
+ _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2468
+
2469
+ if _comp_pref == "grouped_bar":
2470
+ ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2471
+ ax.set_title(f"{col} by cohort (grouped)")
2472
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2473
+ plt.tight_layout(); plt.show()
2474
+
2475
+ elif _comp_pref == "stacked_bar":
2476
+ ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2477
+ ax.set_title(f"{col} by cohort (stacked)")
2478
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2479
+ plt.tight_layout(); plt.show()
2480
+
2481
+ elif _comp_pref == "stacked_bar_pct":
2482
+ _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2483
+ ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2484
+ ax.set_title(f"{col} by cohort (100% stacked)")
2485
+ ax.set_xlabel(col); ax.set_ylabel("Percent")
2486
+ plt.tight_layout(); plt.show()
2487
+
2488
+ elif _comp_pref == "heatmap":
2489
+ _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2490
+ import numpy as np
2491
+ fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2492
+ im = ax.imshow(_perc.values, aspect='auto')
2493
+ ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2494
+ ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2495
+ ax.set_title(f"{col} by cohort — % heatmap")
2496
+ for i in range(_perc.shape[0]):
2497
+ for j in range(_perc.shape[1]):
2498
+ ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2499
+ fig.colorbar(im, ax=ax, label="%")
2500
+ plt.tight_layout(); plt.show()
2501
+
2502
+ else: # counts_bar (default)
2503
+ ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2504
+ ax.set_title(f"{col}: total counts (both cohorts)")
2505
+ ax.set_xlabel(col); ax.set_ylabel("Count")
2506
+ plt.tight_layout(); plt.show()
2507
+ """.lstrip()
2508
+ return snippet
2509
+
2510
+ # --------------- CASE B: facet-by (categories/bins) ---------------
2511
+ if facet:
2512
+ facet_col, how = facet
2513
+ # Build facet series
2514
+ if pd.api.types.is_numeric_dtype(df[facet_col]):
2515
+ if how == "bmi":
2516
+ facet_series = _bmi_bins(df[facet_col])
2517
+ else:
2518
+ # generic numeric bins: 3 equal-width bins by default
2519
+ facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2520
+ else:
2521
+ facet_series = df[facet_col].astype(str)
2522
+
2523
+ # Choose pie dimension (categorical to count inside each facet)
2524
+ pie_dim = None
2525
+ for c in cats:
2526
+ if c in df.columns and _is_cat(c):
2527
+ pie_dim = c; break
2528
+ if pie_dim is None:
2529
+ pie_dim = _fallback_cat()
2530
+ if pie_dim is None:
2531
+ return code
2532
+
2533
+ snippet = f"""
2534
+ import math
2535
+ import pandas as pd
2536
+ import matplotlib.pyplot as plt
2537
+
2538
+ df = df.copy()
2539
+ _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2540
+
2541
+ def _select_facet_col(df, preferred=None):
2542
+ if preferred is not None:
2543
+ return preferred
2544
+ # Prefer low-cardinality categoricals (readable pies/grids)
2545
+ cat_cols = [
2546
+ c for c in df.columns
2547
+ if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2548
+ and df[c].nunique() > 1 and df[c].nunique() <= 20
2549
+ ]
2550
+ if cat_cols:
2551
+ cat_cols.sort(key=lambda c: df[c].nunique())
2552
+ return cat_cols[0]
2553
+ # Else fall back to first usable numeric
2554
+ num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2555
+ return num_cols[0] if num_cols else None
2556
+
2557
+ _facet_col = _select_facet_col(df, _preferred)
2558
+
2559
+ if _facet_col is None:
2560
+ # Nothing suitable → single facet keeps pipeline alive
2561
+ df["__facet__"] = "All"
2562
+ else:
2563
+ s = df[_facet_col]
2564
+ if pd.api.types.is_numeric_dtype(s):
2565
+ # Robust numeric binning: quantiles first, fallback to equal-width
2566
+ uniq = pd.Series(s).dropna().nunique()
2567
+ q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2568
+ try:
2569
+ df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2570
+ except Exception:
2571
+ df["__facet__"] = pd.cut(s.astype(float), bins=q)
2572
+ else:
2573
+ # Cap long tails; keep top categories
2574
+ vc = s.astype(str).value_counts()
2575
+ keep = vc.index[:{top_n}]
2576
+ df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2577
+
2578
+ levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2579
+ levels = [x for x in levels if x != "nan"]
2580
+ levels.sort()
2581
+
2582
+ m = len(levels)
2583
+ cols = 3 if m >= 3 else m or 1
2584
+ rows = int(math.ceil(m / cols))
2585
+
2586
+ fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2587
+ if not isinstance(axes, (list, np.ndarray)):
2588
+ axes = np.array([[axes]])
2589
+ axes = axes.reshape(rows, cols)
2590
+
2591
+ for i, lvl in enumerate(levels):
2592
+ r, c = divmod(i, cols)
2593
+ ax = axes[r, c]
2594
+ s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2595
+ .astype(str).value_counts().nlargest({top_n}))
2596
+ if len(s) > 0:
2597
+ ax.pie(s.values, labels=[str(x) for x in s.index],
2598
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2599
+ ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2600
+
2601
+ # hide any empty subplots
2602
+ for j in range(m, rows*cols):
2603
+ r, c = divmod(j, cols)
2604
+ axes[r, c].axis("off")
2605
+
2606
+ plt.tight_layout(); plt.show()
2607
+
2608
+ # --- companion visual (adaptive) ---
2609
+ _comp_pref = "{_comp_pref}"
2610
+ # build contingency table: pie_dim x facet
2611
+ _tab = (df[["__facet__", "{pie_dim}"]]
2612
+ .dropna()
2613
+ .astype({{"__facet__": str, "{pie_dim}": str}})
2614
+ .value_counts()
2615
+ .unstack(level="__facet__")
2616
+ .fillna(0))
2617
+
2618
+ # keep top categories by overall size
2619
+ _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2620
+
2621
+ if _comp_pref == "grouped_bar":
2622
+ ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2623
+ ax.set_title("{pie_dim} by {facet_col} (grouped)")
2624
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2625
+ plt.tight_layout(); plt.show()
2626
+
2627
+ elif _comp_pref == "stacked_bar":
2628
+ ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2629
+ ax.set_title("{pie_dim} by {facet_col} (stacked)")
2630
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2631
+ plt.tight_layout(); plt.show()
2632
+
2633
+ elif _comp_pref == "stacked_bar_pct":
2634
+ _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2635
+ ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2636
+ ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2637
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2638
+ plt.tight_layout(); plt.show()
2639
+
2640
+ elif _comp_pref == "heatmap":
2641
+ _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2642
+ import numpy as np
2643
+ fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2644
+ im = ax.imshow(_perc.values, aspect='auto')
2645
+ ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2646
+ ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2647
+ ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2648
+ for i in range(_perc.shape[0]):
2649
+ for j in range(_perc.shape[1]):
2650
+ ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2651
+ fig.colorbar(im, ax=ax, label="%")
2652
+ plt.tight_layout(); plt.show()
2653
+
2654
+ else: # counts_bar (default denominators)
2655
+ _counts = df["__facet"].value_counts()
2656
+ ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2657
+ ax.set_title("Counts by {facet_col}")
2658
+ ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2659
+ plt.tight_layout(); plt.show()
2660
+
2661
+ """.lstrip()
2662
+ return snippet
2663
+
2664
+ # --------------- CASE C: single pie ---------------
2665
+ chosen = None
2666
+ for c in cats:
2667
+ if c in df.columns and _is_cat(c):
2668
+ chosen = c; break
2669
+ if chosen is None:
2670
+ chosen = _fallback_cat()
2671
+
2672
+ if chosen:
2673
+ snippet = f"""
2674
+ import matplotlib.pyplot as plt
2675
+ counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2676
+ fig, ax = plt.subplots()
2677
+ if len(counts) > 0:
2678
+ ax.pie(counts.values, labels=[str(i) for i in counts.index],
2679
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2680
+ ax.set_title('Distribution of {chosen} (top {top_n})')
2681
+ ax.axis('equal')
2682
+ plt.show()
2683
+ """.lstrip()
2684
+ return snippet
2685
+
2686
+ # numeric last resort
2687
+ num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2688
+ if num_cols:
2689
+ col = num_cols[0]
2690
+ snippet = f"""
2691
+ import pandas as pd
2692
+ import matplotlib.pyplot as plt
2693
+ bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2694
+ counts = bins.value_counts().sort_index()
2695
+ fig, ax = plt.subplots()
2696
+ if len(counts) > 0:
2697
+ ax.pie(counts.values, labels=[str(i) for i in counts.index],
2698
+ autopct='%1.1f%%', startangle=90, counterclock=False)
2699
+ ax.set_title('Distribution of {col} (binned)')
2700
+ ax.axis('equal')
2701
+ plt.show()
2702
+ """.lstrip()
2703
+ return snippet
2704
+
2705
+ return code
2706
+
2707
+
2708
+ def patch_fix_seaborn_palette_calls(code: str) -> str:
2709
+ """
2710
+ Removes seaborn `palette=` when no `hue=` is present in the same call.
2711
+ Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2712
+ """
2713
+ if "sns." not in code:
2714
+ return code
2715
+
2716
+ # Targets common seaborn plotters
2717
+ funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2718
+ pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2719
+
2720
+ def _fix_call(m):
2721
+ head, inner = m.group(1), m.group(2)
2722
+ # If there's already hue=, keep as is
2723
+ if re.search(r"(?<!\w)hue\s*=", inner):
2724
+ return f"{head}{inner})"
2725
+ # Otherwise remove palette=... safely (and any adjacent comma spacing)
2726
+ inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2727
+ inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2728
+ inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2729
+ return f"{head}{inner2})"
2730
+
2731
+ return pattern.sub(_fix_call, code)
2732
+
2733
+
2734
+ def patch_quiet_specific_warnings(code: str) -> str:
2735
+ """
2736
+ Inserts targeted warning filters (not blanket ignores).
2737
+ - seaborn palette/hue deprecation
2738
+ - python-dotenv parse chatter
2739
+ """
2740
+ prelude = (
2741
+ "import warnings\n"
2742
+ "warnings.filterwarnings(\n"
2743
+ " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
2744
+ "warnings.filterwarnings(\n"
2745
+ " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
2746
+ )
2747
+ # If warnings already imported once, just add filters; else insert full prelude.
2748
+ if "import warnings" in code:
2749
+ code = re.sub(
2750
+ r"(import warnings[^\n]*\n)",
2751
+ lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
2752
+ code,
2753
+ count=1
2754
+ )
2755
+
2756
+ else:
2757
+ # place after first import block if possible
2758
+ m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
2759
+ if m:
2760
+ idx = m.end()
2761
+ code = code[:idx] + prelude + code[idx:]
2762
+ else:
2763
+ code = prelude + code
2764
+ return code
2765
+
2766
+
2767
+ def _norm_col_name(s: str) -> str:
2768
+ """normalise a column name: lowercase + strip non-alphanumerics."""
2769
+ return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2770
+
2771
+
2772
+ def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2773
+ """return the actual df column that matches any candidate (after normalisation)."""
2774
+ norm_map = {_norm_col_name(c): c for c in df.columns}
2775
+ for cand in candidates:
2776
+ hit = norm_map.get(_norm_col_name(cand))
2777
+ if hit is not None:
2778
+ return hit
2779
+ return None
2780
+
2781
+
2782
+ def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2783
+ """
2784
+ If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2785
+ Returns (df, found_bool).
2786
+ """
2787
+ if target in df.columns:
2788
+ return df, True
2789
+ col = _first_present(df, [target, *aliases])
2790
+ if col is None:
2791
+ return df, False
2792
+ df[target] = df[col]
2793
+ return df, True
2794
+
2795
+
2796
+ def strip_python_dotenv(code: str) -> str:
2797
+ """
2798
+ Remove any use of python-dotenv from generated code, including:
2799
+ - single and multi-line 'from dotenv import ...'
2800
+ - 'import dotenv' (with or without alias) and calls via any alias
2801
+ - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2802
+ - IPython magics (%load_ext dotenv, %dotenv, %env …)
2803
+ - shell installs like '!pip install python-dotenv'
2804
+ """
2805
+ original = code
2806
+
2807
+ # 0) Kill IPython magics & shell installs referencing dotenv
2808
+ code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2809
+ code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2810
+ code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2811
+ code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2812
+
2813
+ # 1) Remove single-line 'from dotenv import ...'
2814
+ code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2815
+
2816
+ # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2817
+ code = re.sub(
2818
+ r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2819
+ "",
2820
+ code,
2821
+ flags=re.MULTILINE,
2822
+ )
2823
+
2824
+ # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2825
+ aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2826
+ code, flags=re.MULTILINE)
2827
+ code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2828
+ "", code, flags=re.MULTILINE)
2829
+
2830
+ # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2831
+ # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2832
+ fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2833
+ # bare calls
2834
+ code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2835
+ # dotted calls with any identifier prefix (alias or module)
2836
+ code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2837
+ "", code, flags=re.MULTILINE)
2838
+
2839
+ # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2840
+ for al in aliases:
2841
+ code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2842
+
2843
+ # 6) Tidy excess blank lines
2844
+ code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2845
+ return code
2846
+
2847
+
2848
+ def fix_predict_calls_records_arg(code: str) -> str:
2849
+ """
2850
+ If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2851
+ (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2852
+ Works line-by-line to avoid over-rewrites elsewhere.
2853
+ Examples fixed:
2854
+ predict_patient(X_test.iloc[:5].to_dict('records'))
2855
+ predict_risk(df.head(3).to_dict(orient="records"))
2856
+ → predict_patient(X_test.iloc[:5])
2857
+ """
2858
+ fixed_lines = []
2859
+ for line in code.splitlines():
2860
+ if "predict_" in line and "to_dict" in line and "records" in line:
2861
+ line = re.sub(
2862
+ r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2863
+ "",
2864
+ line
2865
+ )
2866
+ fixed_lines.append(line)
2867
+ return "\n".join(fixed_lines)
2868
+
2869
+
2870
+ def fix_fstring_backslash_paths(code: str) -> str:
2871
+ """
2872
+ Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2873
+ → f"...{os.path.join(out_dir, r'plots\\img.png')}"
2874
+ Only touches f-strings that contain a backslash path inside {...}.
2875
+ """
2876
+ def _fix_line(line: str) -> str:
2877
+ # quick check: only f-strings need scanning
2878
+ if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2879
+ return line
2880
+ # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2881
+ pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2882
+ def repl(m):
2883
+ left = m.group(1)
2884
+ right = m.group(2).strip().replace('"', '\\"')
2885
+ return "{os.path.join(" + left + ', r"' + right + '")}'
2886
+ return pattern.sub(repl, line)
2887
+
2888
+ return "\n".join(_fix_line(ln) for ln in code.splitlines())
2889
+
2890
+
2891
+ def ensure_os_import(code: str) -> str:
2892
+ """
2893
+ If os.path.join is used but 'import os' is missing, inject it at the top.
2894
+ """
2895
+ needs = "os.path.join(" in code
2896
+ has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2897
+ has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2898
+ if needs and not (has_import_os or has_from_os):
2899
+ return "import os\n" + code
2900
+ return code
2901
+
2902
+
2903
+ def fix_seaborn_boxplot_nameerror(code: str) -> str:
2904
+ """
2905
+ Fix bad calls like: sns.boxplot(boxplot)
2906
+ Heuristic:
2907
+ - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2908
+ - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
2909
+ - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2910
+ - Else → sns.boxplot(ax=ax)
2911
+ Ensures a matplotlib Axes 'ax' exists.
2912
+ """
2913
+ pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2914
+ if not pattern.search(code):
2915
+ return code
2916
+
2917
+ has_plot_df = re.search(r"\bplot_df\b", code) is not None
2918
+ has_var = re.search(r"\bvar\b", code) is not None
2919
+ has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2920
+
2921
+ if has_plot_df and has_var and has_fh:
2922
+ replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2923
+ elif has_plot_df and has_var:
2924
+ replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2925
+ elif has_plot_df:
2926
+ replacement = "sns.boxplot(data=plot_df, ax=ax)"
2927
+ else:
2928
+ replacement = "sns.boxplot(ax=ax)"
2929
+
2930
+ fixed = pattern.sub(replacement, code)
2931
+
2932
+ # Ensure 'fig, ax = plt.subplots(...)' exists
2933
+ if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2934
+ # Insert right before the first seaborn call
2935
+ m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2936
+ insert_at = m.start() if m else 0
2937
+ fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2938
+
2939
+ return fixed
2940
+
2941
+
2942
+ def fix_seaborn_barplot_nameerror(code: str) -> str:
2943
+ """
2944
+ Fix bad calls like: sns.barplot(barplot)
2945
+ Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2946
+ otherwise degrade safely to an empty call on an existing Axes.
2947
+ """
2948
+ import re
2949
+ pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2950
+ if not pattern.search(code):
2951
+ return code
2952
+
2953
+ has_plot_df = re.search(r"\bplot_df\b", code) is not None
2954
+ has_var = re.search(r"\bvar\b", code) is not None
2955
+ has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2956
+
2957
+ if has_plot_df and has_var and has_fh:
2958
+ replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2959
+ elif has_plot_df and has_var:
2960
+ replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2961
+ elif has_plot_df:
2962
+ replacement = "sns.barplot(data=plot_df, ax=ax)"
2963
+ else:
2964
+ replacement = "sns.barplot(ax=ax)"
2965
+
2966
+ # ensure an Axes 'ax' exists (no-op if already present)
2967
+ if "ax =" not in code:
2968
+ code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2969
+
2970
+ return pattern.sub(replacement, code)
2971
+
2972
+
2973
+ def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2974
+ """
2975
+ Parses the raw text to extract and format the 'refined question',
2976
+ 'intents (tasks)', and 'chronology of tasks' sections.
2977
+ Args:
2978
+ raw_text: The complete input string containing the ML pipeline structure.
2979
+ Returns:
2980
+ A tuple containing:
2981
+ (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2982
+ """
2983
+ # --- 1. Regex Pattern to Extract Sections ---
2984
+ # The pattern uses capturing groups (?) to look for the section headers
2985
+ # (e.g., 'refined question:') and captures all the content until the next
2986
+ # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2987
+
2988
+ pattern = re.compile(
2989
+ r"refined question:(?P<question>.*?)"
2990
+ r"intents \(tasks\):(?P<intents>.*?)"
2991
+ r"Chronology of tasks:(?P<chronology>.*)",
2992
+ re.IGNORECASE | re.DOTALL
2993
+ )
2994
+
2995
+ match = pattern.search(raw_text)
2996
+
2997
+ if not match:
2998
+ raise ValueError("Input text structure does not match the expected pattern.")
2999
+
3000
+ # --- 2. Extract Content ---
3001
+ question_content = match.group('question').strip()
3002
+ intents_content = match.group('intents').strip()
3003
+ chronology_content = match.group('chronology').strip()
3004
+
3005
+ # --- 3. Formatting Functions ---
3006
+
3007
+ def format_question(content):
3008
+ """Formats the Refined Question section."""
3009
+ # Clean up leading/trailing whitespace and ensure clean paragraphs
3010
+ content = content.strip().replace('\n', ' ').replace(' ', ' ')
3011
+
3012
+ # Simple formatting using Markdown headers and bolding
3013
+ formatted = (
3014
+ # "## 1. Project Goal and Objectives\n\n"
3015
+ "<b> Refined Question:</b>\n"
3016
+ f"{content}\n"
3017
+ )
3018
+ return formatted
3019
+
3020
+ def format_intents(content):
3021
+ """Formats the Intents (Tasks) section as a structured list."""
3022
+ # Use regex to find and format each numbered task
3023
+ # It finds 'N. **Text** - ...' and breaks it down.
3024
+
3025
+ tasks = []
3026
+ # Pattern: N. **Text** - Content (including newlines, non-greedy)
3027
+ # We need to explicitly handle the list items starting with '-' within the content
3028
+ task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
3029
+
3030
+ # Split the content by lines and join tasks back into clean strings
3031
+ raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
3032
+
3033
+ for task in raw_tasks:
3034
+ # Replace the initial task number and **Heading** with a Heading 3
3035
+ task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
3036
+
3037
+ # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
3038
+ task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
3039
+ tasks.append(task)
3040
+
3041
+ formatted_tasks = "\n\n".join(tasks)
3042
+
3043
+ return (
3044
+ "\n---\n"
3045
+ "## 2. Methodology and Tasks\n\n"
3046
+ f"{formatted_tasks}\n"
3047
+ )
3048
+
3049
+ def format_chronology(content):
3050
+ """Formats the Chronology section."""
3051
+ # Uses the given LaTeX format
3052
+ content = content.strip().replace(' ', ' \rightarrow ')
3053
+ formatted = (
3054
+ "\n---\n"
3055
+ "## 3. Chronology of Tasks\n"
3056
+ f"$$\\text{{{content}}}$$"
3057
+ )
3058
+ return formatted
3059
+
3060
+ # --- 4. Format and Return ---
3061
+ formatted_question = format_question(question_content)
3062
+ formatted_intents = format_intents(intents_content)
3063
+ formatted_chronology = format_chronology(chronology_content)
3064
+
3065
+ return formatted_question, formatted_intents, formatted_chronology
3066
+
3067
+
3068
+ def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
3069
+ """Combines all formatted parts into a final report string."""
3070
+ return (
3071
+ "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
3072
+ f"{formatted_question}\n"
3073
+ f"{formatted_intents}\n"
3074
+ f"{formatted_chronology}\n"
3075
+ )
3076
+
3077
+
3078
+ def fix_confusion_matrix_for_multilabel(code: str) -> str:
3079
+ """
3080
+ Replace ConfusionMatrixDisplay.from_estimator(...) usages with
3081
+ from_predictions(...) which works for multi-label loops without requiring
3082
+ the estimator to expose _estimator_type.
3083
+ """
3084
+ return re.sub(
3085
+ r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
3086
+ r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
989
3087
  code
990
3088
  )
991
3089
 
3090
+
3091
+ def smx_auto_title_plots(ctx=None, fallback="Analysis"):
3092
+ """
3093
+ Ensure every Matplotlib/Seaborn Axes has a title.
3094
+ Uses refined_question -> askai_question -> fallback.
3095
+ Only sets a title if it's currently empty.
3096
+ """
3097
+ import matplotlib.pyplot as plt
3098
+
3099
+ def _all_figures():
3100
+ try:
3101
+ from matplotlib._pylab_helpers import Gcf
3102
+ return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
3103
+ except Exception:
3104
+ # Best effort fallback
3105
+ nums = plt.get_fignums()
3106
+ return [plt.figure(n) for n in nums] if nums else []
3107
+
3108
+ # Choose a concise title
3109
+ title = None
3110
+ if isinstance(ctx, dict):
3111
+ title = ctx.get("refined_question") or ctx.get("askai_question")
3112
+ title = (str(title).strip().splitlines()[0][:120]) if title else fallback
3113
+
3114
+ for fig in _all_figures():
3115
+ for ax in getattr(fig, "axes", []):
3116
+ try:
3117
+ if not (ax.get_title() or "").strip():
3118
+ ax.set_title(title)
3119
+ except Exception:
3120
+ pass
3121
+ try:
3122
+ fig.tight_layout()
3123
+ except Exception:
3124
+ pass
3125
+
3126
+
3127
+ def patch_fix_sentinel_plot_calls(code: str) -> str:
3128
+ """
3129
+ Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
3130
+ SB_barplot(barplot) -> SB_barplot()
3131
+ SB_barplot(barplot, ...) -> SB_barplot(...)
3132
+ sns.barplot(barplot) -> SB_barplot()
3133
+ sns.barplot(barplot, ...) -> SB_barplot(...)
3134
+ Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
3135
+ """
3136
+ names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
3137
+ for n in names:
3138
+ # SB_* with sentinel as the first arg (with or without trailing args)
3139
+ code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3140
+ code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3141
+ # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
3142
+ code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3143
+ code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
992
3144
  return code
993
3145
 
3146
+
3147
+ def patch_rmse_calls(code: str) -> str:
3148
+ """
3149
+ Make RMSE robust across sklearn versions.
3150
+ - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
3151
+ - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
3152
+ """
3153
+ import re
3154
+ # (a) Specific RMSE pattern
3155
+ code = re.sub(
3156
+ r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3157
+ r"_SMX_rmse(\1)",
3158
+ code,
3159
+ flags=re.DOTALL
3160
+ )
3161
+ # (b) Guard any other MSE calls
3162
+ code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3163
+ return code