syntaxmatrix 2.5.8.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. syntaxmatrix/agentic/agents.py +1150 -55
  2. syntaxmatrix/agentic/agents_orchestrer.py +326 -0
  3. syntaxmatrix/agentic/code_tools_registry.py +27 -32
  4. syntaxmatrix/commentary.py +16 -16
  5. syntaxmatrix/core.py +107 -70
  6. syntaxmatrix/db.py +416 -4
  7. syntaxmatrix/{display.py → display_html.py} +2 -6
  8. syntaxmatrix/gpt_models_latest.py +1 -1
  9. syntaxmatrix/media/__init__.py +0 -0
  10. syntaxmatrix/media/media_pixabay.py +277 -0
  11. syntaxmatrix/models.py +1 -1
  12. syntaxmatrix/page_builder_defaults.py +183 -0
  13. syntaxmatrix/page_builder_generation.py +1122 -0
  14. syntaxmatrix/page_layout_contract.py +644 -0
  15. syntaxmatrix/page_patch_publish.py +1471 -0
  16. syntaxmatrix/preface.py +128 -8
  17. syntaxmatrix/profiles.py +28 -10
  18. syntaxmatrix/routes.py +1347 -427
  19. syntaxmatrix/selftest_page_templates.py +360 -0
  20. syntaxmatrix/settings/client_items.py +28 -0
  21. syntaxmatrix/settings/model_map.py +1022 -208
  22. syntaxmatrix/settings/prompts.py +328 -130
  23. syntaxmatrix/static/assets/hero-default.svg +22 -0
  24. syntaxmatrix/static/icons/bot-icon.png +0 -0
  25. syntaxmatrix/static/icons/favicon.png +0 -0
  26. syntaxmatrix/static/icons/logo.png +0 -0
  27. syntaxmatrix/static/icons/logo2.png +0 -0
  28. syntaxmatrix/static/icons/logo3.png +0 -0
  29. syntaxmatrix/templates/admin_secretes.html +108 -0
  30. syntaxmatrix/templates/dashboard.html +116 -72
  31. syntaxmatrix/templates/edit_page.html +2535 -0
  32. syntaxmatrix/utils.py +2365 -2411
  33. {syntaxmatrix-2.5.8.1.dist-info → syntaxmatrix-2.6.0.dist-info}/METADATA +6 -2
  34. {syntaxmatrix-2.5.8.1.dist-info → syntaxmatrix-2.6.0.dist-info}/RECORD +37 -24
  35. syntaxmatrix/generate_page.py +0 -644
  36. syntaxmatrix/static/icons/hero_bg.jpg +0 -0
  37. {syntaxmatrix-2.5.8.1.dist-info → syntaxmatrix-2.6.0.dist-info}/WHEEL +0 -0
  38. {syntaxmatrix-2.5.8.1.dist-info → syntaxmatrix-2.6.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {syntaxmatrix-2.5.8.1.dist-info → syntaxmatrix-2.6.0.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py CHANGED
@@ -6,89 +6,41 @@ from difflib import get_close_matches
6
6
  from typing import Iterable, Tuple, Dict
7
7
  import inspect
8
8
  from sklearn.preprocessing import OneHotEncoder
9
-
10
-
11
- from syntaxmatrix.agentic.model_templates import (
12
- classification, regression, multilabel_classification,
13
- eda_overview, eda_correlation,
14
- anomaly_detection, ts_anomaly_detection,
15
- dimensionality_reduction, feature_selection,
16
- time_series_forecasting, time_series_classification,
17
- unknown_group_proxy_pack, viz_line,
18
- clustering, recommendation, topic_modelling,
19
- viz_pie, viz_count_bar, viz_box, viz_scatter,
20
- viz_stacked_bar, viz_distribution, viz_area, viz_kde,
21
- )
22
9
  import ast
23
10
 
24
11
  warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
25
12
 
26
- INJECTABLE_INTENTS = [
27
- "classification",
28
- "multilabel_classification",
29
- "regression",
30
- "anomaly_detection",
31
- "time_series_forecasting",
32
- "time_series_classification",
33
- "ts_anomaly_detection",
34
- "dimensionality_reduction",
35
- "feature_selection",
36
- "clustering",
37
- "eda",
38
- "correlation_analysis",
39
- "visualisation",
40
- "recommendation",
41
- "topic_modelling",
42
- ]
43
-
44
- def classify_ml_job(prompt: str) -> str:
13
+ def patch_quiet_specific_warnings(code: str) -> str:
45
14
  """
46
- Very-light intent classifier.
47
- Returns one of:
48
- 'stat_test' | 'time_series' | 'clustering'
49
- 'classification' | 'regression' | 'eda'
15
+ Inserts targeted warning filters (not blanket ignores).
16
+ - seaborn palette/hue deprecation
17
+ - python-dotenv parse chatter
50
18
  """
51
- p = prompt.lower()
52
-
53
- greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
54
- if any(p.startswith(g) or p == g for g in greetings):
55
- return "greeting"
56
-
57
- # Feature selection / importance intent
58
- if any(k in p for k in (
59
- "feature selection", "select k best", "selectkbest", "rfe",
60
- "mutual information", "feature importance", "permutation importance",
61
- "feature engineering suggestions"
62
- )):
63
- return "feature_selection"
64
-
65
- # Dimensionality reduction intent
66
- if any(k in p for k in (
67
- "pca", "principal component", "dimensionality reduction",
68
- "reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
69
- )):
70
- return "dimensionality_reduction"
71
-
72
- # Anomaly / outlier intent
73
- if any(k in p for k in (
74
- "anomaly", "anomalies", "outlier", "outliers", "novelty",
75
- "fraud", "deviation", "rare event", "rare events", "odd pattern",
76
- "suspicious"
77
- )):
78
- return "anomaly_detection"
79
-
80
- if any(k in p for k in ("t-test", "anova", "p-value")):
81
- return "stat_test"
82
- if "forecast" in p or "prophet" in p:
83
- return "time_series"
84
- if "cluster" in p or "kmeans" in p:
85
- return "clustering"
86
- if any(k in p for k in ("accuracy", "precision", "roc")):
87
- return "classification"
88
- if any(k in p for k in ("rmse", "r2", "mae")):
89
- return "regression"
90
-
91
- return "eda"
19
+ prelude = (
20
+ "import warnings\n"
21
+ "warnings.filterwarnings(\n"
22
+ " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
23
+ "warnings.filterwarnings(\n"
24
+ " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
25
+ )
26
+ # If warnings already imported once, just add filters; else insert full prelude.
27
+ if "import warnings" in code:
28
+ code = re.sub(
29
+ r"(import warnings[^\n]*\n)",
30
+ lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
31
+ code,
32
+ count=1
33
+ )
34
+
35
+ else:
36
+ # place after first import block if possible
37
+ m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
38
+ if m:
39
+ idx = m.end()
40
+ code = code[:idx] + prelude + code[idx:]
41
+ else:
42
+ code = prelude + code
43
+ return code
92
44
 
93
45
 
94
46
  def _indent(code: str, spaces: int = 4) -> str:
@@ -121,7 +73,7 @@ def wrap_llm_code_safe(body: str) -> str:
121
73
  " df_local = globals().get('df')\n"
122
74
  " if df_local is not None:\n"
123
75
  " import pandas as pd\n"
124
- " from syntaxmatrix.preface import SB_histplot, _SMX_export_png\n"
76
+ " from syntaxmatrix.preface import SB_histplot, SB_boxplot, SB_scatterplot, SB_heatmap, _SMX_export_png\n"
125
77
  " num_cols = df_local.select_dtypes(include=['number', 'bool']).columns.tolist()\n"
126
78
  " cat_cols = [c for c in df_local.columns if c not in num_cols]\n"
127
79
  " info = {\n"
@@ -135,11 +87,78 @@ def wrap_llm_code_safe(body: str) -> str:
135
87
  " if num_cols:\n"
136
88
  " SB_histplot()\n"
137
89
  " _SMX_export_png()\n"
90
+ " if len(num_cols) >= 2:\n"
91
+ " SB_scatterplot()\n"
92
+ " _SMX_export_png()\n"
93
+ " if num_cols and cat_cols:\n"
94
+ " SB_boxplot()\n"
95
+ " _SMX_export_png()\n"
96
+ " if len(num_cols) >= 2:\n"
97
+ " SB_heatmap()\n"
98
+ " _SMX_export_png()\n"
138
99
  " except Exception as _f:\n"
139
100
  " show(f\"⚠️ Fallback EDA failed: {type(_f).__name__}: {_f}\")\n"
140
101
  )
141
102
 
142
103
 
104
+ def fix_print_html(code: str) -> str:
105
+ """
106
+ Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
107
+ not printed as `<IPython.core.display.HTML object>` to the server console.
108
+ - Rewrites: print(HTML(...)) → display(HTML(...))
109
+ print(display(...)) → display(...)
110
+ print(df.to_html(...)) → display(HTML(df.to_html(...)))
111
+ Also prepends `from IPython.display import display, HTML` if required.
112
+ """
113
+ import re
114
+
115
+ new = code
116
+
117
+ # 1) print(HTML(...)) -> display(HTML(...))
118
+ new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
119
+
120
+ # 2) print(display(...)) -> display(...)
121
+ new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
122
+
123
+ # 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
124
+ new = re.sub(
125
+ r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
126
+ r"display(HTML(\1.to_html(", new
127
+ )
128
+
129
+ # If code references HTML() or display() make sure the import exists
130
+ if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
131
+ "from IPython.display import display, HTML" not in new:
132
+ new = "from IPython.display import display, HTML\n" + new
133
+
134
+ return new
135
+
136
+
137
+ def ensure_ipy_display(code: str) -> str:
138
+ """
139
+ Guarantee that the cell has proper IPython display imports so that
140
+ display(HTML(...)) produces 'display_data' events the kernel captures.
141
+ """
142
+ if "display(" in code and "from IPython.display import display, HTML" not in code:
143
+ return "from IPython.display import display, HTML\n" + code
144
+ return code
145
+
146
+
147
+ def fix_plain_prints(code: str) -> str:
148
+ """
149
+ Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
150
+ to go through SyntaxMatrix's smart display (so it renders in the dashboard).
151
+ Keeps string prints alone.
152
+ """
153
+
154
+ # Skip obvious string-literal prints
155
+ new = re.sub(
156
+ r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
157
+ r"from syntaxmatrix.display import show\nshow(\1)",
158
+ code,
159
+ )
160
+ return new
161
+
143
162
  def harden_ai_code(code: str) -> str:
144
163
  """
145
164
  Make any AI-generated cell resilient:
@@ -202,21 +221,37 @@ def harden_ai_code(code: str) -> str:
202
221
 
203
222
  return pattern.sub(repl, code)
204
223
 
224
+
205
225
  def _wrap_metric_calls(code: str) -> str:
206
226
  names = [
207
- "r2_score","accuracy_score","precision_score","recall_score","f1_score",
208
- "roc_auc_score","classification_report","confusion_matrix",
209
- "mean_absolute_error","mean_absolute_percentage_error",
210
- "explained_variance_score","log_loss","average_precision_score",
211
- "precision_recall_fscore_support"
227
+ "r2_score",
228
+ "accuracy_score",
229
+ "precision_score",
230
+ "recall_score",
231
+ "f1_score",
232
+ "roc_auc_score",
233
+ "classification_report",
234
+ "confusion_matrix",
235
+ "mean_absolute_error",
236
+ "mean_absolute_percentage_error",
237
+ "explained_variance_score",
238
+ "log_loss",
239
+ "average_precision_score",
240
+ "precision_recall_fscore_support",
241
+ "mean_squared_error",
212
242
  ]
213
- pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
243
+ pat = re.compile(
244
+ r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\("
245
+ )
246
+
214
247
  def repl(m):
215
248
  prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
216
249
  name = m.group(2)
217
250
  return f"_SMX_call({prefix}{name}, "
251
+
218
252
  return pat.sub(repl, code)
219
253
 
254
+
220
255
  def _smx_patch_mean_squared_error_squared_kw():
221
256
  """
222
257
  sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
@@ -291,6 +326,8 @@ def harden_ai_code(code: str) -> str:
291
326
  needed.add("r2_score")
292
327
  if "mean_absolute_error" in code:
293
328
  needed.add("mean_absolute_error")
329
+ if "mean_squared_error" in code:
330
+ needed.add("mean_squared_error")
294
331
  # ... add others if you like ...
295
332
 
296
333
  if not needed:
@@ -351,7 +388,8 @@ def harden_ai_code(code: str) -> str:
351
388
  - then falls back to generic but useful EDA.
352
389
 
353
390
  It assumes `from syntaxmatrix.preface import *` has already been done,
354
- so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `_SMX_export_png` and the
391
+ so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `SB_boxplot`,
392
+ `SB_scatterplot`, `SB_heatmap`, `_SMX_export_png` and the
355
393
  patched `show()` are available.
356
394
  """
357
395
  import textwrap
@@ -461,10 +499,25 @@ def harden_ai_code(code: str) -> str:
461
499
  show(df.head(), title='Sample of data')
462
500
  show(info, title='Dataset summary')
463
501
 
464
- # Quick univariate look if we have numeric columns
502
+ # 1) Distribution of a numeric column
465
503
  if num_cols:
466
504
  SB_histplot()
467
505
  _SMX_export_png()
506
+
507
+ # 2) Relationship between two numeric columns
508
+ if len(num_cols) >= 2:
509
+ SB_scatterplot()
510
+ _SMX_export_png()
511
+
512
+ # 3) Distribution of a numeric by the first categorical column
513
+ if num_cols and cat_cols:
514
+ SB_boxplot()
515
+ _SMX_export_png()
516
+
517
+ # 4) Correlation heatmap across numeric columns
518
+ if len(num_cols) >= 2:
519
+ SB_heatmap()
520
+ _SMX_export_png()
468
521
  except Exception as _eda_e:
469
522
  show(f"⚠ EDA fallback failed: {type(_eda_e).__name__}: {_eda_e}")
470
523
  """
@@ -596,19 +649,21 @@ def harden_ai_code(code: str) -> str:
596
649
  except (SyntaxError, IndentationError):
597
650
  fixed = _fallback_snippet()
598
651
 
652
+
653
+ # Fix placeholder Ellipsis handlers from LLM
599
654
  fixed = re.sub(
600
655
  r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
601
656
  "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
602
657
  fixed,
603
658
  )
604
659
 
605
- # Fix placeholder Ellipsis handlers from LLM
660
+ # redirect that import to the real template module.
606
661
  fixed = re.sub(
607
- r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
608
- "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
662
+ r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
663
+ r"from syntaxmatrix.agentic.model_templates import \1",
609
664
  fixed,
610
665
  )
611
-
666
+
612
667
  try:
613
668
  class _SMXMatmulRewriter(ast.NodeTransformer):
614
669
  def visit_BinOp(self, node):
@@ -628,14 +683,6 @@ def harden_ai_code(code: str) -> str:
628
683
  fixed = fixed.replace("\t", " ")
629
684
  fixed = textwrap.dedent(fixed).strip("\n")
630
685
 
631
- # Normalise any mistaken template imports the LLM may have invented.
632
- # If the model writes "from syntaxmatrix.templates import viz_count_bar",
633
- # redirect that import to the real template module.
634
- fixed = re.sub(
635
- r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
636
- r"from syntaxmatrix.agentic.model_templates import \1",
637
- fixed,
638
- )
639
686
  fixed = _ensure_metrics_imports(fixed)
640
687
  fixed = _strip_stray_backrefs(fixed)
641
688
  fixed = _wrap_metric_calls(fixed)
@@ -661,822 +708,822 @@ def harden_ai_code(code: str) -> str:
661
708
  return wrapped
662
709
 
663
710
 
664
- def indent_code(code: str, spaces: int = 4) -> str:
665
- pad = " " * spaces
666
- return "\n".join(pad + line for line in code.splitlines())
711
+ # def indent_code(code: str, spaces: int = 4) -> str:
712
+ # pad = " " * spaces
713
+ # return "\n".join(pad + line for line in code.splitlines())
667
714
 
668
715
 
669
- def fix_boxplot_placeholder(code: str) -> str:
670
- # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
671
- return re.sub(
672
- r"sns\.boxplot\(\s*boxplot\s*\)",
673
- "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
674
- code
675
- )
716
+ # def fix_boxplot_placeholder(code: str) -> str:
717
+ # # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
718
+ # return re.sub(
719
+ # r"sns\.boxplot\(\s*boxplot\s*\)",
720
+ # "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
721
+ # code
722
+ # )
676
723
 
677
724
 
678
- def relax_required_columns(code: str) -> str:
679
- # Remove hard failure on required_cols; keep a soft filter instead
680
- return re.sub(
681
- r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
682
- "required_cols = [c for c in df.columns]\n",
683
- code,
684
- flags=re.S
685
- )
725
+ # def relax_required_columns(code: str) -> str:
726
+ # # Remove hard failure on required_cols; keep a soft filter instead
727
+ # return re.sub(
728
+ # r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
729
+ # "required_cols = [c for c in df.columns]\n",
730
+ # code,
731
+ # flags=re.S
732
+ # )
686
733
 
687
734
 
688
- def make_numeric_vars_dynamic(code: str) -> str:
689
- # Replace any static numeric_vars list with a dynamic selection
690
- return re.sub(
691
- r"numeric_vars\s*=\s*\[.*?\]",
692
- "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
693
- code,
694
- flags=re.S
695
- )
735
+ # def make_numeric_vars_dynamic(code: str) -> str:
736
+ # # Replace any static numeric_vars list with a dynamic selection
737
+ # return re.sub(
738
+ # r"numeric_vars\s*=\s*\[.*?\]",
739
+ # "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
740
+ # code,
741
+ # flags=re.S
742
+ # )
696
743
 
697
744
 
698
- def auto_inject_template(code: str, intents, df) -> str:
699
- """If the LLM forgot the core logic, prepend a skeleton block."""
745
+ # def auto_inject_template(code: str, intents, df) -> str:
746
+ # """If the LLM forgot the core logic, prepend a skeleton block."""
700
747
 
701
- has_fit = ".fit(" in code
702
- has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
748
+ # has_fit = ".fit(" in code
749
+ # has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
703
750
 
704
- UNKNOWN_TOKENS = {
705
- "unknown","not reported","not_reported","not known","n/a","na",
706
- "none","nan","missing","unreported","unspecified","null","-",""
707
- }
708
-
709
- # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
710
- def _call_template(func, df, **hints):
711
- import inspect
712
- try:
713
- params = inspect.signature(func).parameters
714
- kw = {k: v for k, v in hints.items() if k in params}
715
- try:
716
- return func(df, **kw)
717
- except TypeError:
718
- # In case the template changed its signature at runtime
719
- return func(df)
720
- except Exception:
721
- # Absolute safety net
722
- try:
723
- return func(df)
724
- except Exception:
725
- # As a last resort, return empty code so we don't 500
726
- return ""
727
-
728
- def _guess_classification_target(df: pd.DataFrame) -> str | None:
729
- cols = list(df.columns)
730
-
731
- # Helper: does this column look like a sensible label?
732
- def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
733
- try:
734
- nunq = s.dropna().nunique()
735
- except Exception:
736
- return False
737
- # need at least 2 classes, but not hundreds
738
- if nunq < 2 or nunq > 20:
739
- return False
740
- bad_name_keys = ("id", "identifier", "index", "uuid", "key")
741
- name = str(col_name).lower()
742
- if any(k in name for k in bad_name_keys):
743
- return False
744
- return True
745
-
746
- # 1) columns whose names look like labels
747
- label_keys = ("target", "label", "outcome", "class", "y", "status")
748
- name_candidates: list[str] = []
749
- for key in label_keys:
750
- for c in cols:
751
- if key in str(c).lower():
752
- name_candidates.append(c)
753
- if name_candidates:
754
- break # keep the earliest matching key-group
755
-
756
- # prioritise name-based candidates that also look like proper label columns
757
- for c in name_candidates:
758
- if _is_reasonable_class_col(df[c], c):
759
- return c
760
- if name_candidates:
761
- # fall back to the first name-based candidate if none passed the shape test
762
- return name_candidates[0]
763
-
764
- # 2) any column with a small number of distinct values (likely a class label)
765
- for c in cols:
766
- s = df[c]
767
- if _is_reasonable_class_col(s, c):
768
- return c
769
-
770
- # Nothing suitable found
771
- return None
772
-
773
- def _guess_regression_target(df: pd.DataFrame) -> str | None:
774
- num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
775
- if not num_cols:
776
- return None
777
- # Avoid obvious ID-like columns
778
- bad_keys = ("id", "identifier", "index")
779
- candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
780
- return (candidates or num_cols)[-1]
751
+ # UNKNOWN_TOKENS = {
752
+ # "unknown","not reported","not_reported","not known","n/a","na",
753
+ # "none","nan","missing","unreported","unspecified","null","-",""
754
+ # }
755
+
756
+ # # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
757
+ # def _call_template(func, df, **hints):
758
+ # import inspect
759
+ # try:
760
+ # params = inspect.signature(func).parameters
761
+ # kw = {k: v for k, v in hints.items() if k in params}
762
+ # try:
763
+ # return func(df, **kw)
764
+ # except TypeError:
765
+ # # In case the template changed its signature at runtime
766
+ # return func(df)
767
+ # except Exception:
768
+ # # Absolute safety net
769
+ # try:
770
+ # return func(df)
771
+ # except Exception:
772
+ # # As a last resort, return empty code so we don't 500
773
+ # return ""
774
+
775
+ # def _guess_classification_target(df: pd.DataFrame) -> str | None:
776
+ # cols = list(df.columns)
777
+
778
+ # # Helper: does this column look like a sensible label?
779
+ # def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
780
+ # try:
781
+ # nunq = s.dropna().nunique()
782
+ # except Exception:
783
+ # return False
784
+ # # need at least 2 classes, but not hundreds
785
+ # if nunq < 2 or nunq > 20:
786
+ # return False
787
+ # bad_name_keys = ("id", "identifier", "index", "uuid", "key")
788
+ # name = str(col_name).lower()
789
+ # if any(k in name for k in bad_name_keys):
790
+ # return False
791
+ # return True
792
+
793
+ # # 1) columns whose names look like labels
794
+ # label_keys = ("target", "label", "outcome", "class", "y", "status")
795
+ # name_candidates: list[str] = []
796
+ # for key in label_keys:
797
+ # for c in cols:
798
+ # if key in str(c).lower():
799
+ # name_candidates.append(c)
800
+ # if name_candidates:
801
+ # break # keep the earliest matching key-group
802
+
803
+ # # prioritise name-based candidates that also look like proper label columns
804
+ # for c in name_candidates:
805
+ # if _is_reasonable_class_col(df[c], c):
806
+ # return c
807
+ # if name_candidates:
808
+ # # fall back to the first name-based candidate if none passed the shape test
809
+ # return name_candidates[0]
810
+
811
+ # # 2) any column with a small number of distinct values (likely a class label)
812
+ # for c in cols:
813
+ # s = df[c]
814
+ # if _is_reasonable_class_col(s, c):
815
+ # return c
816
+
817
+ # # Nothing suitable found
818
+ # return None
819
+
820
+ # def _guess_regression_target(df: pd.DataFrame) -> str | None:
821
+ # num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
822
+ # if not num_cols:
823
+ # return None
824
+ # # Avoid obvious ID-like columns
825
+ # bad_keys = ("id", "identifier", "index")
826
+ # candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
827
+ # return (candidates or num_cols)[-1]
781
828
 
782
- def _guess_time_col(df: pd.DataFrame) -> str | None:
783
- # Prefer actual datetime dtype
784
- dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
785
- if dt_cols:
786
- return dt_cols[0]
787
-
788
- # Fallback: name-based hints
789
- name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
790
- for c in df.columns:
791
- name = str(c).lower()
792
- if any(k in name for k in name_keys):
793
- return c
794
- return None
795
-
796
- def _guess_entity_col(df: pd.DataFrame) -> str | None:
797
- # Typical sequence IDs: id, patient, subject, device, series, entity
798
- keys = ["id", "patient", "subject", "device", "series", "entity"]
799
- candidates = []
800
- for c in df.columns:
801
- name = str(c).lower()
802
- if any(k in name for k in keys):
803
- candidates.append(c)
804
- return candidates[0] if candidates else None
805
-
806
- def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
807
- # Try label-like names first
808
- keys = ["target", "label", "class", "outcome", "y"]
809
- for key in keys:
810
- for c in df.columns:
811
- if key in str(c).lower():
812
- return c
813
-
814
- # Fallback: any column with few distinct values (e.g. <= 10)
815
- for c in df.columns:
816
- s = df[c]
817
- # avoid obvious IDs
818
- if any(k in str(c).lower() for k in ["id", "index"]):
819
- continue
820
- try:
821
- nunq = s.dropna().nunique()
822
- except Exception:
823
- continue
824
- if 1 < nunq <= 10:
825
- return c
826
-
827
- return None
828
-
829
- def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
830
- cols = list(df.columns)
831
- lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
832
- # also include boolean/binary columns with suitable names
833
- for c in cols:
834
- s = df[c]
835
- try:
836
- nunq = s.dropna().nunique()
837
- except Exception:
838
- continue
839
- if nunq in (2,) and c not in lbl_like:
840
- # avoid obvious IDs
841
- if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
842
- lbl_like.append(c)
843
- # keep at most, say, 12 to avoid accidental flood
844
- return lbl_like[:12]
845
-
846
- def _find_unknownish_column(df: pd.DataFrame) -> str | None:
847
- # Search categorical-like columns for any 'unknown-like' values or high missingness
848
- candidates = []
849
- for c in df.columns:
850
- s = df[c]
851
- # focus on object/category/boolean-ish or low-card columns
852
- if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
853
- continue
854
- try:
855
- vals = s.astype(str).str.strip().str.lower()
856
- except Exception:
857
- continue
858
- # score: presence of unknown tokens + missing rate
859
- token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
860
- miss_rate = s.isna().mean()
861
- name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
862
- score = 3*token_hit + 2*name_bonus + miss_rate
863
- if token_hit or miss_rate > 0.05 or name_bonus:
864
- candidates.append((score, c))
865
- if not candidates:
866
- return None
867
- candidates.sort(reverse=True)
868
- return candidates[0][1]
869
-
870
- def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
871
- cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
872
- # prefer non-constant columns
873
- scored = []
874
- for c in cols:
875
- try:
876
- v = df[c].dropna()
877
- var = float(v.var()) if len(v) else 0.0
878
- scored.append((var, c))
879
- except Exception:
880
- continue
881
- scored.sort(reverse=True)
882
- return [c for _, c in scored[:max_n]]
883
-
884
- def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
885
- exclude = exclude or set()
886
- picks = []
887
- for c in df.columns:
888
- if c in exclude:
889
- continue
890
- s = df[c]
891
- if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
892
- nunq = s.dropna().nunique()
893
- if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
894
- picks.append((nunq, c))
895
- picks.sort(reverse=True)
896
- return [c for _, c in picks[:max_n]]
897
-
898
- def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
899
- exclude = exclude or set()
900
- # name hints first
901
- name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
902
- for c in df.columns:
903
- if c in exclude:
904
- continue
905
- name = str(c).lower()
906
- if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
907
- return c
908
- # fallback: any binary numeric
909
- for c in df.select_dtypes(include=[np.number, "bool"]).columns:
910
- if c in exclude:
911
- continue
912
- try:
913
- if df[c].dropna().nunique() == 2:
914
- return c
915
- except Exception:
916
- continue
917
- return None
829
+ # def _guess_time_col(df: pd.DataFrame) -> str | None:
830
+ # # Prefer actual datetime dtype
831
+ # dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
832
+ # if dt_cols:
833
+ # return dt_cols[0]
834
+
835
+ # # Fallback: name-based hints
836
+ # name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
837
+ # for c in df.columns:
838
+ # name = str(c).lower()
839
+ # if any(k in name for k in name_keys):
840
+ # return c
841
+ # return None
842
+
843
+ # def _guess_entity_col(df: pd.DataFrame) -> str | None:
844
+ # # Typical sequence IDs: id, patient, subject, device, series, entity
845
+ # keys = ["id", "patient", "subject", "device", "series", "entity"]
846
+ # candidates = []
847
+ # for c in df.columns:
848
+ # name = str(c).lower()
849
+ # if any(k in name for k in keys):
850
+ # candidates.append(c)
851
+ # return candidates[0] if candidates else None
852
+
853
+ # def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
854
+ # # Try label-like names first
855
+ # keys = ["target", "label", "class", "outcome", "y"]
856
+ # for key in keys:
857
+ # for c in df.columns:
858
+ # if key in str(c).lower():
859
+ # return c
860
+
861
+ # # Fallback: any column with few distinct values (e.g. <= 10)
862
+ # for c in df.columns:
863
+ # s = df[c]
864
+ # # avoid obvious IDs
865
+ # if any(k in str(c).lower() for k in ["id", "index"]):
866
+ # continue
867
+ # try:
868
+ # nunq = s.dropna().nunique()
869
+ # except Exception:
870
+ # continue
871
+ # if 1 < nunq <= 10:
872
+ # return c
873
+
874
+ # return None
875
+
876
+ # def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
877
+ # cols = list(df.columns)
878
+ # lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
879
+ # # also include boolean/binary columns with suitable names
880
+ # for c in cols:
881
+ # s = df[c]
882
+ # try:
883
+ # nunq = s.dropna().nunique()
884
+ # except Exception:
885
+ # continue
886
+ # if nunq in (2,) and c not in lbl_like:
887
+ # # avoid obvious IDs
888
+ # if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
889
+ # lbl_like.append(c)
890
+ # # keep at most, say, 12 to avoid accidental flood
891
+ # return lbl_like[:12]
892
+
893
+ # def _find_unknownish_column(df: pd.DataFrame) -> str | None:
894
+ # # Search categorical-like columns for any 'unknown-like' values or high missingness
895
+ # candidates = []
896
+ # for c in df.columns:
897
+ # s = df[c]
898
+ # # focus on object/category/boolean-ish or low-card columns
899
+ # if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
900
+ # continue
901
+ # try:
902
+ # vals = s.astype(str).str.strip().str.lower()
903
+ # except Exception:
904
+ # continue
905
+ # # score: presence of unknown tokens + missing rate
906
+ # token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
907
+ # miss_rate = s.isna().mean()
908
+ # name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
909
+ # score = 3*token_hit + 2*name_bonus + miss_rate
910
+ # if token_hit or miss_rate > 0.05 or name_bonus:
911
+ # candidates.append((score, c))
912
+ # if not candidates:
913
+ # return None
914
+ # candidates.sort(reverse=True)
915
+ # return candidates[0][1]
916
+
917
+ # def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
918
+ # cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
919
+ # # prefer non-constant columns
920
+ # scored = []
921
+ # for c in cols:
922
+ # try:
923
+ # v = df[c].dropna()
924
+ # var = float(v.var()) if len(v) else 0.0
925
+ # scored.append((var, c))
926
+ # except Exception:
927
+ # continue
928
+ # scored.sort(reverse=True)
929
+ # return [c for _, c in scored[:max_n]]
930
+
931
+ # def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
932
+ # exclude = exclude or set()
933
+ # picks = []
934
+ # for c in df.columns:
935
+ # if c in exclude:
936
+ # continue
937
+ # s = df[c]
938
+ # if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
939
+ # nunq = s.dropna().nunique()
940
+ # if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
941
+ # picks.append((nunq, c))
942
+ # picks.sort(reverse=True)
943
+ # return [c for _, c in picks[:max_n]]
944
+
945
+ # def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
946
+ # exclude = exclude or set()
947
+ # # name hints first
948
+ # name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
949
+ # for c in df.columns:
950
+ # if c in exclude:
951
+ # continue
952
+ # name = str(c).lower()
953
+ # if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
954
+ # return c
955
+ # # fallback: any binary numeric
956
+ # for c in df.select_dtypes(include=[np.number, "bool"]).columns:
957
+ # if c in exclude:
958
+ # continue
959
+ # try:
960
+ # if df[c].dropna().nunique() == 2:
961
+ # return c
962
+ # except Exception:
963
+ # continue
964
+ # return None
918
965
 
919
966
 
920
- def _pick_viz_template(signal: str):
921
- s = signal.lower()
967
+ # def _pick_viz_template(signal: str):
968
+ # s = signal.lower()
922
969
 
923
- # explicit chart requests
924
- if any(k in s for k in ("pie", "donut")):
925
- return viz_pie
970
+ # # explicit chart requests
971
+ # if any(k in s for k in ("pie", "donut")):
972
+ # return viz_pie
926
973
 
927
- if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
928
- return viz_stacked_bar
974
+ # if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
975
+ # return viz_stacked_bar
929
976
 
930
- if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
931
- return viz_distribution
977
+ # if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
978
+ # return viz_distribution
932
979
 
933
- if any(k in s for k in ("kde", "density")):
934
- return viz_kde
980
+ # if any(k in s for k in ("kde", "density")):
981
+ # return viz_kde
935
982
 
936
- # these three you asked about
937
- if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
938
- return viz_box
983
+ # # these three you asked about
984
+ # if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
985
+ # return viz_box
939
986
 
940
- if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
941
- return viz_scatter
987
+ # if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
988
+ # return viz_scatter
942
989
 
943
- if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
944
- return viz_count_bar
990
+ # if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
991
+ # return viz_count_bar
945
992
 
946
- if any(k in s for k in ("area", "trend", "over time", "time series")):
947
- return viz_area
993
+ # if any(k in s for k in ("area", "trend", "over time", "time series")):
994
+ # return viz_area
948
995
 
949
- # fallback
950
- return viz_line
996
+ # # fallback
997
+ # return viz_line
951
998
 
952
- for intent in intents:
999
+ # for intent in intents:
953
1000
 
954
- if intent not in INJECTABLE_INTENTS:
955
- return code
1001
+ # if intent not in INJECTABLE_INTENTS:
1002
+ # return code
956
1003
 
957
- # Correlation analysis
958
- if intent == "correlation_analysis" and not has_fit:
959
- return eda_correlation(df) + "\n\n" + code
960
-
961
- # Generic visualisation (keyword-based)
962
- if intent == "visualisation" and not has_fit and not has_plot:
963
- rq = str(globals().get("refined_question", ""))
964
- # aq = str(globals().get("askai_question", ""))
965
- signal = rq + "\n" + str(intents) + "\n" + code
966
- tpl = _pick_viz_template(signal)
967
- return tpl(df) + "\n\n" + code
1004
+ # # Correlation analysis
1005
+ # if intent == "correlation_analysis" and not has_fit:
1006
+ # return eda_correlation(df) + "\n\n" + code
1007
+
1008
+ # # Generic visualisation (keyword-based)
1009
+ # if intent == "visualisation" and not has_fit and not has_plot:
1010
+ # rq = str(globals().get("refined_question", ""))
1011
+ # # aq = str(globals().get("askai_question", ""))
1012
+ # signal = rq + "\n" + str(intents) + "\n" + code
1013
+ # tpl = _pick_viz_template(signal)
1014
+ # return tpl(df) + "\n\n" + code
968
1015
 
969
- if intent == "clustering" and not has_fit:
970
- return clustering(df) + "\n\n" + code
1016
+ # if intent == "clustering" and not has_fit:
1017
+ # return clustering(df) + "\n\n" + code
971
1018
 
972
- if intent == "recommendation" and not has_fit:
973
- return recommendation(df) + "\\n\\n" + code
1019
+ # if intent == "recommendation" and not has_fit:
1020
+ # return recommendation(df) + "\\n\\n" + code
974
1021
 
975
- if intent == "topic_modelling" and not has_fit:
976
- return topic_modelling(df) + "\\n\\n" + code
1022
+ # if intent == "topic_modelling" and not has_fit:
1023
+ # return topic_modelling(df) + "\\n\\n" + code
977
1024
 
978
- if intent == "eda" and not has_fit:
979
- return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
980
-
981
- # --- Classification ------------------------------------------------
982
- if intent == "classification" and not has_fit:
983
- target = _guess_classification_target(df)
984
- if target:
985
- return classification(df) + "\n\n" + code
986
- # return _call_template(classification, df, target) + "\n\n" + code
987
-
988
- # --- Regression ----------------------------------------------------
989
- if intent == "regression" and not has_fit:
990
- target = _guess_regression_target(df)
991
- if target:
992
- return regression(df) + "\n\n" + code
993
- # return _call_template(regression, df, target) + "\n\n" + code
994
-
995
- # --- Anomaly detection --------------------------------------------
996
- if intent == "anomaly_detection":
997
- uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
998
- if not uses_anomaly:
999
- return anomaly_detection(df) + "\n\n" + code
1000
-
1001
- # --- Time-series anomaly detection --------------------------------
1002
- if intent == "ts_anomaly_detection":
1003
- uses_ts = "STL(" in code or "seasonal_decompose(" in code
1004
- if not uses_ts:
1005
- return ts_anomaly_detection(df) + "\n\n" + code
1025
+ # if intent == "eda" and not has_fit:
1026
+ # return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
1027
+
1028
+ # # --- Classification ------------------------------------------------
1029
+ # if intent == "classification" and not has_fit:
1030
+ # target = _guess_classification_target(df)
1031
+ # if target:
1032
+ # return classification(df) + "\n\n" + code
1033
+ # # return _call_template(classification, df, target) + "\n\n" + code
1034
+
1035
+ # # --- Regression ----------------------------------------------------
1036
+ # if intent == "regression" and not has_fit:
1037
+ # target = _guess_regression_target(df)
1038
+ # if target:
1039
+ # return regression(df) + "\n\n" + code
1040
+ # # return _call_template(regression, df, target) + "\n\n" + code
1041
+
1042
+ # # --- Anomaly detection --------------------------------------------
1043
+ # if intent == "anomaly_detection":
1044
+ # uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
1045
+ # if not uses_anomaly:
1046
+ # return anomaly_detection(df) + "\n\n" + code
1047
+
1048
+ # # --- Time-series anomaly detection --------------------------------
1049
+ # if intent == "ts_anomaly_detection":
1050
+ # uses_ts = "STL(" in code or "seasonal_decompose(" in code
1051
+ # if not uses_ts:
1052
+ # return ts_anomaly_detection(df) + "\n\n" + code
1006
1053
 
1007
- # --- Time-series classification --------------------------------
1008
- if intent == "time_series_classification" and not has_fit:
1009
- time_col = _guess_time_col(df)
1010
- entity_col = _guess_entity_col(df)
1011
- target_col = _guess_ts_class_target(df)
1012
-
1013
- # If we can't confidently identify these, do NOT inject anything
1014
- if time_col and entity_col and target_col:
1015
- return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
1016
-
1017
- # --- Dimensionality reduction --------------------------------------
1018
- if intent == "dimensionality_reduction":
1019
- uses_dr = any(k in code for k in ("PCA(", "TSNE("))
1020
- if not uses_dr:
1021
- return dimensionality_reduction(df) + "\n\n" + code
1022
-
1023
- # --- Feature selection ---------------------------------------------
1024
- if intent == "feature_selection":
1025
- uses_fs = any(k in code for k in (
1026
- "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
1027
- ))
1028
- if not uses_fs:
1029
- return feature_selection(df) + "\n\n" + code
1030
-
1031
- # --- EDA / correlation / visualisation -----------------------------
1032
- if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
1033
- if intent == "correlation_analysis":
1034
- return eda_correlation(df) + "\n\n" + code
1035
- else:
1036
- return eda_overview(df) + "\n\n" + code
1037
-
1038
- # --- Time-series forecasting ---------------------------------------
1039
- if intent == "time_series_forecasting" and not has_fit:
1040
- uses_ts_forecast = any(k in code for k in (
1041
- "ARIMA", "ExponentialSmoothing", "forecast", "predict("
1042
- ))
1043
- if not uses_ts_forecast:
1044
- return time_series_forecasting(df) + "\n\n" + code
1054
+ # # --- Time-series classification --------------------------------
1055
+ # if intent == "time_series_classification" and not has_fit:
1056
+ # time_col = _guess_time_col(df)
1057
+ # entity_col = _guess_entity_col(df)
1058
+ # target_col = _guess_ts_class_target(df)
1059
+
1060
+ # # If we can't confidently identify these, do NOT inject anything
1061
+ # if time_col and entity_col and target_col:
1062
+ # return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
1063
+
1064
+ # # --- Dimensionality reduction --------------------------------------
1065
+ # if intent == "dimensionality_reduction":
1066
+ # uses_dr = any(k in code for k in ("PCA(", "TSNE("))
1067
+ # if not uses_dr:
1068
+ # return dimensionality_reduction(df) + "\n\n" + code
1069
+
1070
+ # # --- Feature selection ---------------------------------------------
1071
+ # if intent == "feature_selection":
1072
+ # uses_fs = any(k in code for k in (
1073
+ # "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
1074
+ # ))
1075
+ # if not uses_fs:
1076
+ # return feature_selection(df) + "\n\n" + code
1077
+
1078
+ # # --- EDA / correlation / visualisation -----------------------------
1079
+ # if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
1080
+ # if intent == "correlation_analysis":
1081
+ # return eda_correlation(df) + "\n\n" + code
1082
+ # else:
1083
+ # return eda_overview(df) + "\n\n" + code
1084
+
1085
+ # # --- Time-series forecasting ---------------------------------------
1086
+ # if intent == "time_series_forecasting" and not has_fit:
1087
+ # uses_ts_forecast = any(k in code for k in (
1088
+ # "ARIMA", "ExponentialSmoothing", "forecast", "predict("
1089
+ # ))
1090
+ # if not uses_ts_forecast:
1091
+ # return time_series_forecasting(df) + "\n\n" + code
1045
1092
 
1046
- # --- Multi-label classification -----------------------------------
1047
- if intent in ("multilabel_classification",) and not has_fit:
1048
- label_cols = _guess_multilabel_cols(df)
1049
- if len(label_cols) >= 2:
1050
- return multilabel_classification(df, label_cols) + "\n\n" + code
1051
-
1052
- group_col = _find_unknownish_column(df)
1053
- if group_col:
1054
- num_cols = _guess_numeric_cols(df)
1055
- cat_cols = _guess_categorical_cols(df, exclude={group_col})
1056
- outcome_col = None # generic; let template skip if not present
1057
- tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1058
-
1059
- # Return template + guarded (repaired) LLM code, so it never crashes
1060
- repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1061
- return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1093
+ # # --- Multi-label classification -----------------------------------
1094
+ # if intent in ("multilabel_classification",) and not has_fit:
1095
+ # label_cols = _guess_multilabel_cols(df)
1096
+ # if len(label_cols) >= 2:
1097
+ # return multilabel_classification(df, label_cols) + "\n\n" + code
1098
+
1099
+ # group_col = _find_unknownish_column(df)
1100
+ # if group_col:
1101
+ # num_cols = _guess_numeric_cols(df)
1102
+ # cat_cols = _guess_categorical_cols(df, exclude={group_col})
1103
+ # outcome_col = None # generic; let template skip if not present
1104
+ # tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1105
+
1106
+ # # Return template + guarded (repaired) LLM code, so it never crashes
1107
+ # repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1108
+ # return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1062
1109
 
1063
- return code
1064
-
1065
-
1066
- def fix_values_sum_numeric_only_bug(code: str) -> str:
1067
- """
1068
- If a previous pass injected numeric_only=True into a NumPy-style sum,
1069
- e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1070
- """
1071
- # .values.sum(numeric_only=True, ...)
1072
- code = re.sub(
1073
- r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1074
- ".to_numpy().sum()",
1075
- code,
1076
- flags=re.IGNORECASE,
1077
- )
1078
- # .to_numpy().sum(numeric_only=True, ...)
1079
- code = re.sub(
1080
- r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1081
- ".to_numpy().sum()",
1082
- code,
1083
- flags=re.IGNORECASE,
1084
- )
1085
- return code
1086
-
1087
-
1088
- def strip_describe_slice(code: str) -> str:
1089
- """
1090
- Remove any pattern like df.groupby(...).describe()[[ ... ]] because
1091
- slicing a SeriesGroupBy.describe() causes AttributeError.
1092
- We leave the plain .describe() in place (harmless) and let our own
1093
- table patcher add the correct .agg() table afterwards.
1094
- """
1095
- pat = re.compile(
1096
- r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
1097
- flags=re.DOTALL,
1098
- )
1099
- return pat.sub(r"\1)", code)
1100
-
1101
-
1102
- def remove_plt_show(code: str) -> str:
1103
- """Removes all plt.show() calls from the generated code string."""
1104
- return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
1105
-
1106
-
1107
- def patch_plot_with_table(code: str) -> str:
1108
- """
1109
- ▸ strips every `plt.show()` (avoids warnings)
1110
- ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
1111
- rendered in the dashboard
1112
- ▸ appends a summary-stats table **after** the plot
1113
- """
1114
- # 0. drop plt.show()
1115
- lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
1116
-
1117
- # 1. locate the last plotting line
1118
- plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
1119
- last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
1120
- if last_plot == -1:
1121
- return "\n".join(lines) # nothing to do
1122
-
1123
- whole = "\n".join(lines)
1124
-
1125
- # 2. detect group / feature (if any)
1126
- group, feature = None, None
1127
- xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
1128
- ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
1129
- if xm and ym:
1130
- group, feature = xm.group(1), ym.group(1)
1131
- else:
1132
- cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
1133
- if cm:
1134
- feature, group = cm.group(1), cm.group(2)
1135
-
1136
- # 3. code that captures current fig → PNG → HTML
1137
- img_block = textwrap.dedent("""
1138
- import io, base64
1139
- buf = io.BytesIO()
1140
- plt.savefig(buf, format='png', bbox_inches='tight')
1141
- buf.seek(0)
1142
- img_b64 = base64.b64encode(buf.read()).decode('utf-8')
1143
- from IPython.display import display, HTML
1144
- display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
1145
- plt.close()
1146
- """)
1147
-
1148
- # 4. build summary-table code
1149
- if group and feature:
1150
- tbl_block = (
1151
- f"summary_table = (\n"
1152
- f" df.groupby('{group}')['{feature}']\n"
1153
- f" .agg(['count','mean','std','min','median','max'])\n"
1154
- f" .rename(columns={{'median':'50%'}})\n"
1155
- f")\n"
1156
- )
1157
- elif ym:
1158
- feature = ym.group(1)
1159
- tbl_block = (
1160
- f"summary_table = (\n"
1161
- f" df['{feature}']\n"
1162
- f" .agg(['count','mean','std','min','median','max'])\n"
1163
- f" .rename(columns={{'median':'50%'}})\n"
1164
- f")\n"
1165
- )
1110
+ # return code
1111
+
1112
+
1113
+ # def fix_values_sum_numeric_only_bug(code: str) -> str:
1114
+ # """
1115
+ # If a previous pass injected numeric_only=True into a NumPy-style sum,
1116
+ # e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1117
+ # """
1118
+ # # .values.sum(numeric_only=True, ...)
1119
+ # code = re.sub(
1120
+ # r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1121
+ # ".to_numpy().sum()",
1122
+ # code,
1123
+ # flags=re.IGNORECASE,
1124
+ # )
1125
+ # # .to_numpy().sum(numeric_only=True, ...)
1126
+ # code = re.sub(
1127
+ # r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1128
+ # ".to_numpy().sum()",
1129
+ # code,
1130
+ # flags=re.IGNORECASE,
1131
+ # )
1132
+ # return code
1133
+
1134
+
1135
+ # def strip_describe_slice(code: str) -> str:
1136
+ # """
1137
+ # Remove any pattern like df.groupby(...).describe()[[ ... ]] because
1138
+ # slicing a SeriesGroupBy.describe() causes AttributeError.
1139
+ # We leave the plain .describe() in place (harmless) and let our own
1140
+ # table patcher add the correct .agg() table afterwards.
1141
+ # """
1142
+ # pat = re.compile(
1143
+ # r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
1144
+ # flags=re.DOTALL,
1145
+ # )
1146
+ # return pat.sub(r"\1)", code)
1147
+
1148
+
1149
+ # def remove_plt_show(code: str) -> str:
1150
+ # """Removes all plt.show() calls from the generated code string."""
1151
+ # return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
1152
+
1153
+
1154
+ # def patch_plot_with_table(code: str) -> str:
1155
+ # """
1156
+ # ▸ strips every `plt.show()` (avoids warnings)
1157
+ # ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
1158
+ # rendered in the dashboard
1159
+ # ▸ appends a summary-stats table **after** the plot
1160
+ # """
1161
+ # # 0. drop plt.show()
1162
+ # lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
1163
+
1164
+ # # 1. locate the last plotting line
1165
+ # plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
1166
+ # last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
1167
+ # if last_plot == -1:
1168
+ # return "\n".join(lines) # nothing to do
1169
+
1170
+ # whole = "\n".join(lines)
1171
+
1172
+ # # 2. detect group / feature (if any)
1173
+ # group, feature = None, None
1174
+ # xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
1175
+ # ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
1176
+ # if xm and ym:
1177
+ # group, feature = xm.group(1), ym.group(1)
1178
+ # else:
1179
+ # cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
1180
+ # if cm:
1181
+ # feature, group = cm.group(1), cm.group(2)
1182
+
1183
+ # # 3. code that captures current fig → PNG → HTML
1184
+ # img_block = textwrap.dedent("""
1185
+ # import io, base64
1186
+ # buf = io.BytesIO()
1187
+ # plt.savefig(buf, format='png', bbox_inches='tight')
1188
+ # buf.seek(0)
1189
+ # img_b64 = base64.b64encode(buf.read()).decode('utf-8')
1190
+ # from IPython.display import display, HTML
1191
+ # display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
1192
+ # plt.close()
1193
+ # """)
1194
+
1195
+ # # 4. build summary-table code
1196
+ # if group and feature:
1197
+ # tbl_block = (
1198
+ # f"summary_table = (\n"
1199
+ # f" df.groupby('{group}')['{feature}']\n"
1200
+ # f" .agg(['count','mean','std','min','median','max'])\n"
1201
+ # f" .rename(columns={{'median':'50%'}})\n"
1202
+ # f")\n"
1203
+ # )
1204
+ # elif ym:
1205
+ # feature = ym.group(1)
1206
+ # tbl_block = (
1207
+ # f"summary_table = (\n"
1208
+ # f" df['{feature}']\n"
1209
+ # f" .agg(['count','mean','std','min','median','max'])\n"
1210
+ # f" .rename(columns={{'median':'50%'}})\n"
1211
+ # f")\n"
1212
+ # )
1166
1213
 
1167
- # 3️⃣ grid-search results
1168
- elif "GridSearchCV(" in code:
1169
- tbl_block = textwrap.dedent("""
1170
- # build tidy CV-results table
1171
- cv_df = (
1172
- pd.DataFrame(grid_search.cv_results_)
1173
- .loc[:, ['param_n_estimators', 'param_max_depth',
1174
- 'mean_test_score', 'std_test_score']]
1175
- .rename(columns={
1176
- 'param_n_estimators': 'n_estimators',
1177
- 'param_max_depth': 'max_depth',
1178
- 'mean_test_score': 'mean_cv_accuracy',
1179
- 'std_test_score': 'std'
1180
- })
1181
- .sort_values('mean_cv_accuracy', ascending=False)
1182
- .reset_index(drop=True)
1183
- )
1184
- summary_table = cv_df
1185
- """).strip() + "\n"
1186
- else:
1187
- tbl_block = (
1188
- "summary_table = (\n"
1189
- " df.describe().T[['count','mean','std','min','50%','max']]\n"
1190
- ")\n"
1191
- )
1192
-
1193
- tbl_block += "show(summary_table, title='Summary Statistics')"
1194
-
1195
- # 5. inject image-export block, then table block, after the plot
1196
- patched = (
1197
- lines[:last_plot+1]
1198
- + img_block.splitlines()
1199
- + tbl_block.splitlines()
1200
- + lines[last_plot+1:]
1201
- )
1202
- patched_code = "\n".join(patched)
1203
- # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
1204
- return textwrap.dedent(patched_code)
1205
-
1206
-
1207
- def refine_eda_question(raw_question, df=None, max_points=1000):
1208
- """
1209
- Rewrites user's EDA question to avoid classic mistakes:
1210
- - For line plots and scatter: recommend aggregation or sampling if large.
1211
- - For histograms/bar: clarify which variable to plot and bin count.
1212
- - For correlation: suggest a heatmap.
1213
- - For counts: direct request for df.shape[0].
1214
- df (optional): pass DataFrame for column inspection.
1215
- """
1216
-
1217
- # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
1218
- pc = re.match(
1219
- r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
1220
- raw_question, re.I
1221
- )
1222
- if pc:
1223
- col1, col2 = pc.group(1), pc.group(3)
1224
- # Return an instruction that preserves the exact intent
1225
- return (
1226
- f"Compute the Pearson correlation coefficient (r) and p-value "
1227
- f"between {col1} and {col2}. "
1228
- f"Print a short interpretation."
1229
- )
1230
- # -----------------------------------------------------------------
1231
- # ── Detect "predict <column>" intent ──────────────────────────────
1232
- c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
1233
- if c:
1234
- target = c.group(1)
1235
- raw_question += (
1236
- f" IMPORTANT: do NOT recreate or overwrite the existing target column "
1237
- f"“{target}”. Use it as-is for y = df['{target}']."
1238
- )
1239
-
1240
- q = raw_question.strip()
1241
- # REMOVE explicit summary-table instructions
1242
- # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
1243
- q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
1244
- q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
1245
-
1246
- ql = q.lower()
1247
-
1248
- # ── NEW: if the text contains an exact column name, leave it alone ──
1249
- if df is not None:
1250
- for col in df.columns:
1251
- if col.lower() in ql:
1252
- return q
1253
-
1254
- modelling_keywords = (
1255
- "random forest", "gradient-boost", "tree-based model",
1256
- "feature importance", "feature importances",
1257
- "overall accuracy", "train a model", "predict "
1258
- )
1259
- if any(k in ql for k in modelling_keywords):
1260
- return q
1261
-
1262
- # 1. Line plots: average if plotting raw numeric vs numeric
1263
- if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
1264
- match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
1265
- if match:
1266
- y, _, x = match.groups()
1267
- return f"Show me the average {y} by {x} as a line plot."
1268
-
1269
- # 2. Scatter plots: sample if too large
1270
- if "scatter" in ql or "scatter plot" in ql:
1271
- if df is not None and df.shape[0] > max_points:
1272
- return q + " (use only a random sample of 1000 points to avoid overplotting)"
1273
- else:
1274
- return q
1275
-
1276
- # 3. Histogram: specify bins and column
1277
- if "histogram" in ql:
1278
- match = re.search(r'histogram of ([\w_]+)', ql)
1279
- if match:
1280
- col = match.group(1)
1281
- return f"Show me a histogram of {col} using 20 bins."
1282
-
1283
- # Special case: histogram for column with most missing values
1284
- if "most missing" in ql:
1285
- return (
1286
- "Show a histogram for the column with the most missing values. "
1287
- "First, select the column using: "
1288
- "column_with_most_missing = df.isnull().sum().idxmax(); "
1289
- "then plot its histogram with: "
1290
- "df[column_with_most_missing].hist()"
1291
- )
1292
-
1293
- # 4. Bar plot: show top N
1294
- if "bar plot" in ql or "bar chart" in ql:
1295
- match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
1296
- if match:
1297
- col = match.group(2)
1298
- return f"Show me a bar plot of the top 10 {col} values."
1299
-
1300
- # 5. Correlation or heatmap
1301
- if "correlation" in ql:
1302
- return (
1303
- "Show a correlation heatmap for all numeric columns only. "
1304
- "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
1305
- )
1306
-
1307
-
1308
- # 6. Counts/size
1309
- if "how many record" in ql or "row count" in ql or "number of rows" in ql:
1310
- return "How many rows are in the dataset?"
1311
-
1312
- # 7. General best-practices fallback: add axis labels/titles
1313
- if "plot" in ql:
1314
- return q + " (make sure the axes are labeled and the plot is readable)"
1214
+ # # 3️⃣ grid-search results
1215
+ # elif "GridSearchCV(" in code:
1216
+ # tbl_block = textwrap.dedent("""
1217
+ # # build tidy CV-results table
1218
+ # cv_df = (
1219
+ # pd.DataFrame(grid_search.cv_results_)
1220
+ # .loc[:, ['param_n_estimators', 'param_max_depth',
1221
+ # 'mean_test_score', 'std_test_score']]
1222
+ # .rename(columns={
1223
+ # 'param_n_estimators': 'n_estimators',
1224
+ # 'param_max_depth': 'max_depth',
1225
+ # 'mean_test_score': 'mean_cv_accuracy',
1226
+ # 'std_test_score': 'std'
1227
+ # })
1228
+ # .sort_values('mean_cv_accuracy', ascending=False)
1229
+ # .reset_index(drop=True)
1230
+ # )
1231
+ # summary_table = cv_df
1232
+ # """).strip() + "\n"
1233
+ # else:
1234
+ # tbl_block = (
1235
+ # "summary_table = (\n"
1236
+ # " df.describe().T[['count','mean','std','min','50%','max']]\n"
1237
+ # ")\n"
1238
+ # )
1239
+
1240
+ # tbl_block += "show(summary_table, title='Summary Statistics')"
1241
+
1242
+ # # 5. inject image-export block, then table block, after the plot
1243
+ # patched = (
1244
+ # lines[:last_plot+1]
1245
+ # + img_block.splitlines()
1246
+ # + tbl_block.splitlines()
1247
+ # + lines[last_plot+1:]
1248
+ # )
1249
+ # patched_code = "\n".join(patched)
1250
+ # # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
1251
+ # return textwrap.dedent(patched_code)
1252
+
1253
+
1254
+ # def refine_eda_question(raw_question, df=None, max_points=1000):
1255
+ # """
1256
+ # Rewrites user's EDA question to avoid classic mistakes:
1257
+ # - For line plots and scatter: recommend aggregation or sampling if large.
1258
+ # - For histograms/bar: clarify which variable to plot and bin count.
1259
+ # - For correlation: suggest a heatmap.
1260
+ # - For counts: direct request for df.shape[0].
1261
+ # df (optional): pass DataFrame for column inspection.
1262
+ # """
1263
+
1264
+ # # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
1265
+ # pc = re.match(
1266
+ # r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
1267
+ # raw_question, re.I
1268
+ # )
1269
+ # if pc:
1270
+ # col1, col2 = pc.group(1), pc.group(3)
1271
+ # # Return an instruction that preserves the exact intent
1272
+ # return (
1273
+ # f"Compute the Pearson correlation coefficient (r) and p-value "
1274
+ # f"between {col1} and {col2}. "
1275
+ # f"Print a short interpretation."
1276
+ # )
1277
+ # # -----------------------------------------------------------------
1278
+ # # ── Detect "predict <column>" intent ──────────────────────────────
1279
+ # c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
1280
+ # if c:
1281
+ # target = c.group(1)
1282
+ # raw_question += (
1283
+ # f" IMPORTANT: do NOT recreate or overwrite the existing target column "
1284
+ # f"“{target}”. Use it as-is for y = df['{target}']."
1285
+ # )
1286
+
1287
+ # q = raw_question.strip()
1288
+ # # REMOVE explicit summary-table instructions
1289
+ # # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
1290
+ # q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
1291
+ # q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
1292
+
1293
+ # ql = q.lower()
1294
+
1295
+ # # ── NEW: if the text contains an exact column name, leave it alone ──
1296
+ # if df is not None:
1297
+ # for col in df.columns:
1298
+ # if col.lower() in ql:
1299
+ # return q
1300
+
1301
+ # modelling_keywords = (
1302
+ # "random forest", "gradient-boost", "tree-based model",
1303
+ # "feature importance", "feature importances",
1304
+ # "overall accuracy", "train a model", "predict "
1305
+ # )
1306
+ # if any(k in ql for k in modelling_keywords):
1307
+ # return q
1308
+
1309
+ # # 1. Line plots: average if plotting raw numeric vs numeric
1310
+ # if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
1311
+ # match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
1312
+ # if match:
1313
+ # y, _, x = match.groups()
1314
+ # return f"Show me the average {y} by {x} as a line plot."
1315
+
1316
+ # # 2. Scatter plots: sample if too large
1317
+ # if "scatter" in ql or "scatter plot" in ql:
1318
+ # if df is not None and df.shape[0] > max_points:
1319
+ # return q + " (use only a random sample of 1000 points to avoid overplotting)"
1320
+ # else:
1321
+ # return q
1322
+
1323
+ # # 3. Histogram: specify bins and column
1324
+ # if "histogram" in ql:
1325
+ # match = re.search(r'histogram of ([\w_]+)', ql)
1326
+ # if match:
1327
+ # col = match.group(1)
1328
+ # return f"Show me a histogram of {col} using 20 bins."
1329
+
1330
+ # # Special case: histogram for column with most missing values
1331
+ # if "most missing" in ql:
1332
+ # return (
1333
+ # "Show a histogram for the column with the most missing values. "
1334
+ # "First, select the column using: "
1335
+ # "column_with_most_missing = df.isnull().sum().idxmax(); "
1336
+ # "then plot its histogram with: "
1337
+ # "df[column_with_most_missing].hist()"
1338
+ # )
1339
+
1340
+ # # 4. Bar plot: show top N
1341
+ # if "bar plot" in ql or "bar chart" in ql:
1342
+ # match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
1343
+ # if match:
1344
+ # col = match.group(2)
1345
+ # return f"Show me a bar plot of the top 10 {col} values."
1346
+
1347
+ # # 5. Correlation or heatmap
1348
+ # if "correlation" in ql:
1349
+ # return (
1350
+ # "Show a correlation heatmap for all numeric columns only. "
1351
+ # "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
1352
+ # )
1353
+
1354
+
1355
+ # # 6. Counts/size
1356
+ # if "how many record" in ql or "row count" in ql or "number of rows" in ql:
1357
+ # return "How many rows are in the dataset?"
1358
+
1359
+ # # 7. General best-practices fallback: add axis labels/titles
1360
+ # if "plot" in ql:
1361
+ # return q + " (make sure the axes are labeled and the plot is readable)"
1315
1362
 
1316
- # 8.
1317
- if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
1318
- match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
1319
- col = match.group(1) if match else None
1320
- if col:
1321
- return (
1322
- f"Show a bar plot of the counts of {col} using: "
1323
- f"df['{col}'].value_counts().plot(kind='bar'); "
1324
- "add axis labels and a title, then plt.show()."
1325
- )
1363
+ # # 8.
1364
+ # if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
1365
+ # match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
1366
+ # col = match.group(1) if match else None
1367
+ # if col:
1368
+ # return (
1369
+ # f"Show a bar plot of the counts of {col} using: "
1370
+ # f"df['{col}'].value_counts().plot(kind='bar'); "
1371
+ # "add axis labels and a title, then plt.show()."
1372
+ # )
1326
1373
 
1327
- if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
1328
- return (
1329
- "Show a table of the mean, median, and standard deviation for all numeric columns. "
1330
- "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
1331
- )
1374
+ # if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
1375
+ # return (
1376
+ # "Show a table of the mean, median, and standard deviation for all numeric columns. "
1377
+ # "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
1378
+ # )
1332
1379
 
1333
- # 9. Fallback: return the raw question
1334
- return q
1380
+ # # 9. Fallback: return the raw question
1381
+ # return q
1335
1382
 
1336
1383
 
1337
- def patch_plot_code(code, df, user_question=None):
1384
+ # def patch_plot_code(code, df, user_question=None):
1338
1385
 
1339
- # ── Early guard: abort nicely if the generated code references columns that
1340
- # do not exist in the DataFrame. This prevents KeyError crashes.
1341
- import re
1386
+ # # ── Early guard: abort nicely if the generated code references columns that
1387
+ # # do not exist in the DataFrame. This prevents KeyError crashes.
1388
+ # import re
1342
1389
 
1343
1390
 
1344
- # ── Detect columns referenced in the code ──────────────────────────
1345
- col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
1346
-
1347
- # Columns that will be newly CREATED (appear left of '=')
1348
- new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
1349
-
1350
- missing_cols = [
1351
- col for col in col_refs
1352
- if col not in df.columns and col not in new_cols
1353
- ]
1354
-
1355
- if missing_cols:
1356
- cols_list = ", ".join(missing_cols)
1357
- warning = (
1358
- f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1359
- "These must either exist in df or be created earlier in the code; "
1360
- "otherwise you may see a KeyError.')\n"
1361
- )
1362
- # Prepend the warning but keep the original code so it can still run
1363
- code = warning + code
1391
+ # # ── Detect columns referenced in the code ──────────────────────────
1392
+ # col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
1393
+
1394
+ # # Columns that will be newly CREATED (appear left of '=')
1395
+ # new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
1396
+
1397
+ # missing_cols = [
1398
+ # col for col in col_refs
1399
+ # if col not in df.columns and col not in new_cols
1400
+ # ]
1401
+
1402
+ # if missing_cols:
1403
+ # cols_list = ", ".join(missing_cols)
1404
+ # warning = (
1405
+ # f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1406
+ # "These must either exist in df or be created earlier in the code; "
1407
+ # "otherwise you may see a KeyError.')\n"
1408
+ # )
1409
+ # # Prepend the warning but keep the original code so it can still run
1410
+ # code = warning + code
1364
1411
 
1365
- # 1. For line plots (auto-aggregate)
1366
- m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1367
- if m_l:
1368
- x, y = m_l.groups()
1369
- if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
1370
- return (
1371
- f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
1372
- f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
1373
- f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
1374
- )
1375
-
1376
- # 2. For scatter plots: sample to 1000 points max
1377
- m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1378
- if m_s:
1379
- x, y = m_s.groups()
1380
- if len(df) > 1000:
1381
- return (
1382
- f"samp = df.sample(1000, random_state=42)\n"
1383
- f"plt.scatter(samp['{x}'], samp['{y}'])\n"
1384
- f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
1385
- )
1386
-
1387
- # 3. For histograms: use bins=20 for numeric, value_counts for categorical
1388
- m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
1389
- if m_h:
1390
- col = m_h.group(1)
1391
- if pd.api.types.is_numeric_dtype(df[col]):
1392
- return (
1393
- f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
1394
- f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
1395
- )
1396
- else:
1397
- # If categorical, show bar plot of value counts
1398
- return (
1399
- f"df['{col}'].value_counts().plot(kind='bar')\n"
1400
- f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
1401
- )
1402
-
1403
- # 4. For bar plots: show only top 20
1404
- m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
1405
- if m_b:
1406
- col = m_b.group(1)
1407
- if df[col].nunique() > 20:
1408
- return (
1409
- f"topN = df['{col}'].value_counts().head(20)\n"
1410
- f"topN.plot(kind='bar')\n"
1411
- f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
1412
- )
1413
-
1414
- # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
1415
- if "df.plot" in code and len(df) > 10000:
1416
- return (
1417
- f"samp = df.sample(1000, random_state=42)\n"
1418
- + code.replace("df.", "samp.")
1419
- )
1412
+ # # 1. For line plots (auto-aggregate)
1413
+ # m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1414
+ # if m_l:
1415
+ # x, y = m_l.groups()
1416
+ # if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
1417
+ # return (
1418
+ # f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
1419
+ # f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
1420
+ # f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
1421
+ # )
1422
+
1423
+ # # 2. For scatter plots: sample to 1000 points max
1424
+ # m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1425
+ # if m_s:
1426
+ # x, y = m_s.groups()
1427
+ # if len(df) > 1000:
1428
+ # return (
1429
+ # f"samp = df.sample(1000, random_state=42)\n"
1430
+ # f"plt.scatter(samp['{x}'], samp['{y}'])\n"
1431
+ # f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
1432
+ # )
1433
+
1434
+ # # 3. For histograms: use bins=20 for numeric, value_counts for categorical
1435
+ # m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
1436
+ # if m_h:
1437
+ # col = m_h.group(1)
1438
+ # if pd.api.types.is_numeric_dtype(df[col]):
1439
+ # return (
1440
+ # f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
1441
+ # f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
1442
+ # )
1443
+ # else:
1444
+ # # If categorical, show bar plot of value counts
1445
+ # return (
1446
+ # f"df['{col}'].value_counts().plot(kind='bar')\n"
1447
+ # f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
1448
+ # )
1449
+
1450
+ # # 4. For bar plots: show only top 20
1451
+ # m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
1452
+ # if m_b:
1453
+ # col = m_b.group(1)
1454
+ # if df[col].nunique() > 20:
1455
+ # return (
1456
+ # f"topN = df['{col}'].value_counts().head(20)\n"
1457
+ # f"topN.plot(kind='bar')\n"
1458
+ # f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
1459
+ # )
1460
+
1461
+ # # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
1462
+ # if "df.plot" in code and len(df) > 10000:
1463
+ # return (
1464
+ # f"samp = df.sample(1000, random_state=42)\n"
1465
+ # + code.replace("df.", "samp.")
1466
+ # )
1420
1467
 
1421
- # ── Block assignment to an existing target column ────────────────
1422
- #*******************************************************
1423
- target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
1424
- if target_match:
1425
- target = target_match.group(1)
1426
-
1427
- # pattern for an assignment to that target
1428
- assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
1429
- assign_line = re.search(assign_pat + r".*", code)
1430
- if assign_line:
1431
- # runtime check: keep the assignment **only if** the column is absent
1432
- guard = (
1433
- f"if '{target}' in df.columns:\n"
1434
- f" print('⚠️ {target} already exists – overwrite skipped.');\n"
1435
- f"else:\n"
1436
- f" {assign_line.group(0)}"
1437
- )
1438
- # remove original assignment line and insert guarded block
1439
- code = code.replace(assign_line.group(0), guard, 1)
1440
- # ***************************************************
1468
+ # # ── Block assignment to an existing target column ────────────────
1469
+ # #*******************************************************
1470
+ # target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
1471
+ # if target_match:
1472
+ # target = target_match.group(1)
1473
+
1474
+ # # pattern for an assignment to that target
1475
+ # assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
1476
+ # assign_line = re.search(assign_pat + r".*", code)
1477
+ # if assign_line:
1478
+ # # runtime check: keep the assignment **only if** the column is absent
1479
+ # guard = (
1480
+ # f"if '{target}' in df.columns:\n"
1481
+ # f" print('⚠️ {target} already exists – overwrite skipped.');\n"
1482
+ # f"else:\n"
1483
+ # f" {assign_line.group(0)}"
1484
+ # )
1485
+ # # remove original assignment line and insert guarded block
1486
+ # code = code.replace(assign_line.group(0), guard, 1)
1487
+ # # ***************************************************
1441
1488
 
1442
- # 6. Grouped bar plot for two categoricals
1443
- # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
1444
- if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
1445
- # Try to infer columns from user question if possible:
1446
- group, cat = None, None
1447
- if user_question:
1448
- # crude parse for "counts of X for each Y"
1449
- m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
1450
- if m:
1451
- cat, group = m.groups()
1452
- if not (cat and group):
1453
- # fallback: use two most frequent categoricals
1454
- categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
1455
- if len(categoricals) >= 2:
1456
- cat, group = categoricals[:2]
1457
- else:
1458
- # fallback: any
1459
- cat, group = df.columns[:2]
1460
- return (
1461
- f"import pandas as pd\n"
1462
- f"import matplotlib.pyplot as plt\n"
1463
- f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
1464
- f"ct.plot(kind='bar')\n"
1465
- f"plt.title('Counts of {cat} for each {group}')\n"
1466
- f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
1467
- )
1468
-
1469
- # Fallback: Return original code
1470
- return code
1471
-
1472
-
1473
- def ensure_matplotlib_title(code, title_var="refined_question"):
1474
- import re
1475
- makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1476
- has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1477
- if makes_plot and not has_title:
1478
- code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1479
- return code
1489
+ # # 6. Grouped bar plot for two categoricals
1490
+ # # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
1491
+ # if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
1492
+ # # Try to infer columns from user question if possible:
1493
+ # group, cat = None, None
1494
+ # if user_question:
1495
+ # # crude parse for "counts of X for each Y"
1496
+ # m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
1497
+ # if m:
1498
+ # cat, group = m.groups()
1499
+ # if not (cat and group):
1500
+ # # fallback: use two most frequent categoricals
1501
+ # categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
1502
+ # if len(categoricals) >= 2:
1503
+ # cat, group = categoricals[:2]
1504
+ # else:
1505
+ # # fallback: any
1506
+ # cat, group = df.columns[:2]
1507
+ # return (
1508
+ # f"import pandas as pd\n"
1509
+ # f"import matplotlib.pyplot as plt\n"
1510
+ # f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
1511
+ # f"ct.plot(kind='bar')\n"
1512
+ # f"plt.title('Counts of {cat} for each {group}')\n"
1513
+ # f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
1514
+ # )
1515
+
1516
+ # # Fallback: Return original code
1517
+ # return code
1518
+
1519
+
1520
+ # def ensure_matplotlib_title(code, title_var="refined_question"):
1521
+ # import re
1522
+ # makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1523
+ # has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1524
+ # if makes_plot and not has_title:
1525
+ # code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1526
+ # return code
1480
1527
 
1481
1528
 
1482
1529
  def ensure_output(code: str) -> str:
@@ -1555,1548 +1602,1455 @@ def ensure_output(code: str) -> str:
1555
1602
  return "\n".join(lines)
1556
1603
 
1557
1604
 
1558
- def get_plotting_imports(code):
1559
- imports = []
1560
- if "plt." in code and "import matplotlib.pyplot as plt" not in code:
1561
- imports.append("import matplotlib.pyplot as plt")
1562
- if "sns." in code and "import seaborn as sns" not in code:
1563
- imports.append("import seaborn as sns")
1564
- if "px." in code and "import plotly.express as px" not in code:
1565
- imports.append("import plotly.express as px")
1566
- if "pd." in code and "import pandas as pd" not in code:
1567
- imports.append("import pandas as pd")
1568
- if "np." in code and "import numpy as np" not in code:
1569
- imports.append("import numpy as np")
1570
- if "display(" in code and "from IPython.display import display" not in code:
1571
- imports.append("from IPython.display import display")
1572
- # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
1573
- if imports:
1574
- code = "\n".join(imports) + "\n\n" + code
1575
- return code
1576
-
1577
-
1578
- def patch_pairplot(code, df):
1579
- if "sns.pairplot" in code:
1580
- # Always assign and print pairgrid
1581
- code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
1582
- if "plt.show()" not in code:
1583
- code += "\nplt.show()"
1584
- if "print(pairgrid)" not in code:
1585
- code += "\nprint(pairgrid)"
1586
- return code
1587
-
1588
-
1589
- def ensure_image_output(code: str) -> str:
1590
- """
1591
- Replace each plt.show() with an indented _SMX_export_png() call.
1592
- This keeps block indentation valid and still renders images in the dashboard.
1593
- """
1594
- if "plt.show()" not in code:
1595
- return code
1605
+ # def get_plotting_imports(code):
1606
+ # imports = []
1607
+ # if "plt." in code and "import matplotlib.pyplot as plt" not in code:
1608
+ # imports.append("import matplotlib.pyplot as plt")
1609
+ # if "sns." in code and "import seaborn as sns" not in code:
1610
+ # imports.append("import seaborn as sns")
1611
+ # if "px." in code and "import plotly.express as px" not in code:
1612
+ # imports.append("import plotly.express as px")
1613
+ # if "pd." in code and "import pandas as pd" not in code:
1614
+ # imports.append("import pandas as pd")
1615
+ # if "np." in code and "import numpy as np" not in code:
1616
+ # imports.append("import numpy as np")
1617
+ # if "display(" in code and "from IPython.display import display" not in code:
1618
+ # imports.append("from IPython.display import display")
1619
+ # # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
1620
+ # if imports:
1621
+ # code = "\n".join(imports) + "\n\n" + code
1622
+ # return code
1623
+
1624
+
1625
+ # def patch_pairplot(code, df):
1626
+ # if "sns.pairplot" in code:
1627
+ # # Always assign and print pairgrid
1628
+ # code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
1629
+ # if "plt.show()" not in code:
1630
+ # code += "\nplt.show()"
1631
+ # if "print(pairgrid)" not in code:
1632
+ # code += "\nprint(pairgrid)"
1633
+ # return code
1634
+
1635
+
1636
+ # def ensure_image_output(code: str) -> str:
1637
+ # """
1638
+ # Replace each plt.show() with an indented _SMX_export_png() call.
1639
+ # This keeps block indentation valid and still renders images in the dashboard.
1640
+ # """
1641
+ # if "plt.show()" not in code:
1642
+ # return code
1643
+
1644
+ # import re
1645
+ # out_lines = []
1646
+ # for ln in code.splitlines():
1647
+ # if "plt.show()" not in ln:
1648
+ # out_lines.append(ln)
1649
+ # continue
1650
+
1651
+ # # works for:
1652
+ # # plt.show()
1653
+ # # plt.tight_layout(); plt.show()
1654
+ # # ... ; plt.show(); ... (multiple on one line)
1655
+ # indent = re.match(r"^(\s*)", ln).group(1)
1656
+ # parts = ln.split("plt.show()")
1657
+
1658
+ # # keep whatever is before the first plt.show()
1659
+ # if parts[0].strip():
1660
+ # out_lines.append(parts[0].rstrip())
1661
+
1662
+ # # for every plt.show() we removed, insert exporter at same indent
1663
+ # for _ in range(len(parts) - 1):
1664
+ # out_lines.append(indent + "_SMX_export_png()")
1665
+
1666
+ # # keep whatever comes after the last plt.show()
1667
+ # if parts[-1].strip():
1668
+ # out_lines.append(indent + parts[-1].lstrip())
1669
+
1670
+ # return "\n".join(out_lines)
1671
+
1672
+
1673
+ # def clean_llm_code(code: str) -> str:
1674
+ # """
1675
+ # Make LLM output safe to exec:
1676
+ # - If fenced blocks exist, keep the largest one (usually the real code).
1677
+ # - Otherwise strip any stray ``` / ```python lines.
1678
+ # - Remove common markdown/preamble junk.
1679
+ # """
1680
+ # code = str(code or "")
1681
+
1682
+ # # Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
1683
+ # # is accidentally passed here as `code`. In that case, extract the actual
1684
+ # # Python code from the ChatCompletionMessage(content=...) field.
1685
+ # if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
1686
+ # try:
1687
+ # extracted = None
1688
+
1689
+ # class _ChatCompletionVisitor(ast.NodeVisitor):
1690
+ # def visit_Call(self, node):
1691
+ # nonlocal extracted
1692
+ # func = node.func
1693
+ # fname = getattr(func, "id", None) or getattr(func, "attr", None)
1694
+ # if fname == "ChatCompletionMessage":
1695
+ # for kw in node.keywords:
1696
+ # if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
1697
+ # extracted = kw.value.value
1698
+ # self.generic_visit(node)
1699
+
1700
+ # tree = ast.parse(code, mode="exec")
1701
+ # _ChatCompletionVisitor().visit(tree)
1702
+ # if extracted:
1703
+ # code = extracted
1704
+ # except Exception:
1705
+ # # Best-effort regex fallback if AST parsing fails
1706
+ # m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
1707
+ # if m:
1708
+ # code = m.group("body")
1709
+
1710
+ # # Existing logic continues unchanged below...
1711
+ # # Extract fenced blocks (```python ... ``` or ``` ... ```)
1712
+ # blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1713
+
1714
+ # if blocks:
1715
+ # # pick the largest block; small trailing blocks are usually garbage
1716
+ # largest = max(blocks, key=lambda b: len(b.strip()))
1717
+ # if len(largest.strip().splitlines()) >= 10:
1718
+ # code = largest
1719
+
1720
+ # # Extract fenced blocks (```python ... ``` or ``` ... ```)
1721
+ # blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1722
+
1723
+ # if blocks:
1724
+ # # pick the largest block; small trailing blocks are usually garbage
1725
+ # largest = max(blocks, key=lambda b: len(b.strip()))
1726
+ # if len(largest.strip().splitlines()) >= 10:
1727
+ # code = largest
1728
+ # else:
1729
+ # # if no meaningful block, just remove fence markers
1730
+ # code = re.sub(r"^```.*?$", "", code, flags=re.M)
1731
+ # else:
1732
+ # # no complete blocks — still remove any stray fence lines
1733
+ # code = re.sub(r"^```.*?$", "", code, flags=re.M)
1734
+
1735
+ # # Strip common markdown/preamble lines
1736
+ # drop_prefixes = (
1737
+ # "here is", "here's", "below is", "sure,", "certainly",
1738
+ # "explanation", "note:", "```"
1739
+ # )
1740
+ # cleaned_lines = []
1741
+ # for ln in code.splitlines():
1742
+ # s = ln.strip().lower()
1743
+ # if any(s.startswith(p) for p in drop_prefixes):
1744
+ # continue
1745
+ # cleaned_lines.append(ln)
1746
+
1747
+ # return "\n".join(cleaned_lines).strip()
1748
+
1749
+
1750
+ # def fix_groupby_describe_slice(code: str) -> str:
1751
+ # """
1752
+ # Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
1753
+ # so it works for both SeriesGroupBy and DataFrameGroupBy.
1754
+ # """
1755
+ # pat = re.compile(
1756
+ # r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
1757
+ # re.MULTILINE
1758
+ # )
1759
+ # def repl(match):
1760
+ # inner = match.group(0)
1761
+ # # extract group and feature to build df.groupby('g')['f']
1762
+ # g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
1763
+ # f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
1764
+ # return (
1765
+ # f"df.groupby('{g}')['{f}']"
1766
+ # ".agg(['count','mean','std','min','median','max'])"
1767
+ # ".rename(columns={'median':'50%'})"
1768
+ # )
1769
+ # return pat.sub(repl, code)
1770
+
1771
+
1772
+ # def fix_importance_groupby(code: str) -> str:
1773
+ # pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
1774
+ # if "importance_df" in code:
1775
+ # return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
1776
+ # return code
1777
+
1778
+ # def inject_auto_preprocessing(code: str) -> str:
1779
+ # """
1780
+ # • Detects a RandomForestClassifier in the generated code.
1781
+ # • Finds the target column from `y = df['target']`.
1782
+ # • Prepends a fully-dedented preprocessing snippet that:
1783
+ # – auto-detects numeric & categorical columns
1784
+ # – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
1785
+ # The dedent() call guarantees no leading-space IndentationError.
1786
+ # """
1787
+ # if "RandomForestClassifier" not in code:
1788
+ # return code # nothing to patch
1789
+
1790
+ # y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
1791
+ # if not y_match:
1792
+ # return code # can't infer target safely
1793
+ # target = y_match.group(1)
1794
+
1795
+ # prep_snippet = textwrap.dedent(f"""
1796
+ # # ── automatic preprocessing ───────────────────────────────
1797
+ # num_cols = df.select_dtypes(include=['number']).columns.tolist()
1798
+ # cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
1799
+ # num_cols = [c for c in num_cols if c != '{target}']
1800
+ # cat_cols = [c for c in cat_cols if c != '{target}']
1801
+
1802
+ # from sklearn.compose import ColumnTransformer
1803
+ # from sklearn.preprocessing import StandardScaler, OneHotEncoder
1804
+
1805
+ # preproc = ColumnTransformer(
1806
+ # transformers=[
1807
+ # ('num', StandardScaler(), num_cols),
1808
+ # ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
1809
+ # ],
1810
+ # remainder='drop',
1811
+ # )
1812
+ # # ───────────────────────────────────────────────────────────
1813
+ # """).strip() + "\n\n"
1814
+
1815
+ # # simply prepend; model code that follows can wrap estimator in a Pipeline
1816
+ # return prep_snippet + code
1817
+
1818
+
1819
+ # def fix_to_datetime_errors(code: str) -> str:
1820
+ # """
1821
+ # Force every pd.to_datetime(…) call to ignore bad dates so that
1822
+
1823
+ # 'year 16500 is out of range' and similar issues don’t crash runs.
1824
+ # """
1825
+ # import re
1826
+ # # look for any pd.to_datetime( … )
1827
+ # pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
1828
+ # def repl(m):
1829
+ # inside = m.group(1)
1830
+ # # if the call already has errors=, leave it unchanged
1831
+ # if "errors=" in inside:
1832
+ # return m.group(0)
1833
+ # return f"pd.to_datetime({inside}, errors='coerce')"
1834
+ # return pat.sub(repl, code)
1835
+
1836
+
1837
+ # def fix_numeric_sum(code: str) -> str:
1838
+ # """
1839
+ # Make .sum(...) code safe across pandas versions by removing any
1840
+ # numeric_only=... argument (True/False/None) from function calls.
1841
+
1842
+ # This avoids errors on pandas versions where numeric_only is not
1843
+ # supported for Series/grouped sums, and we rely instead on explicit
1844
+ # numeric column selection (e.g. select_dtypes) in the generated code.
1845
+ # """
1846
+ # # Case 1: ..., numeric_only=True/False/None
1847
+ # code = re.sub(
1848
+ # r",\s*numeric_only\s*=\s*(True|False|None)",
1849
+ # "",
1850
+ # code,
1851
+ # flags=re.IGNORECASE,
1852
+ # )
1853
+
1854
+ # # Case 2: numeric_only=True/False/None, ... (as first argument)
1855
+ # code = re.sub(
1856
+ # r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1857
+ # "",
1858
+ # code,
1859
+ # flags=re.IGNORECASE,
1860
+ # )
1861
+
1862
+ # # Case 3: numeric_only=True/False/None (only argument)
1863
+ # code = re.sub(
1864
+ # r"numeric_only\s*=\s*(True|False|None)",
1865
+ # "",
1866
+ # code,
1867
+ # flags=re.IGNORECASE,
1868
+ # )
1869
+
1870
+ # return code
1871
+
1872
+
1873
+ # def fix_concat_empty_list(code: str) -> str:
1874
+ # """
1875
+ # Make pd.concat calls resilient to empty lists of objects.
1876
+
1877
+ # Transforms patterns like:
1878
+ # pd.concat(frames, ignore_index=True)
1879
+ # pd.concat(frames)
1880
+
1881
+ # into:
1882
+ # pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1883
+ # pd.concat(frames or [pd.DataFrame()])
1884
+
1885
+ # Only triggers when the first argument is a simple variable name.
1886
+ # """
1887
+ # pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1888
+
1889
+ # def _repl(m):
1890
+ # name = m.group(1)
1891
+ # sep = m.group(2) # ',' or ')'
1892
+ # return f"pd.concat({name} or [pd.DataFrame()]{sep}"
1893
+
1894
+ # return pattern.sub(_repl, code)
1895
+
1896
+
1897
+ # def fix_numeric_aggs(code: str) -> str:
1898
+ # _AGG_FUNCS = ("sum", "mean")
1899
+ # pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
1900
+ # def _repl(m):
1901
+ # func, args = m.group(1), m.group(2) or ""
1902
+ # if "numeric_only" in args:
1903
+ # return m.group(0)
1904
+ # args = args.rstrip()
1905
+ # if args:
1906
+ # args += ", "
1907
+ # return f".{func}({args}numeric_only=True)"
1908
+ # return pat.sub(_repl, code)
1909
+
1910
+
1911
+ # def ensure_accuracy_block(code: str) -> str:
1912
+ # """
1913
+ # Inject a sensible evaluation block right after the last `<est>.fit(...)`
1914
+ # Classification → accuracy + weighted F1
1915
+ # Regression → R², RMSE, MAE
1916
+ # Heuristic: infer task from estimator names present in the code.
1917
+ # """
1918
+ # import re, textwrap
1919
+
1920
+ # # If any proper metric already exists, do nothing
1921
+ # if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
1922
+ # return code
1923
+
1924
+ # # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
1925
+ # m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
1926
+ # if not m:
1927
+ # return code # no estimator
1928
+
1929
+ # var = m[-1].group(1)
1930
+ # # indent with same leading whitespace used on that line
1931
+ # indent = re.match(r"\s*", code[m[-1].start():]).group(0)
1932
+
1933
+ # # Detect regression by estimator names / hints in code
1934
+ # is_regression = bool(
1935
+ # re.search(
1936
+ # r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1937
+ # r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1938
+ # r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1939
+ # )
1940
+ # or re.search(r"\bOLS\s*\(", code)
1941
+ # or re.search(r"\bRegressor\b", code)
1942
+ # )
1943
+
1944
+ # if is_regression:
1945
+ # # inject numpy import if needed for RMSE
1946
+ # if "import numpy as np" not in code and "np." not in code:
1947
+ # code = "import numpy as np\n" + code
1948
+ # eval_block = textwrap.dedent(f"""
1949
+ # {indent}# ── automatic regression evaluation ─────────
1950
+ # {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
1951
+ # {indent}y_pred = {var}.predict(X_test)
1952
+ # {indent}r2 = r2_score(y_test, y_pred)
1953
+ # {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
1954
+ # {indent}mae = float(mean_absolute_error(y_test, y_pred))
1955
+ # {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
1956
+ # """)
1957
+ # else:
1958
+ # eval_block = textwrap.dedent(f"""
1959
+ # {indent}# ── automatic classification evaluation ─────────
1960
+ # {indent}from sklearn.metrics import accuracy_score, f1_score
1961
+ # {indent}y_pred = {var}.predict(X_test)
1962
+ # {indent}acc = accuracy_score(y_test, y_pred)
1963
+ # {indent}f1 = f1_score(y_test, y_pred, average='weighted')
1964
+ # {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
1965
+ # """)
1966
+
1967
+ # insert_at = code.find("\n", m[-1].end()) + 1
1968
+ # return code[:insert_at] + eval_block + code[insert_at:]
1969
+
1970
+
1971
+ # def fix_scatter_and_summary(code: str) -> str:
1972
+ # """
1973
+ # 1. Change cmap='spectral' (any case) → cmap='Spectral'
1974
+ # 2. If the LLM forgets to close the parenthesis in
1975
+ # summary_table = ( df.describe()... <missing )>
1976
+ # insert the ')' right before the next 'from' or 'show('.
1977
+ # """
1978
+ # # 1️⃣ colormap case
1979
+ # code = re.sub(
1980
+ # r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
1981
+ # "cmap='Spectral'",
1982
+ # code,
1983
+ # flags=re.IGNORECASE,
1984
+ # )
1985
+
1986
+ # # 2️⃣ close summary_table = ( ... )
1987
+ # code = re.sub(
1988
+ # r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
1989
+ # r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
1990
+ # r"\1)", # keep group 1 and add ')'
1991
+ # code,
1992
+ # flags=re.MULTILINE,
1993
+ # )
1994
+
1995
+ # return code
1996
+
1997
+
1998
+ # def auto_format_with_black(code: str) -> str:
1999
+ # """
2000
+ # Format the generated code with Black. Falls back silently if Black
2001
+ # is missing or raises (so the dashboard never 500s).
2002
+ # """
2003
+ # try:
2004
+ # import black # make sure black is in your v-env: pip install black
2005
+
2006
+ # mode = black.FileMode() # default settings
2007
+ # return black.format_str(code, mode=mode)
2008
+
2009
+ # except Exception:
2010
+ # return code
2011
+
2012
+
2013
+ # def ensure_preproc_in_pipeline(code: str) -> str:
2014
+ # """
2015
+ # If code defines `preproc = ColumnTransformer(...)` but then builds
2016
+ # `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
2017
+ # that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
2018
+ # """
2019
+ # return re.sub(
2020
+ # r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
2021
+ # "Pipeline([('prep', preproc)",
2022
+ # code
2023
+ # )
2024
+
2025
+ # def drop_bad_classification_metrics(code: str, y_or_df) -> str:
2026
+ # """
2027
+ # Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
2028
+ # if the generated cell is *regression*. We infer this from:
2029
+ # 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
2030
+ # 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
2031
+ # Safe across datasets and queries.
2032
+ # """
2033
+ # import re
2034
+ # import pandas as pd
2035
+
2036
+ # # 1) Heuristic by estimator names in the *code* (fast path)
2037
+ # regression_by_model = bool(re.search(
2038
+ # r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
2039
+ # r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
2040
+ # r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
2041
+ # ) or re.search(r"\bOLS\s*\(", code))
2042
+
2043
+ # is_regression = regression_by_model
2044
+
2045
+ # # 2) If not obvious from the model, try to infer from y dtype (if we can)
2046
+ # if not is_regression:
2047
+ # try:
2048
+ # # Try to parse: y = df['target']
2049
+ # m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
2050
+ # if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
2051
+ # y = y_or_df[m.group(1)]
2052
+ # if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2053
+ # is_regression = True
2054
+ # else:
2055
+ # # If a Series was passed
2056
+ # y = y_or_df
2057
+ # if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2058
+ # is_regression = True
2059
+ # except Exception:
2060
+ # pass
2061
+
2062
+ # if is_regression:
2063
+ # # Strip classification-only lines
2064
+ # for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
2065
+ # code = re.sub(pat, "", code, flags=re.I)
2066
+
2067
+ # return code
2068
+
2069
+
2070
+ # def force_capture_display(code: str) -> str:
2071
+ # """
2072
+ # Ensure our executor captures HTML output:
2073
+ # - Remove any import that would override our 'display' hook.
2074
+ # - Keep/allow importing HTML only.
2075
+ # - Handle alias cases like 'display as d'.
2076
+ # """
2077
+ # import re
2078
+ # new = code
2079
+
2080
+ # # 'from IPython.display import display, HTML' -> keep HTML only
2081
+ # new = re.sub(
2082
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
2083
+ # r"from IPython.display import HTML\1", new
2084
+ # )
2085
+
2086
+ # # 'from IPython.display import display as d' -> 'd = display'
2087
+ # new = re.sub(
2088
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
2089
+ # r"\1 = display", new
2090
+ # )
2091
+
2092
+ # # 'from IPython.display import display' -> remove (use our injected display)
2093
+ # new = re.sub(
2094
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
2095
+ # r"# display import removed (SMX capture active)", new
2096
+ # )
2097
+
2098
+ # # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
2099
+ # new = re.sub(
2100
+ # r"(?m)\bIPython\.display\.display\s*\(",
2101
+ # "display(", new
2102
+ # )
2103
+ # new = re.sub(
2104
+ # r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
2105
+ # r"(?=.*import\s+IPython\.display\s+as\s+\1)",
2106
+ # "display(", new
2107
+ # )
2108
+ # return new
2109
+
2110
+
2111
+ # def strip_matplotlib_show(code: str) -> str:
2112
+ # """Remove blocking plt.show() calls (we export base64 instead)."""
2113
+ # import re
2114
+ # return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
2115
+
2116
+
2117
+ # def inject_display_shim(code: str) -> str:
2118
+ # """
2119
+ # Provide display()/HTML() if missing, forwarding to our executor hook.
2120
+ # Harmless if the names already exist.
2121
+ # """
2122
+ # shim = (
2123
+ # "try:\n"
2124
+ # " display\n"
2125
+ # "except NameError:\n"
2126
+ # " def display(obj=None, **kwargs):\n"
2127
+ # " __builtins__.get('_smx_display', print)(obj)\n"
2128
+ # "try:\n"
2129
+ # " HTML\n"
2130
+ # "except NameError:\n"
2131
+ # " class HTML:\n"
2132
+ # " def __init__(self, data): self.data = str(data)\n"
2133
+ # " def _repr_html_(self): return self.data\n"
2134
+ # "\n"
2135
+ # )
2136
+ # return shim + code
2137
+
2138
+
2139
+ # def strip_spurious_column_tokens(code: str) -> str:
2140
+ # """
2141
+ # Remove common stop-words ('the','whether', ...) when they appear
2142
+ # inside column lists, e.g.:
2143
+ # predictors = ['BMI','the','HbA1c']
2144
+ # df[['GGT','whether','BMI']]
2145
+ # Leaves other strings intact.
2146
+ # """
2147
+ # STOP = {
2148
+ # "the","whether","a","an","and","or","of","to","in","on","for","by",
2149
+ # "with","as","at","from","that","this","these","those","is","are","was","were",
2150
+ # "coef", "Coef", "coefficient", "Coefficient"
2151
+ # }
2152
+
2153
+ # def _norm(s: str) -> str:
2154
+ # return re.sub(r"[^a-z0-9]+", "", s.lower())
2155
+
2156
+ # def _clean_list(content: str) -> str:
2157
+ # # Rebuild a string list, keeping only non-stopword items
2158
+ # items = re.findall(r"(['\"])(.*?)\1", content)
2159
+ # if not items:
2160
+ # return "[" + content + "]"
2161
+ # keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
2162
+ # return "[" + ", ".join(keep) + "]"
2163
+
2164
+ # # Variable assignments: predictors/features/columns/cols = [...]
2165
+ # code = re.sub(
2166
+ # r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
2167
+ # lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
2168
+ # code
2169
+ # )
2170
+
2171
+ # # df[[ ... ]] selections
2172
+ # code = re.sub(
2173
+ # r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2174
+
2175
+ # return code
2176
+
2177
+
2178
+ # def patch_prefix_seaborn_calls(code: str) -> str:
2179
+ # """
2180
+ # Ensure bare seaborn calls are prefixed with `sns.`.
2181
+ # E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2182
+ # """
2183
+ # if "sns." in code:
2184
+ # # still fix any leftover bare calls alongside prefixed ones
2185
+ # pass
2186
+
2187
+ # # functions commonly used from seaborn
2188
+ # funcs = [
2189
+ # "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2190
+ # "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2191
+ # "scatterplot","lineplot","catplot","displot","lmplot"
2192
+ # ]
2193
+ # # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2194
+ # # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2195
+ # pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2196
+
2197
+ # def _add_prefix(m):
2198
+ # fn = m.group(1)
2199
+ # return f"sns.{fn}("
2200
+
2201
+ # return pattern.sub(_add_prefix, code)
2202
+
2203
+
2204
+ # def patch_ensure_seaborn_import(code: str) -> str:
2205
+ # """
2206
+ # If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2207
+ # Also set a quiet theme for consistent visuals.
2208
+ # """
2209
+ # needs_sns = "sns." in code
2210
+ # has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2211
+ # if needs_sns and not has_import:
2212
+ # # Insert after the first block of imports if possible, else at top
2213
+ # import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2214
+ # inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2215
+ # if import_block:
2216
+ # start = import_block.end()
2217
+ # code = code[:start] + inject + code[start:]
2218
+ # else:
2219
+ # code = inject + code
2220
+ # return code
2221
+
2222
+
2223
+ # def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2224
+ # """
2225
+ # Normalise pie-chart requests.
2226
+
2227
+ # Supports three patterns:
2228
+ # A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2229
+ # B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2230
+ # → one pie per facet level (grid) + counts bar of facet sizes.
2231
+ # C) Single pie when no split/facet is requested.
2232
+
2233
+ # Notes:
2234
+ # - Pie variables must be categorical (or numeric binned).
2235
+ # - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2236
+ # """
2237
+
2238
+ # q = (user_question or "")
2239
+ # q_low = q.lower()
2240
+
2241
+ # # Prefer explicit: df['col'].value_counts()
2242
+ # m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2243
+ # col = m.group(1) if m else None
2244
+
2245
+ # # ---------- helpers ----------
2246
+ # def _is_cat(col):
2247
+ # return (str(df[col].dtype).startswith("category")
2248
+ # or df[col].dtype == "object"
2249
+ # or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2250
+
2251
+ # def _cats_from_question(question: str):
2252
+ # found = []
2253
+ # for c in df.columns:
2254
+ # if c.lower() in question.lower() and _is_cat(c):
2255
+ # found.append(c)
2256
+ # # dedupe preserve order
2257
+ # seen, out = set(), []
2258
+ # for c in found:
2259
+ # if c not in seen:
2260
+ # out.append(c); seen.add(c)
2261
+ # return out
2262
+
2263
+ # def _fallback_cat():
2264
+ # cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2265
+ # if not cats: return None
2266
+ # cats.sort(key=lambda t: t[1])
2267
+ # return cats[0][0]
2268
+
2269
+ # def _infer_comp_pref(question: str) -> str:
2270
+ # ql = (question or "").lower()
2271
+ # if "heatmap" in ql or "matrix" in ql:
2272
+ # return "heatmap"
2273
+ # if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2274
+ # return "stacked_bar_pct"
2275
+ # if "stacked" in ql:
2276
+ # return "stacked_bar"
2277
+ # if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2278
+ # return "grouped_bar"
2279
+ # return "counts_bar"
2280
+
2281
+ # # parse threshold split like "HbA1c ≥ 6.5"
2282
+ # def _parse_split(question: str):
2283
+ # ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2284
+ # m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2285
+ # if not m: return None
2286
+ # col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2287
+ # op = ops_map.get(op_raw);
2288
+ # if not op: return None
2289
+ # # case-insensitive column match
2290
+ # candidates = {c.lower(): c for c in df.columns}
2291
+ # col = candidates.get(col_raw.lower())
2292
+ # if not col: return None
2293
+ # try: val = float(val_raw)
2294
+ # except Exception: return None
2295
+ # return (col, op, val)
2296
+
2297
+ # # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2298
+ # def _extract_facet(question: str):
2299
+ # # 1) explicit "by/ across / within / per <col>"
2300
+ # for kw in [" by ", " across ", " within ", " within each ", " per "]:
2301
+ # m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2302
+ # if m:
2303
+ # col_raw = m.group(1).strip()
2304
+ # candidates = {c.lower(): c for c in df.columns}
2305
+ # if col_raw.lower() in candidates:
2306
+ # return (candidates[col_raw.lower()], "auto")
2307
+ # # 2) "bin <col>"
2308
+ # m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2309
+ # if m2:
2310
+ # col_raw = m2.group(1).strip()
2311
+ # candidates = {c.lower(): c for c in df.columns}
2312
+ # if col_raw.lower() in candidates:
2313
+ # return (candidates[col_raw.lower()], "bin")
2314
+ # # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2315
+ # if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2316
+ # any(c.lower() == "bmi" for c in df.columns.str.lower()):
2317
+ # bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2318
+ # return (bmi_col, "bmi")
2319
+ # return None
2320
+
2321
+ # def _bmi_bins(series: pd.Series):
2322
+ # # WHO cutoffs
2323
+ # bins = [-np.inf, 18.5, 25, 30, np.inf]
2324
+ # labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2325
+ # return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2326
+
2327
+ # wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2328
+ # if not wants_pie:
2329
+ # return code
2330
+
2331
+ # split = _parse_split(q)
2332
+ # facet = _extract_facet(q)
2333
+ # cats = _cats_from_question(q)
2334
+ # _comp_pref = _infer_comp_pref(q)
2335
+
2336
+ # # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2337
+ # for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2338
+ # if hard in df.columns and hard not in cats and hard.lower() in q_low:
2339
+ # cats.append(hard)
2340
+
2341
+ # # --------------- CASE A: threshold split (cohorts) ---------------
2342
+ # if split:
2343
+ # if not (cats or any(_is_cat(c) for c in df.columns)):
2344
+ # return code
2345
+ # if not cats:
2346
+ # pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2347
+ # pool.sort(key=lambda t: t[1])
2348
+ # cats = [t[0] for t in pool[:3]] if pool else []
2349
+ # if not cats:
2350
+ # return code
2351
+
2352
+ # split_col, op, val = split
2353
+ # cond_str = f"(df['{split_col}'] {op} {val})"
2354
+ # snippet = f"""
2355
+ # import numpy as np
2356
+ # import pandas as pd
2357
+ # import matplotlib.pyplot as plt
2358
+
2359
+ # _mask_a = ({cond_str}) & df['{split_col}'].notna()
2360
+ # _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2361
+
2362
+ # _cohort_a_name = "{split_col} {op} {val}"
2363
+ # _cohort_b_name = "NOT ({split_col} {op} {val})"
2364
+
2365
+ # _cat_cols = {cats!r}
2366
+ # n = len(_cat_cols)
2367
+ # fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2368
+ # if n == 1:
2369
+ # axes = np.array([axes])
2370
+
2371
+ # for i, col in enumerate(_cat_cols):
2372
+ # s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2373
+ # s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2374
+
2375
+ # ax_a = axes[i, 0]; ax_b = axes[i, 1]
2376
+ # if len(s_a) > 0:
2377
+ # ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2378
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2379
+ # ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2380
+
2381
+ # if len(s_b) > 0:
2382
+ # ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2383
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2384
+ # ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2385
+
2386
+ # plt.tight_layout(); plt.show()
2387
+
2388
+ # # grouped bar complement
2389
+ # for col in _cat_cols:
2390
+ # _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2391
+ # .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2392
+ # _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2393
+ # _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2394
+
2395
+ # if _comp_pref == "grouped_bar":
2396
+ # ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2397
+ # ax.set_title(f"{col} by cohort (grouped)")
2398
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2399
+ # plt.tight_layout(); plt.show()
2400
+
2401
+ # elif _comp_pref == "stacked_bar":
2402
+ # ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2403
+ # ax.set_title(f"{col} by cohort (stacked)")
2404
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2405
+ # plt.tight_layout(); plt.show()
2406
+
2407
+ # elif _comp_pref == "stacked_bar_pct":
2408
+ # _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2409
+ # ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2410
+ # ax.set_title(f"{col} by cohort (100% stacked)")
2411
+ # ax.set_xlabel(col); ax.set_ylabel("Percent")
2412
+ # plt.tight_layout(); plt.show()
2413
+
2414
+ # elif _comp_pref == "heatmap":
2415
+ # _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2416
+ # import numpy as np
2417
+ # fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2418
+ # im = ax.imshow(_perc.values, aspect='auto')
2419
+ # ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2420
+ # ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2421
+ # ax.set_title(f"{col} by cohort — % heatmap")
2422
+ # for i in range(_perc.shape[0]):
2423
+ # for j in range(_perc.shape[1]):
2424
+ # ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2425
+ # fig.colorbar(im, ax=ax, label="%")
2426
+ # plt.tight_layout(); plt.show()
2427
+
2428
+ # else: # counts_bar (default)
2429
+ # ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2430
+ # ax.set_title(f"{col}: total counts (both cohorts)")
2431
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2432
+ # plt.tight_layout(); plt.show()
2433
+ # """.lstrip()
2434
+ # return snippet
2435
+
2436
+ # # --------------- CASE B: facet-by (categories/bins) ---------------
2437
+ # if facet:
2438
+ # facet_col, how = facet
2439
+ # # Build facet series
2440
+ # if pd.api.types.is_numeric_dtype(df[facet_col]):
2441
+ # if how == "bmi":
2442
+ # facet_series = _bmi_bins(df[facet_col])
2443
+ # else:
2444
+ # # generic numeric bins: 3 equal-width bins by default
2445
+ # facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2446
+ # else:
2447
+ # facet_series = df[facet_col].astype(str)
2448
+
2449
+ # # Choose pie dimension (categorical to count inside each facet)
2450
+ # pie_dim = None
2451
+ # for c in cats:
2452
+ # if c in df.columns and _is_cat(c):
2453
+ # pie_dim = c; break
2454
+ # if pie_dim is None:
2455
+ # pie_dim = _fallback_cat()
2456
+ # if pie_dim is None:
2457
+ # return code
2458
+
2459
+ # snippet = f"""
2460
+ # import math
2461
+ # import pandas as pd
2462
+ # import matplotlib.pyplot as plt
2463
+
2464
+ # df = df.copy()
2465
+ # _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2466
+
2467
+ # def _select_facet_col(df, preferred=None):
2468
+ # if preferred is not None:
2469
+ # return preferred
2470
+ # # Prefer low-cardinality categoricals (readable pies/grids)
2471
+ # cat_cols = [
2472
+ # c for c in df.columns
2473
+ # if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2474
+ # and df[c].nunique() > 1 and df[c].nunique() <= 20
2475
+ # ]
2476
+ # if cat_cols:
2477
+ # cat_cols.sort(key=lambda c: df[c].nunique())
2478
+ # return cat_cols[0]
2479
+ # # Else fall back to first usable numeric
2480
+ # num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2481
+ # return num_cols[0] if num_cols else None
2482
+
2483
+ # _facet_col = _select_facet_col(df, _preferred)
2484
+
2485
+ # if _facet_col is None:
2486
+ # # Nothing suitable → single facet keeps pipeline alive
2487
+ # df["__facet__"] = "All"
2488
+ # else:
2489
+ # s = df[_facet_col]
2490
+ # if pd.api.types.is_numeric_dtype(s):
2491
+ # # Robust numeric binning: quantiles first, fallback to equal-width
2492
+ # uniq = pd.Series(s).dropna().nunique()
2493
+ # q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2494
+ # try:
2495
+ # df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2496
+ # except Exception:
2497
+ # df["__facet__"] = pd.cut(s.astype(float), bins=q)
2498
+ # else:
2499
+ # # Cap long tails; keep top categories
2500
+ # vc = s.astype(str).value_counts()
2501
+ # keep = vc.index[:{top_n}]
2502
+ # df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2503
+
2504
+ # levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2505
+ # levels = [x for x in levels if x != "nan"]
2506
+ # levels.sort()
2507
+
2508
+ # m = len(levels)
2509
+ # cols = 3 if m >= 3 else m or 1
2510
+ # rows = int(math.ceil(m / cols))
2511
+
2512
+ # fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2513
+ # if not isinstance(axes, (list, np.ndarray)):
2514
+ # axes = np.array([[axes]])
2515
+ # axes = axes.reshape(rows, cols)
2516
+
2517
+ # for i, lvl in enumerate(levels):
2518
+ # r, c = divmod(i, cols)
2519
+ # ax = axes[r, c]
2520
+ # s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2521
+ # .astype(str).value_counts().nlargest({top_n}))
2522
+ # if len(s) > 0:
2523
+ # ax.pie(s.values, labels=[str(x) for x in s.index],
2524
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2525
+ # ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2526
+
2527
+ # # hide any empty subplots
2528
+ # for j in range(m, rows*cols):
2529
+ # r, c = divmod(j, cols)
2530
+ # axes[r, c].axis("off")
2531
+
2532
+ # plt.tight_layout(); plt.show()
2533
+
2534
+ # # --- companion visual (adaptive) ---
2535
+ # _comp_pref = "{_comp_pref}"
2536
+ # # build contingency table: pie_dim x facet
2537
+ # _tab = (df[["__facet__", "{pie_dim}"]]
2538
+ # .dropna()
2539
+ # .astype({{"__facet__": str, "{pie_dim}": str}})
2540
+ # .value_counts()
2541
+ # .unstack(level="__facet__")
2542
+ # .fillna(0))
2543
+
2544
+ # # keep top categories by overall size
2545
+ # _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2546
+
2547
+ # if _comp_pref == "grouped_bar":
2548
+ # ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2549
+ # ax.set_title("{pie_dim} by {facet_col} (grouped)")
2550
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2551
+ # plt.tight_layout(); plt.show()
2552
+
2553
+ # elif _comp_pref == "stacked_bar":
2554
+ # ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2555
+ # ax.set_title("{pie_dim} by {facet_col} (stacked)")
2556
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2557
+ # plt.tight_layout(); plt.show()
2558
+
2559
+ # elif _comp_pref == "stacked_bar_pct":
2560
+ # _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2561
+ # ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2562
+ # ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2563
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2564
+ # plt.tight_layout(); plt.show()
2565
+
2566
+ # elif _comp_pref == "heatmap":
2567
+ # _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2568
+ # import numpy as np
2569
+ # fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2570
+ # im = ax.imshow(_perc.values, aspect='auto')
2571
+ # ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2572
+ # ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2573
+ # ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2574
+ # for i in range(_perc.shape[0]):
2575
+ # for j in range(_perc.shape[1]):
2576
+ # ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2577
+ # fig.colorbar(im, ax=ax, label="%")
2578
+ # plt.tight_layout(); plt.show()
2579
+
2580
+ # else: # counts_bar (default denominators)
2581
+ # _counts = df["__facet"].value_counts()
2582
+ # ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2583
+ # ax.set_title("Counts by {facet_col}")
2584
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2585
+ # plt.tight_layout(); plt.show()
2586
+
2587
+ # """.lstrip()
2588
+ # return snippet
2589
+
2590
+ # # --------------- CASE C: single pie ---------------
2591
+ # chosen = None
2592
+ # for c in cats:
2593
+ # if c in df.columns and _is_cat(c):
2594
+ # chosen = c; break
2595
+ # if chosen is None:
2596
+ # chosen = _fallback_cat()
2597
+
2598
+ # if chosen:
2599
+ # snippet = f"""
2600
+ # import matplotlib.pyplot as plt
2601
+ # counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2602
+ # fig, ax = plt.subplots()
2603
+ # if len(counts) > 0:
2604
+ # ax.pie(counts.values, labels=[str(i) for i in counts.index],
2605
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2606
+ # ax.set_title('Distribution of {chosen} (top {top_n})')
2607
+ # ax.axis('equal')
2608
+ # plt.show()
2609
+ # """.lstrip()
2610
+ # return snippet
2611
+
2612
+ # # numeric last resort
2613
+ # num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2614
+ # if num_cols:
2615
+ # col = num_cols[0]
2616
+ # snippet = f"""
2617
+ # import pandas as pd
2618
+ # import matplotlib.pyplot as plt
2619
+ # bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2620
+ # counts = bins.value_counts().sort_index()
2621
+ # fig, ax = plt.subplots()
2622
+ # if len(counts) > 0:
2623
+ # ax.pie(counts.values, labels=[str(i) for i in counts.index],
2624
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2625
+ # ax.set_title('Distribution of {col} (binned)')
2626
+ # ax.axis('equal')
2627
+ # plt.show()
2628
+ # """.lstrip()
2629
+ # return snippet
2630
+
2631
+ # return code
2632
+
2633
+
2634
+ # def patch_fix_seaborn_palette_calls(code: str) -> str:
2635
+ # """
2636
+ # Removes seaborn `palette=` when no `hue=` is present in the same call.
2637
+ # Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2638
+ # """
2639
+ # if "sns." not in code:
2640
+ # return code
2641
+
2642
+ # # Targets common seaborn plotters
2643
+ # funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2644
+ # pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2645
+
2646
+ # def _fix_call(m):
2647
+ # head, inner = m.group(1), m.group(2)
2648
+ # # If there's already hue=, keep as is
2649
+ # if re.search(r"(?<!\w)hue\s*=", inner):
2650
+ # return f"{head}{inner})"
2651
+ # # Otherwise remove palette=... safely (and any adjacent comma spacing)
2652
+ # inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2653
+ # inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2654
+ # inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2655
+ # return f"{head}{inner2})"
2656
+
2657
+ # return pattern.sub(_fix_call, code)
2658
+
2659
+ # def _norm_col_name(s: str) -> str:
2660
+ # """normalise a column name: lowercase + strip non-alphanumerics."""
2661
+ # return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2662
+
2663
+
2664
+ # def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2665
+ # """return the actual df column that matches any candidate (after normalisation)."""
2666
+ # norm_map = {_norm_col_name(c): c for c in df.columns}
2667
+ # for cand in candidates:
2668
+ # hit = norm_map.get(_norm_col_name(cand))
2669
+ # if hit is not None:
2670
+ # return hit
2671
+ # return None
2672
+
2673
+
2674
+ # def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2675
+ # """
2676
+ # If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2677
+ # Returns (df, found_bool).
2678
+ # """
2679
+ # if target in df.columns:
2680
+ # return df, True
2681
+ # col = _first_present(df, [target, *aliases])
2682
+ # if col is None:
2683
+ # return df, False
2684
+ # df[target] = df[col]
2685
+ # return df, True
2686
+
2687
+
2688
+ # def strip_python_dotenv(code: str) -> str:
2689
+ # """
2690
+ # Remove any use of python-dotenv from generated code, including:
2691
+ # - single and multi-line 'from dotenv import ...'
2692
+ # - 'import dotenv' (with or without alias) and calls via any alias
2693
+ # - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2694
+ # - IPython magics (%load_ext dotenv, %dotenv, %env …)
2695
+ # - shell installs like '!pip install python-dotenv'
2696
+ # """
2697
+ # original = code
2698
+
2699
+ # # 0) Kill IPython magics & shell installs referencing dotenv
2700
+ # code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2701
+ # code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2702
+ # code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2703
+ # code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2704
+
2705
+ # # 1) Remove single-line 'from dotenv import ...'
2706
+ # code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2707
+
2708
+ # # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2709
+ # code = re.sub(
2710
+ # r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2711
+ # "",
2712
+ # code,
2713
+ # flags=re.MULTILINE,
2714
+ # )
2715
+
2716
+ # # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2717
+ # aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2718
+ # code, flags=re.MULTILINE)
2719
+ # code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2720
+ # "", code, flags=re.MULTILINE)
2721
+
2722
+ # # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2723
+ # # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2724
+ # fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2725
+ # # bare calls
2726
+ # code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2727
+ # # dotted calls with any identifier prefix (alias or module)
2728
+ # code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2729
+ # "", code, flags=re.MULTILINE)
2730
+
2731
+ # # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2732
+ # for al in aliases:
2733
+ # code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2734
+
2735
+ # # 6) Tidy excess blank lines
2736
+ # code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2737
+ # return code
2738
+
2739
+
2740
+ # def fix_predict_calls_records_arg(code: str) -> str:
2741
+ # """
2742
+ # If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2743
+ # (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2744
+ # Works line-by-line to avoid over-rewrites elsewhere.
2745
+ # Examples fixed:
2746
+ # predict_patient(X_test.iloc[:5].to_dict('records'))
2747
+ # predict_risk(df.head(3).to_dict(orient="records"))
2748
+ # → predict_patient(X_test.iloc[:5])
2749
+ # """
2750
+ # fixed_lines = []
2751
+ # for line in code.splitlines():
2752
+ # if "predict_" in line and "to_dict" in line and "records" in line:
2753
+ # line = re.sub(
2754
+ # r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2755
+ # "",
2756
+ # line
2757
+ # )
2758
+ # fixed_lines.append(line)
2759
+ # return "\n".join(fixed_lines)
2760
+
2761
+
2762
+ # def fix_fstring_backslash_paths(code: str) -> str:
2763
+ # """
2764
+ # Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2765
+ # → f"...{os.path.join(out_dir, r'plots\\img.png')}"
2766
+ # Only touches f-strings that contain a backslash path inside {...}.
2767
+ # """
2768
+ # def _fix_line(line: str) -> str:
2769
+ # # quick check: only f-strings need scanning
2770
+ # if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2771
+ # return line
2772
+ # # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2773
+ # pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2774
+ # def repl(m):
2775
+ # left = m.group(1)
2776
+ # right = m.group(2).strip().replace('"', '\\"')
2777
+ # return "{os.path.join(" + left + ', r"' + right + '")}'
2778
+ # return pattern.sub(repl, line)
2779
+
2780
+ # return "\n".join(_fix_line(ln) for ln in code.splitlines())
2781
+
2782
+
2783
+ # def ensure_os_import(code: str) -> str:
2784
+ # """
2785
+ # If os.path.join is used but 'import os' is missing, inject it at the top.
2786
+ # """
2787
+ # needs = "os.path.join(" in code
2788
+ # has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2789
+ # has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2790
+ # if needs and not (has_import_os or has_from_os):
2791
+ # return "import os\n" + code
2792
+ # return code
2793
+
2794
+
2795
+ # def fix_seaborn_boxplot_nameerror(code: str) -> str:
2796
+ # """
2797
+ # Fix bad calls like: sns.boxplot(boxplot)
2798
+ # Heuristic:
2799
+ # - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2800
+ # - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
2801
+ # - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2802
+ # - Else → sns.boxplot(ax=ax)
2803
+ # Ensures a matplotlib Axes 'ax' exists.
2804
+ # """
2805
+ # pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2806
+ # if not pattern.search(code):
2807
+ # return code
2808
+
2809
+ # has_plot_df = re.search(r"\bplot_df\b", code) is not None
2810
+ # has_var = re.search(r"\bvar\b", code) is not None
2811
+ # has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2812
+
2813
+ # if has_plot_df and has_var and has_fh:
2814
+ # replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2815
+ # elif has_plot_df and has_var:
2816
+ # replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2817
+ # elif has_plot_df:
2818
+ # replacement = "sns.boxplot(data=plot_df, ax=ax)"
2819
+ # else:
2820
+ # replacement = "sns.boxplot(ax=ax)"
2821
+
2822
+ # fixed = pattern.sub(replacement, code)
2823
+
2824
+ # # Ensure 'fig, ax = plt.subplots(...)' exists
2825
+ # if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2826
+ # # Insert right before the first seaborn call
2827
+ # m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2828
+ # insert_at = m.start() if m else 0
2829
+ # fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2830
+
2831
+ # return fixed
2832
+
2833
+
2834
+ # def fix_seaborn_barplot_nameerror(code: str) -> str:
2835
+ # """
2836
+ # Fix bad calls like: sns.barplot(barplot)
2837
+ # Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2838
+ # otherwise degrade safely to an empty call on an existing Axes.
2839
+ # """
2840
+ # import re
2841
+ # pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2842
+ # if not pattern.search(code):
2843
+ # return code
2844
+
2845
+ # has_plot_df = re.search(r"\bplot_df\b", code) is not None
2846
+ # has_var = re.search(r"\bvar\b", code) is not None
2847
+ # has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2848
+
2849
+ # if has_plot_df and has_var and has_fh:
2850
+ # replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2851
+ # elif has_plot_df and has_var:
2852
+ # replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2853
+ # elif has_plot_df:
2854
+ # replacement = "sns.barplot(data=plot_df, ax=ax)"
2855
+ # else:
2856
+ # replacement = "sns.barplot(ax=ax)"
2857
+
2858
+ # # ensure an Axes 'ax' exists (no-op if already present)
2859
+ # if "ax =" not in code:
2860
+ # code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2861
+
2862
+ # return pattern.sub(replacement, code)
2863
+
2864
+
2865
+ # def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2866
+ # """
2867
+ # Parses the raw text to extract and format the 'refined question',
2868
+ # 'intents (tasks)', and 'chronology of tasks' sections.
2869
+ # Args:
2870
+ # raw_text: The complete input string containing the ML pipeline structure.
2871
+ # Returns:
2872
+ # A tuple containing:
2873
+ # (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2874
+ # """
2875
+ # # --- 1. Regex Pattern to Extract Sections ---
2876
+ # # The pattern uses capturing groups (?) to look for the section headers
2877
+ # # (e.g., 'refined question:') and captures all the content until the next
2878
+ # # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2879
+
2880
+ # pattern = re.compile(
2881
+ # r"refined question:(?P<question>.*?)"
2882
+ # r"intents \(tasks\):(?P<intents>.*?)"
2883
+ # r"Chronology of tasks:(?P<chronology>.*)",
2884
+ # re.IGNORECASE | re.DOTALL
2885
+ # )
2886
+
2887
+ # match = pattern.search(raw_text)
2888
+
2889
+ # if not match:
2890
+ # raise ValueError("Input text structure does not match the expected pattern.")
1596
2891
 
1597
- import re
1598
- out_lines = []
1599
- for ln in code.splitlines():
1600
- if "plt.show()" not in ln:
1601
- out_lines.append(ln)
1602
- continue
2892
+ # # --- 2. Extract Content ---
2893
+ # question_content = match.group('question').strip()
2894
+ # intents_content = match.group('intents').strip()
2895
+ # chronology_content = match.group('chronology').strip()
1603
2896
 
1604
- # works for:
1605
- # plt.show()
1606
- # plt.tight_layout(); plt.show()
1607
- # ... ; plt.show(); ... (multiple on one line)
1608
- indent = re.match(r"^(\s*)", ln).group(1)
1609
- parts = ln.split("plt.show()")
2897
+ # # --- 3. Formatting Functions ---
1610
2898
 
1611
- # keep whatever is before the first plt.show()
1612
- if parts[0].strip():
1613
- out_lines.append(parts[0].rstrip())
2899
+ # def format_question(content):
2900
+ # """Formats the Refined Question section."""
2901
+ # # Clean up leading/trailing whitespace and ensure clean paragraphs
2902
+ # content = content.strip().replace('\n', ' ').replace(' ', ' ')
2903
+
2904
+ # # Simple formatting using Markdown headers and bolding
2905
+ # formatted = (
2906
+ # # "## 1. Project Goal and Objectives\n\n"
2907
+ # "<b> Refined Question:</b>\n"
2908
+ # f"{content}\n"
2909
+ # )
2910
+ # return formatted
2911
+
2912
+ # def format_intents(content):
2913
+ # """Formats the Intents (Tasks) section as a structured list."""
2914
+ # # Use regex to find and format each numbered task
2915
+ # # It finds 'N. **Text** - ...' and breaks it down.
2916
+
2917
+ # tasks = []
2918
+ # # Pattern: N. **Text** - Content (including newlines, non-greedy)
2919
+ # # We need to explicitly handle the list items starting with '-' within the content
2920
+ # task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
2921
+
2922
+ # # Split the content by lines and join tasks back into clean strings
2923
+ # raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
2924
+
2925
+ # for task in raw_tasks:
2926
+ # # Replace the initial task number and **Heading** with a Heading 3
2927
+ # task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
2928
+
2929
+ # # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
2930
+ # task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
2931
+ # tasks.append(task)
2932
+
2933
+ # formatted_tasks = "\n\n".join(tasks)
2934
+
2935
+ # return (
2936
+ # "\n---\n"
2937
+ # "## 2. Methodology and Tasks\n\n"
2938
+ # f"{formatted_tasks}\n"
2939
+ # )
2940
+
2941
+ # def format_chronology(content):
2942
+ # """Formats the Chronology section."""
2943
+ # # Uses the given LaTeX format
2944
+ # content = content.strip().replace(' ', ' \rightarrow ')
2945
+ # formatted = (
2946
+ # "\n---\n"
2947
+ # "## 3. Chronology of Tasks\n"
2948
+ # f"$$\\text{{{content}}}$$"
2949
+ # )
2950
+ # return formatted
2951
+
2952
+ # # --- 4. Format and Return ---
2953
+ # formatted_question = format_question(question_content)
2954
+ # formatted_intents = format_intents(intents_content)
2955
+ # formatted_chronology = format_chronology(chronology_content)
2956
+
2957
+ # return formatted_question, formatted_intents, formatted_chronology
2958
+
2959
+
2960
+ # def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
2961
+ # """Combines all formatted parts into a final report string."""
2962
+ # return (
2963
+ # "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
2964
+ # f"{formatted_question}\n"
2965
+ # f"{formatted_intents}\n"
2966
+ # f"{formatted_chronology}\n"
2967
+ # )
2968
+
2969
+
2970
+ # def fix_confusion_matrix_for_multilabel(code: str) -> str:
2971
+ # """
2972
+ # Replace ConfusionMatrixDisplay.from_estimator(...) usages with
2973
+ # from_predictions(...) which works for multi-label loops without requiring
2974
+ # the estimator to expose _estimator_type.
2975
+ # """
2976
+ # return re.sub(
2977
+ # r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
2978
+ # r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
2979
+ # code
2980
+ # )
2981
+
2982
+
2983
+ # def smx_auto_title_plots(ctx=None, fallback="Analysis"):
2984
+ # """
2985
+ # Ensure every Matplotlib/Seaborn Axes has a title.
2986
+ # Uses refined_question -> askai_question -> fallback.
2987
+ # Only sets a title if it's currently empty.
2988
+ # """
2989
+ # import matplotlib.pyplot as plt
2990
+
2991
+ # def _all_figures():
2992
+ # try:
2993
+ # from matplotlib._pylab_helpers import Gcf
2994
+ # return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
2995
+ # except Exception:
2996
+ # # Best effort fallback
2997
+ # nums = plt.get_fignums()
2998
+ # return [plt.figure(n) for n in nums] if nums else []
2999
+
3000
+ # # Choose a concise title
3001
+ # title = None
3002
+ # if isinstance(ctx, dict):
3003
+ # title = ctx.get("refined_question") or ctx.get("askai_question")
3004
+ # title = (str(title).strip().splitlines()[0][:120]) if title else fallback
3005
+
3006
+ # for fig in _all_figures():
3007
+ # for ax in getattr(fig, "axes", []):
3008
+ # try:
3009
+ # if not (ax.get_title() or "").strip():
3010
+ # ax.set_title(title)
3011
+ # except Exception:
3012
+ # pass
3013
+ # try:
3014
+ # fig.tight_layout()
3015
+ # except Exception:
3016
+ # pass
3017
+
3018
+
3019
+ # def patch_fix_sentinel_plot_calls(code: str) -> str:
3020
+ # """
3021
+ # Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
3022
+ # SB_barplot(barplot) -> SB_barplot()
3023
+ # SB_barplot(barplot, ...) -> SB_barplot(...)
3024
+ # sns.barplot(barplot) -> SB_barplot()
3025
+ # sns.barplot(barplot, ...) -> SB_barplot(...)
3026
+ # Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
3027
+ # """
3028
+ # names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
3029
+ # for n in names:
3030
+ # # SB_* with sentinel as the first arg (with or without trailing args)
3031
+ # code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3032
+ # code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3033
+ # # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
3034
+ # code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3035
+ # code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3036
+ # return code
3037
+
3038
+
3039
+ # def patch_rmse_calls(code: str) -> str:
3040
+ # """
3041
+ # Make RMSE robust across sklearn versions.
3042
+ # - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
3043
+ # - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
3044
+ # """
3045
+ # import re
3046
+ # # (a) Specific RMSE pattern
3047
+ # code = re.sub(
3048
+ # r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3049
+ # r"_SMX_rmse(\1)",
3050
+ # code,
3051
+ # flags=re.DOTALL
3052
+ # )
3053
+ # # (b) Guard any other MSE calls
3054
+ # code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3055
+ # return code
1614
3056
 
1615
- # for every plt.show() we removed, insert exporter at same indent
1616
- for _ in range(len(parts) - 1):
1617
- out_lines.append(indent + "_SMX_export_png()")
1618
-
1619
- # keep whatever comes after the last plt.show()
1620
- if parts[-1].strip():
1621
- out_lines.append(indent + parts[-1].lstrip())
1622
-
1623
- return "\n".join(out_lines)
1624
-
1625
-
1626
- def clean_llm_code(code: str) -> str:
1627
- """
1628
- Make LLM output safe to exec:
1629
- - If fenced blocks exist, keep the largest one (usually the real code).
1630
- - Otherwise strip any stray ``` / ```python lines.
1631
- - Remove common markdown/preamble junk.
1632
- """
1633
- code = str(code or "")
1634
-
1635
- # Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
1636
- # is accidentally passed here as `code`. In that case, extract the actual
1637
- # Python code from the ChatCompletionMessage(content=...) field.
1638
- if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
1639
- try:
1640
- extracted = None
1641
-
1642
- class _ChatCompletionVisitor(ast.NodeVisitor):
1643
- def visit_Call(self, node):
1644
- nonlocal extracted
1645
- func = node.func
1646
- fname = getattr(func, "id", None) or getattr(func, "attr", None)
1647
- if fname == "ChatCompletionMessage":
1648
- for kw in node.keywords:
1649
- if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
1650
- extracted = kw.value.value
1651
- self.generic_visit(node)
1652
-
1653
- tree = ast.parse(code, mode="exec")
1654
- _ChatCompletionVisitor().visit(tree)
1655
- if extracted:
1656
- code = extracted
1657
- except Exception:
1658
- # Best-effort regex fallback if AST parsing fails
1659
- m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
1660
- if m:
1661
- code = m.group("body")
1662
-
1663
- # Existing logic continues unchanged below...
1664
- # Extract fenced blocks (```python ... ``` or ``` ... ```)
1665
- blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1666
-
1667
- if blocks:
1668
- # pick the largest block; small trailing blocks are usually garbage
1669
- largest = max(blocks, key=lambda b: len(b.strip()))
1670
- if len(largest.strip().splitlines()) >= 10:
1671
- code = largest
1672
-
1673
- # Extract fenced blocks (```python ... ``` or ``` ... ```)
1674
- blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1675
-
1676
- if blocks:
1677
- # pick the largest block; small trailing blocks are usually garbage
1678
- largest = max(blocks, key=lambda b: len(b.strip()))
1679
- if len(largest.strip().splitlines()) >= 10:
1680
- code = largest
1681
- else:
1682
- # if no meaningful block, just remove fence markers
1683
- code = re.sub(r"^```.*?$", "", code, flags=re.M)
1684
- else:
1685
- # no complete blocks — still remove any stray fence lines
1686
- code = re.sub(r"^```.*?$", "", code, flags=re.M)
1687
-
1688
- # Strip common markdown/preamble lines
1689
- drop_prefixes = (
1690
- "here is", "here's", "below is", "sure,", "certainly",
1691
- "explanation", "note:", "```"
1692
- )
1693
- cleaned_lines = []
1694
- for ln in code.splitlines():
1695
- s = ln.strip().lower()
1696
- if any(s.startswith(p) for p in drop_prefixes):
1697
- continue
1698
- cleaned_lines.append(ln)
1699
-
1700
- return "\n".join(cleaned_lines).strip()
1701
-
1702
-
1703
- def fix_groupby_describe_slice(code: str) -> str:
1704
- """
1705
- Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
1706
- so it works for both SeriesGroupBy and DataFrameGroupBy.
1707
- """
1708
- pat = re.compile(
1709
- r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
1710
- re.MULTILINE
1711
- )
1712
- def repl(match):
1713
- inner = match.group(0)
1714
- # extract group and feature to build df.groupby('g')['f']
1715
- g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
1716
- f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
1717
- return (
1718
- f"df.groupby('{g}')['{f}']"
1719
- ".agg(['count','mean','std','min','median','max'])"
1720
- ".rename(columns={'median':'50%'})"
1721
- )
1722
- return pat.sub(repl, code)
1723
-
1724
-
1725
- def fix_importance_groupby(code: str) -> str:
1726
- pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
1727
- if "importance_df" in code:
1728
- return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
1729
- return code
1730
-
1731
- def inject_auto_preprocessing(code: str) -> str:
1732
- """
1733
- • Detects a RandomForestClassifier in the generated code.
1734
- • Finds the target column from `y = df['target']`.
1735
- • Prepends a fully-dedented preprocessing snippet that:
1736
- – auto-detects numeric & categorical columns
1737
- – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
1738
- The dedent() call guarantees no leading-space IndentationError.
1739
- """
1740
- if "RandomForestClassifier" not in code:
1741
- return code # nothing to patch
1742
-
1743
- y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
1744
- if not y_match:
1745
- return code # can't infer target safely
1746
- target = y_match.group(1)
1747
-
1748
- prep_snippet = textwrap.dedent(f"""
1749
- # ── automatic preprocessing ───────────────────────────────
1750
- num_cols = df.select_dtypes(include=['number']).columns.tolist()
1751
- cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
1752
- num_cols = [c for c in num_cols if c != '{target}']
1753
- cat_cols = [c for c in cat_cols if c != '{target}']
1754
-
1755
- from sklearn.compose import ColumnTransformer
1756
- from sklearn.preprocessing import StandardScaler, OneHotEncoder
1757
-
1758
- preproc = ColumnTransformer(
1759
- transformers=[
1760
- ('num', StandardScaler(), num_cols),
1761
- ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
1762
- ],
1763
- remainder='drop',
1764
- )
1765
- # ───────────────────────────────────────────────────────────
1766
- """).strip() + "\n\n"
1767
-
1768
- # simply prepend; model code that follows can wrap estimator in a Pipeline
1769
- return prep_snippet + code
1770
-
1771
-
1772
- def fix_to_datetime_errors(code: str) -> str:
1773
- """
1774
- Force every pd.to_datetime(…) call to ignore bad dates so that
1775
-
1776
- 'year 16500 is out of range' and similar issues don’t crash runs.
1777
- """
1778
- import re
1779
- # look for any pd.to_datetime( … )
1780
- pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
1781
- def repl(m):
1782
- inside = m.group(1)
1783
- # if the call already has errors=, leave it unchanged
1784
- if "errors=" in inside:
1785
- return m.group(0)
1786
- return f"pd.to_datetime({inside}, errors='coerce')"
1787
- return pat.sub(repl, code)
1788
-
1789
-
1790
- def fix_numeric_sum(code: str) -> str:
1791
- """
1792
- Make .sum(...) code safe across pandas versions by removing any
1793
- numeric_only=... argument (True/False/None) from function calls.
1794
-
1795
- This avoids errors on pandas versions where numeric_only is not
1796
- supported for Series/grouped sums, and we rely instead on explicit
1797
- numeric column selection (e.g. select_dtypes) in the generated code.
1798
- """
1799
- # Case 1: ..., numeric_only=True/False/None
1800
- code = re.sub(
1801
- r",\s*numeric_only\s*=\s*(True|False|None)",
1802
- "",
1803
- code,
1804
- flags=re.IGNORECASE,
1805
- )
1806
-
1807
- # Case 2: numeric_only=True/False/None, ... (as first argument)
1808
- code = re.sub(
1809
- r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1810
- "",
1811
- code,
1812
- flags=re.IGNORECASE,
1813
- )
1814
-
1815
- # Case 3: numeric_only=True/False/None (only argument)
1816
- code = re.sub(
1817
- r"numeric_only\s*=\s*(True|False|None)",
1818
- "",
1819
- code,
1820
- flags=re.IGNORECASE,
1821
- )
1822
-
1823
- return code
1824
-
1825
-
1826
- def fix_concat_empty_list(code: str) -> str:
1827
- """
1828
- Make pd.concat calls resilient to empty lists of objects.
1829
-
1830
- Transforms patterns like:
1831
- pd.concat(frames, ignore_index=True)
1832
- pd.concat(frames)
1833
-
1834
- into:
1835
- pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1836
- pd.concat(frames or [pd.DataFrame()])
1837
-
1838
- Only triggers when the first argument is a simple variable name.
1839
- """
1840
- pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1841
-
1842
- def _repl(m):
1843
- name = m.group(1)
1844
- sep = m.group(2) # ',' or ')'
1845
- return f"pd.concat({name} or [pd.DataFrame()]{sep}"
1846
-
1847
- return pattern.sub(_repl, code)
1848
-
1849
-
1850
- def fix_numeric_aggs(code: str) -> str:
1851
- _AGG_FUNCS = ("sum", "mean")
1852
- pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
1853
- def _repl(m):
1854
- func, args = m.group(1), m.group(2) or ""
1855
- if "numeric_only" in args:
1856
- return m.group(0)
1857
- args = args.rstrip()
1858
- if args:
1859
- args += ", "
1860
- return f".{func}({args}numeric_only=True)"
1861
- return pat.sub(_repl, code)
1862
-
1863
-
1864
- def ensure_accuracy_block(code: str) -> str:
1865
- """
1866
- Inject a sensible evaluation block right after the last `<est>.fit(...)`
1867
- Classification → accuracy + weighted F1
1868
- Regression → R², RMSE, MAE
1869
- Heuristic: infer task from estimator names present in the code.
1870
- """
1871
- import re, textwrap
1872
-
1873
- # If any proper metric already exists, do nothing
1874
- if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
1875
- return code
1876
-
1877
- # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
1878
- m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
1879
- if not m:
1880
- return code # no estimator
1881
-
1882
- var = m[-1].group(1)
1883
- # indent with same leading whitespace used on that line
1884
- indent = re.match(r"\s*", code[m[-1].start():]).group(0)
1885
-
1886
- # Detect regression by estimator names / hints in code
1887
- is_regression = bool(
1888
- re.search(
1889
- r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1890
- r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1891
- r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1892
- )
1893
- or re.search(r"\bOLS\s*\(", code)
1894
- or re.search(r"\bRegressor\b", code)
1895
- )
1896
-
1897
- if is_regression:
1898
- # inject numpy import if needed for RMSE
1899
- if "import numpy as np" not in code and "np." not in code:
1900
- code = "import numpy as np\n" + code
1901
- eval_block = textwrap.dedent(f"""
1902
- {indent}# ── automatic regression evaluation ─────────
1903
- {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
1904
- {indent}y_pred = {var}.predict(X_test)
1905
- {indent}r2 = r2_score(y_test, y_pred)
1906
- {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
1907
- {indent}mae = float(mean_absolute_error(y_test, y_pred))
1908
- {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
1909
- """)
1910
- else:
1911
- eval_block = textwrap.dedent(f"""
1912
- {indent}# ── automatic classification evaluation ─────────
1913
- {indent}from sklearn.metrics import accuracy_score, f1_score
1914
- {indent}y_pred = {var}.predict(X_test)
1915
- {indent}acc = accuracy_score(y_test, y_pred)
1916
- {indent}f1 = f1_score(y_test, y_pred, average='weighted')
1917
- {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
1918
- """)
1919
-
1920
- insert_at = code.find("\n", m[-1].end()) + 1
1921
- return code[:insert_at] + eval_block + code[insert_at:]
1922
-
1923
-
1924
- def fix_scatter_and_summary(code: str) -> str:
1925
- """
1926
- 1. Change cmap='spectral' (any case) → cmap='Spectral'
1927
- 2. If the LLM forgets to close the parenthesis in
1928
- summary_table = ( df.describe()... <missing )>
1929
- insert the ')' right before the next 'from' or 'show('.
1930
- """
1931
- # 1️⃣ colormap case
1932
- code = re.sub(
1933
- r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
1934
- "cmap='Spectral'",
1935
- code,
1936
- flags=re.IGNORECASE,
1937
- )
1938
-
1939
- # 2️⃣ close summary_table = ( ... )
1940
- code = re.sub(
1941
- r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
1942
- r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
1943
- r"\1)", # keep group 1 and add ')'
1944
- code,
1945
- flags=re.MULTILINE,
1946
- )
1947
-
1948
- return code
1949
-
1950
-
1951
- def auto_format_with_black(code: str) -> str:
1952
- """
1953
- Format the generated code with Black. Falls back silently if Black
1954
- is missing or raises (so the dashboard never 500s).
1955
- """
1956
- try:
1957
- import black # make sure black is in your v-env: pip install black
1958
-
1959
- mode = black.FileMode() # default settings
1960
- return black.format_str(code, mode=mode)
1961
-
1962
- except Exception:
1963
- return code
1964
-
1965
-
1966
- def ensure_preproc_in_pipeline(code: str) -> str:
1967
- """
1968
- If code defines `preproc = ColumnTransformer(...)` but then builds
1969
- `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
1970
- that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
1971
- """
1972
- return re.sub(
1973
- r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
1974
- "Pipeline([('prep', preproc)",
1975
- code
1976
- )
1977
-
1978
-
1979
- def fix_plain_prints(code: str) -> str:
1980
- """
1981
- Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
1982
- to go through SyntaxMatrix's smart display (so it renders in the dashboard).
1983
- Keeps string prints alone.
1984
- """
1985
-
1986
- # Skip obvious string-literal prints
1987
- new = re.sub(
1988
- r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
1989
- r"from syntaxmatrix.display import show\nshow(\1)",
1990
- code,
1991
- )
1992
- return new
1993
-
1994
-
1995
- def fix_print_html(code: str) -> str:
1996
- """
1997
- Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
1998
- not printed as `<IPython.core.display.HTML object>` to the server console.
1999
- - Rewrites: print(HTML(...)) → display(HTML(...))
2000
- print(display(...)) → display(...)
2001
- print(df.to_html(...)) → display(HTML(df.to_html(...)))
2002
- Also prepends `from IPython.display import display, HTML` if required.
2003
- """
2004
- import re
2005
-
2006
- new = code
2007
-
2008
- # 1) print(HTML(...)) -> display(HTML(...))
2009
- new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
2010
-
2011
- # 2) print(display(...)) -> display(...)
2012
- new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
2013
-
2014
- # 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
2015
- new = re.sub(
2016
- r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
2017
- r"display(HTML(\1.to_html(", new
2018
- )
2019
-
2020
- # If code references HTML() or display() make sure the import exists
2021
- if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
2022
- "from IPython.display import display, HTML" not in new:
2023
- new = "from IPython.display import display, HTML\n" + new
2024
-
2025
- return new
2026
-
2027
-
2028
- def ensure_ipy_display(code: str) -> str:
2029
- """
2030
- Guarantee that the cell has proper IPython display imports so that
2031
- display(HTML(...)) produces 'display_data' events the kernel captures.
2032
- """
2033
- if "display(" in code and "from IPython.display import display, HTML" not in code:
2034
- return "from IPython.display import display, HTML\n" + code
2035
- return code
2036
-
2037
-
2038
- def drop_bad_classification_metrics(code: str, y_or_df) -> str:
2039
- """
2040
- Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
2041
- if the generated cell is *regression*. We infer this from:
2042
- 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
2043
- 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
2044
- Safe across datasets and queries.
2045
- """
2046
- import re
2047
- import pandas as pd
2048
-
2049
- # 1) Heuristic by estimator names in the *code* (fast path)
2050
- regression_by_model = bool(re.search(
2051
- r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
2052
- r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
2053
- r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
2054
- ) or re.search(r"\bOLS\s*\(", code))
2055
-
2056
- is_regression = regression_by_model
2057
-
2058
- # 2) If not obvious from the model, try to infer from y dtype (if we can)
2059
- if not is_regression:
2060
- try:
2061
- # Try to parse: y = df['target']
2062
- m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
2063
- if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
2064
- y = y_or_df[m.group(1)]
2065
- if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2066
- is_regression = True
2067
- else:
2068
- # If a Series was passed
2069
- y = y_or_df
2070
- if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2071
- is_regression = True
2072
- except Exception:
2073
- pass
2074
-
2075
- if is_regression:
2076
- # Strip classification-only lines
2077
- for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
2078
- code = re.sub(pat, "", code, flags=re.I)
2079
-
2080
- return code
2081
-
2082
-
2083
- def force_capture_display(code: str) -> str:
2084
- """
2085
- Ensure our executor captures HTML output:
2086
- - Remove any import that would override our 'display' hook.
2087
- - Keep/allow importing HTML only.
2088
- - Handle alias cases like 'display as d'.
2089
- """
2090
- import re
2091
- new = code
2092
-
2093
- # 'from IPython.display import display, HTML' -> keep HTML only
2094
- new = re.sub(
2095
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
2096
- r"from IPython.display import HTML\1", new
2097
- )
2098
-
2099
- # 'from IPython.display import display as d' -> 'd = display'
2100
- new = re.sub(
2101
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
2102
- r"\1 = display", new
2103
- )
2104
-
2105
- # 'from IPython.display import display' -> remove (use our injected display)
2106
- new = re.sub(
2107
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
2108
- r"# display import removed (SMX capture active)", new
2109
- )
2110
-
2111
- # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
2112
- new = re.sub(
2113
- r"(?m)\bIPython\.display\.display\s*\(",
2114
- "display(", new
2115
- )
2116
- new = re.sub(
2117
- r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
2118
- r"(?=.*import\s+IPython\.display\s+as\s+\1)",
2119
- "display(", new
2120
- )
2121
- return new
2122
-
2123
-
2124
- def strip_matplotlib_show(code: str) -> str:
2125
- """Remove blocking plt.show() calls (we export base64 instead)."""
2126
- import re
2127
- return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
2128
-
2129
-
2130
- def inject_display_shim(code: str) -> str:
2131
- """
2132
- Provide display()/HTML() if missing, forwarding to our executor hook.
2133
- Harmless if the names already exist.
2134
- """
2135
- shim = (
2136
- "try:\n"
2137
- " display\n"
2138
- "except NameError:\n"
2139
- " def display(obj=None, **kwargs):\n"
2140
- " __builtins__.get('_smx_display', print)(obj)\n"
2141
- "try:\n"
2142
- " HTML\n"
2143
- "except NameError:\n"
2144
- " class HTML:\n"
2145
- " def __init__(self, data): self.data = str(data)\n"
2146
- " def _repr_html_(self): return self.data\n"
2147
- "\n"
2148
- )
2149
- return shim + code
2150
-
2151
-
2152
- def strip_spurious_column_tokens(code: str) -> str:
2153
- """
2154
- Remove common stop-words ('the','whether', ...) when they appear
2155
- inside column lists, e.g.:
2156
- predictors = ['BMI','the','HbA1c']
2157
- df[['GGT','whether','BMI']]
2158
- Leaves other strings intact.
2159
- """
2160
- STOP = {
2161
- "the","whether","a","an","and","or","of","to","in","on","for","by",
2162
- "with","as","at","from","that","this","these","those","is","are","was","were",
2163
- "coef", "Coef", "coefficient", "Coefficient"
2164
- }
2165
-
2166
- def _norm(s: str) -> str:
2167
- return re.sub(r"[^a-z0-9]+", "", s.lower())
2168
-
2169
- def _clean_list(content: str) -> str:
2170
- # Rebuild a string list, keeping only non-stopword items
2171
- items = re.findall(r"(['\"])(.*?)\1", content)
2172
- if not items:
2173
- return "[" + content + "]"
2174
- keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
2175
- return "[" + ", ".join(keep) + "]"
2176
-
2177
- # Variable assignments: predictors/features/columns/cols = [...]
2178
- code = re.sub(
2179
- r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
2180
- lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
2181
- code
2182
- )
2183
-
2184
- # df[[ ... ]] selections
2185
- code = re.sub(
2186
- r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2187
-
2188
- return code
2189
-
2190
-
2191
- def patch_prefix_seaborn_calls(code: str) -> str:
2192
- """
2193
- Ensure bare seaborn calls are prefixed with `sns.`.
2194
- E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2195
- """
2196
- if "sns." in code:
2197
- # still fix any leftover bare calls alongside prefixed ones
2198
- pass
2199
-
2200
- # functions commonly used from seaborn
2201
- funcs = [
2202
- "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2203
- "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2204
- "scatterplot","lineplot","catplot","displot","lmplot"
2205
- ]
2206
- # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2207
- # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2208
- pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2209
-
2210
- def _add_prefix(m):
2211
- fn = m.group(1)
2212
- return f"sns.{fn}("
2213
-
2214
- return pattern.sub(_add_prefix, code)
2215
-
2216
-
2217
- def patch_ensure_seaborn_import(code: str) -> str:
2218
- """
2219
- If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2220
- Also set a quiet theme for consistent visuals.
2221
- """
2222
- needs_sns = "sns." in code
2223
- has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2224
- if needs_sns and not has_import:
2225
- # Insert after the first block of imports if possible, else at top
2226
- import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2227
- inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2228
- if import_block:
2229
- start = import_block.end()
2230
- code = code[:start] + inject + code[start:]
2231
- else:
2232
- code = inject + code
2233
- return code
2234
-
2235
-
2236
- def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2237
- """
2238
- Normalise pie-chart requests.
2239
-
2240
- Supports three patterns:
2241
- A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2242
- B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2243
- → one pie per facet level (grid) + counts bar of facet sizes.
2244
- C) Single pie when no split/facet is requested.
2245
-
2246
- Notes:
2247
- - Pie variables must be categorical (or numeric binned).
2248
- - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2249
- """
2250
-
2251
- q = (user_question or "")
2252
- q_low = q.lower()
2253
-
2254
- # Prefer explicit: df['col'].value_counts()
2255
- m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2256
- col = m.group(1) if m else None
2257
-
2258
- # ---------- helpers ----------
2259
- def _is_cat(col):
2260
- return (str(df[col].dtype).startswith("category")
2261
- or df[col].dtype == "object"
2262
- or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2263
-
2264
- def _cats_from_question(question: str):
2265
- found = []
2266
- for c in df.columns:
2267
- if c.lower() in question.lower() and _is_cat(c):
2268
- found.append(c)
2269
- # dedupe preserve order
2270
- seen, out = set(), []
2271
- for c in found:
2272
- if c not in seen:
2273
- out.append(c); seen.add(c)
2274
- return out
2275
-
2276
- def _fallback_cat():
2277
- cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2278
- if not cats: return None
2279
- cats.sort(key=lambda t: t[1])
2280
- return cats[0][0]
2281
-
2282
- def _infer_comp_pref(question: str) -> str:
2283
- ql = (question or "").lower()
2284
- if "heatmap" in ql or "matrix" in ql:
2285
- return "heatmap"
2286
- if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2287
- return "stacked_bar_pct"
2288
- if "stacked" in ql:
2289
- return "stacked_bar"
2290
- if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2291
- return "grouped_bar"
2292
- return "counts_bar"
2293
-
2294
- # parse threshold split like "HbA1c ≥ 6.5"
2295
- def _parse_split(question: str):
2296
- ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2297
- m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2298
- if not m: return None
2299
- col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2300
- op = ops_map.get(op_raw);
2301
- if not op: return None
2302
- # case-insensitive column match
2303
- candidates = {c.lower(): c for c in df.columns}
2304
- col = candidates.get(col_raw.lower())
2305
- if not col: return None
2306
- try: val = float(val_raw)
2307
- except Exception: return None
2308
- return (col, op, val)
2309
-
2310
- # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2311
- def _extract_facet(question: str):
2312
- # 1) explicit "by/ across / within / per <col>"
2313
- for kw in [" by ", " across ", " within ", " within each ", " per "]:
2314
- m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2315
- if m:
2316
- col_raw = m.group(1).strip()
2317
- candidates = {c.lower(): c for c in df.columns}
2318
- if col_raw.lower() in candidates:
2319
- return (candidates[col_raw.lower()], "auto")
2320
- # 2) "bin <col>"
2321
- m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2322
- if m2:
2323
- col_raw = m2.group(1).strip()
2324
- candidates = {c.lower(): c for c in df.columns}
2325
- if col_raw.lower() in candidates:
2326
- return (candidates[col_raw.lower()], "bin")
2327
- # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2328
- if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2329
- any(c.lower() == "bmi" for c in df.columns.str.lower()):
2330
- bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2331
- return (bmi_col, "bmi")
2332
- return None
2333
-
2334
- def _bmi_bins(series: pd.Series):
2335
- # WHO cutoffs
2336
- bins = [-np.inf, 18.5, 25, 30, np.inf]
2337
- labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2338
- return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2339
-
2340
- wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2341
- if not wants_pie:
2342
- return code
2343
-
2344
- split = _parse_split(q)
2345
- facet = _extract_facet(q)
2346
- cats = _cats_from_question(q)
2347
- _comp_pref = _infer_comp_pref(q)
2348
-
2349
- # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2350
- for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2351
- if hard in df.columns and hard not in cats and hard.lower() in q_low:
2352
- cats.append(hard)
2353
-
2354
- # --------------- CASE A: threshold split (cohorts) ---------------
2355
- if split:
2356
- if not (cats or any(_is_cat(c) for c in df.columns)):
2357
- return code
2358
- if not cats:
2359
- pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2360
- pool.sort(key=lambda t: t[1])
2361
- cats = [t[0] for t in pool[:3]] if pool else []
2362
- if not cats:
2363
- return code
2364
-
2365
- split_col, op, val = split
2366
- cond_str = f"(df['{split_col}'] {op} {val})"
2367
- snippet = f"""
2368
- import numpy as np
2369
- import pandas as pd
2370
- import matplotlib.pyplot as plt
2371
-
2372
- _mask_a = ({cond_str}) & df['{split_col}'].notna()
2373
- _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2374
-
2375
- _cohort_a_name = "{split_col} {op} {val}"
2376
- _cohort_b_name = "NOT ({split_col} {op} {val})"
2377
-
2378
- _cat_cols = {cats!r}
2379
- n = len(_cat_cols)
2380
- fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2381
- if n == 1:
2382
- axes = np.array([axes])
2383
-
2384
- for i, col in enumerate(_cat_cols):
2385
- s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2386
- s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2387
-
2388
- ax_a = axes[i, 0]; ax_b = axes[i, 1]
2389
- if len(s_a) > 0:
2390
- ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2391
- autopct='%1.1f%%', startangle=90, counterclock=False)
2392
- ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2393
-
2394
- if len(s_b) > 0:
2395
- ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2396
- autopct='%1.1f%%', startangle=90, counterclock=False)
2397
- ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2398
-
2399
- plt.tight_layout(); plt.show()
2400
-
2401
- # grouped bar complement
2402
- for col in _cat_cols:
2403
- _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2404
- .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2405
- _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2406
- _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2407
-
2408
- if _comp_pref == "grouped_bar":
2409
- ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2410
- ax.set_title(f"{col} by cohort (grouped)")
2411
- ax.set_xlabel(col); ax.set_ylabel("Count")
2412
- plt.tight_layout(); plt.show()
2413
-
2414
- elif _comp_pref == "stacked_bar":
2415
- ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2416
- ax.set_title(f"{col} by cohort (stacked)")
2417
- ax.set_xlabel(col); ax.set_ylabel("Count")
2418
- plt.tight_layout(); plt.show()
2419
-
2420
- elif _comp_pref == "stacked_bar_pct":
2421
- _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2422
- ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2423
- ax.set_title(f"{col} by cohort (100% stacked)")
2424
- ax.set_xlabel(col); ax.set_ylabel("Percent")
2425
- plt.tight_layout(); plt.show()
2426
-
2427
- elif _comp_pref == "heatmap":
2428
- _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2429
- import numpy as np
2430
- fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2431
- im = ax.imshow(_perc.values, aspect='auto')
2432
- ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2433
- ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2434
- ax.set_title(f"{col} by cohort — % heatmap")
2435
- for i in range(_perc.shape[0]):
2436
- for j in range(_perc.shape[1]):
2437
- ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2438
- fig.colorbar(im, ax=ax, label="%")
2439
- plt.tight_layout(); plt.show()
2440
-
2441
- else: # counts_bar (default)
2442
- ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2443
- ax.set_title(f"{col}: total counts (both cohorts)")
2444
- ax.set_xlabel(col); ax.set_ylabel("Count")
2445
- plt.tight_layout(); plt.show()
2446
- """.lstrip()
2447
- return snippet
2448
-
2449
- # --------------- CASE B: facet-by (categories/bins) ---------------
2450
- if facet:
2451
- facet_col, how = facet
2452
- # Build facet series
2453
- if pd.api.types.is_numeric_dtype(df[facet_col]):
2454
- if how == "bmi":
2455
- facet_series = _bmi_bins(df[facet_col])
2456
- else:
2457
- # generic numeric bins: 3 equal-width bins by default
2458
- facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2459
- else:
2460
- facet_series = df[facet_col].astype(str)
2461
-
2462
- # Choose pie dimension (categorical to count inside each facet)
2463
- pie_dim = None
2464
- for c in cats:
2465
- if c in df.columns and _is_cat(c):
2466
- pie_dim = c; break
2467
- if pie_dim is None:
2468
- pie_dim = _fallback_cat()
2469
- if pie_dim is None:
2470
- return code
2471
-
2472
- snippet = f"""
2473
- import math
2474
- import pandas as pd
2475
- import matplotlib.pyplot as plt
2476
-
2477
- df = df.copy()
2478
- _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2479
-
2480
- def _select_facet_col(df, preferred=None):
2481
- if preferred is not None:
2482
- return preferred
2483
- # Prefer low-cardinality categoricals (readable pies/grids)
2484
- cat_cols = [
2485
- c for c in df.columns
2486
- if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2487
- and df[c].nunique() > 1 and df[c].nunique() <= 20
2488
- ]
2489
- if cat_cols:
2490
- cat_cols.sort(key=lambda c: df[c].nunique())
2491
- return cat_cols[0]
2492
- # Else fall back to first usable numeric
2493
- num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2494
- return num_cols[0] if num_cols else None
2495
-
2496
- _facet_col = _select_facet_col(df, _preferred)
2497
-
2498
- if _facet_col is None:
2499
- # Nothing suitable → single facet keeps pipeline alive
2500
- df["__facet__"] = "All"
2501
- else:
2502
- s = df[_facet_col]
2503
- if pd.api.types.is_numeric_dtype(s):
2504
- # Robust numeric binning: quantiles first, fallback to equal-width
2505
- uniq = pd.Series(s).dropna().nunique()
2506
- q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2507
- try:
2508
- df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2509
- except Exception:
2510
- df["__facet__"] = pd.cut(s.astype(float), bins=q)
2511
- else:
2512
- # Cap long tails; keep top categories
2513
- vc = s.astype(str).value_counts()
2514
- keep = vc.index[:{top_n}]
2515
- df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2516
-
2517
- levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2518
- levels = [x for x in levels if x != "nan"]
2519
- levels.sort()
2520
-
2521
- m = len(levels)
2522
- cols = 3 if m >= 3 else m or 1
2523
- rows = int(math.ceil(m / cols))
2524
-
2525
- fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2526
- if not isinstance(axes, (list, np.ndarray)):
2527
- axes = np.array([[axes]])
2528
- axes = axes.reshape(rows, cols)
2529
-
2530
- for i, lvl in enumerate(levels):
2531
- r, c = divmod(i, cols)
2532
- ax = axes[r, c]
2533
- s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2534
- .astype(str).value_counts().nlargest({top_n}))
2535
- if len(s) > 0:
2536
- ax.pie(s.values, labels=[str(x) for x in s.index],
2537
- autopct='%1.1f%%', startangle=90, counterclock=False)
2538
- ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2539
-
2540
- # hide any empty subplots
2541
- for j in range(m, rows*cols):
2542
- r, c = divmod(j, cols)
2543
- axes[r, c].axis("off")
2544
-
2545
- plt.tight_layout(); plt.show()
2546
-
2547
- # --- companion visual (adaptive) ---
2548
- _comp_pref = "{_comp_pref}"
2549
- # build contingency table: pie_dim x facet
2550
- _tab = (df[["__facet__", "{pie_dim}"]]
2551
- .dropna()
2552
- .astype({{"__facet__": str, "{pie_dim}": str}})
2553
- .value_counts()
2554
- .unstack(level="__facet__")
2555
- .fillna(0))
2556
-
2557
- # keep top categories by overall size
2558
- _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2559
-
2560
- if _comp_pref == "grouped_bar":
2561
- ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2562
- ax.set_title("{pie_dim} by {facet_col} (grouped)")
2563
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2564
- plt.tight_layout(); plt.show()
2565
-
2566
- elif _comp_pref == "stacked_bar":
2567
- ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2568
- ax.set_title("{pie_dim} by {facet_col} (stacked)")
2569
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2570
- plt.tight_layout(); plt.show()
2571
-
2572
- elif _comp_pref == "stacked_bar_pct":
2573
- _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2574
- ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2575
- ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2576
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2577
- plt.tight_layout(); plt.show()
2578
-
2579
- elif _comp_pref == "heatmap":
2580
- _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2581
- import numpy as np
2582
- fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2583
- im = ax.imshow(_perc.values, aspect='auto')
2584
- ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2585
- ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2586
- ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2587
- for i in range(_perc.shape[0]):
2588
- for j in range(_perc.shape[1]):
2589
- ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2590
- fig.colorbar(im, ax=ax, label="%")
2591
- plt.tight_layout(); plt.show()
2592
-
2593
- else: # counts_bar (default denominators)
2594
- _counts = df["__facet"].value_counts()
2595
- ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2596
- ax.set_title("Counts by {facet_col}")
2597
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2598
- plt.tight_layout(); plt.show()
2599
-
2600
- """.lstrip()
2601
- return snippet
2602
-
2603
- # --------------- CASE C: single pie ---------------
2604
- chosen = None
2605
- for c in cats:
2606
- if c in df.columns and _is_cat(c):
2607
- chosen = c; break
2608
- if chosen is None:
2609
- chosen = _fallback_cat()
2610
-
2611
- if chosen:
2612
- snippet = f"""
2613
- import matplotlib.pyplot as plt
2614
- counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2615
- fig, ax = plt.subplots()
2616
- if len(counts) > 0:
2617
- ax.pie(counts.values, labels=[str(i) for i in counts.index],
2618
- autopct='%1.1f%%', startangle=90, counterclock=False)
2619
- ax.set_title('Distribution of {chosen} (top {top_n})')
2620
- ax.axis('equal')
2621
- plt.show()
2622
- """.lstrip()
2623
- return snippet
2624
-
2625
- # numeric last resort
2626
- num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2627
- if num_cols:
2628
- col = num_cols[0]
2629
- snippet = f"""
2630
- import pandas as pd
2631
- import matplotlib.pyplot as plt
2632
- bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2633
- counts = bins.value_counts().sort_index()
2634
- fig, ax = plt.subplots()
2635
- if len(counts) > 0:
2636
- ax.pie(counts.values, labels=[str(i) for i in counts.index],
2637
- autopct='%1.1f%%', startangle=90, counterclock=False)
2638
- ax.set_title('Distribution of {col} (binned)')
2639
- ax.axis('equal')
2640
- plt.show()
2641
- """.lstrip()
2642
- return snippet
2643
-
2644
- return code
2645
-
2646
-
2647
- def patch_fix_seaborn_palette_calls(code: str) -> str:
2648
- """
2649
- Removes seaborn `palette=` when no `hue=` is present in the same call.
2650
- Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2651
- """
2652
- if "sns." not in code:
2653
- return code
2654
-
2655
- # Targets common seaborn plotters
2656
- funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2657
- pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2658
-
2659
- def _fix_call(m):
2660
- head, inner = m.group(1), m.group(2)
2661
- # If there's already hue=, keep as is
2662
- if re.search(r"(?<!\w)hue\s*=", inner):
2663
- return f"{head}{inner})"
2664
- # Otherwise remove palette=... safely (and any adjacent comma spacing)
2665
- inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2666
- inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2667
- inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2668
- return f"{head}{inner2})"
2669
-
2670
- return pattern.sub(_fix_call, code)
2671
-
2672
-
2673
- def patch_quiet_specific_warnings(code: str) -> str:
2674
- """
2675
- Inserts targeted warning filters (not blanket ignores).
2676
- - seaborn palette/hue deprecation
2677
- - python-dotenv parse chatter
2678
- """
2679
- prelude = (
2680
- "import warnings\n"
2681
- "warnings.filterwarnings(\n"
2682
- " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
2683
- "warnings.filterwarnings(\n"
2684
- " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
2685
- )
2686
- # If warnings already imported once, just add filters; else insert full prelude.
2687
- if "import warnings" in code:
2688
- code = re.sub(
2689
- r"(import warnings[^\n]*\n)",
2690
- lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
2691
- code,
2692
- count=1
2693
- )
2694
-
2695
- else:
2696
- # place after first import block if possible
2697
- m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
2698
- if m:
2699
- idx = m.end()
2700
- code = code[:idx] + prelude + code[idx:]
2701
- else:
2702
- code = prelude + code
2703
- return code
2704
-
2705
-
2706
- def _norm_col_name(s: str) -> str:
2707
- """normalise a column name: lowercase + strip non-alphanumerics."""
2708
- return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2709
-
2710
-
2711
- def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2712
- """return the actual df column that matches any candidate (after normalisation)."""
2713
- norm_map = {_norm_col_name(c): c for c in df.columns}
2714
- for cand in candidates:
2715
- hit = norm_map.get(_norm_col_name(cand))
2716
- if hit is not None:
2717
- return hit
2718
- return None
2719
-
2720
-
2721
- def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2722
- """
2723
- If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2724
- Returns (df, found_bool).
2725
- """
2726
- if target in df.columns:
2727
- return df, True
2728
- col = _first_present(df, [target, *aliases])
2729
- if col is None:
2730
- return df, False
2731
- df[target] = df[col]
2732
- return df, True
2733
-
2734
-
2735
- def strip_python_dotenv(code: str) -> str:
2736
- """
2737
- Remove any use of python-dotenv from generated code, including:
2738
- - single and multi-line 'from dotenv import ...'
2739
- - 'import dotenv' (with or without alias) and calls via any alias
2740
- - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2741
- - IPython magics (%load_ext dotenv, %dotenv, %env …)
2742
- - shell installs like '!pip install python-dotenv'
2743
- """
2744
- original = code
2745
-
2746
- # 0) Kill IPython magics & shell installs referencing dotenv
2747
- code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2748
- code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2749
- code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2750
- code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2751
-
2752
- # 1) Remove single-line 'from dotenv import ...'
2753
- code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2754
-
2755
- # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2756
- code = re.sub(
2757
- r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2758
- "",
2759
- code,
2760
- flags=re.MULTILINE,
2761
- )
2762
-
2763
- # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2764
- aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2765
- code, flags=re.MULTILINE)
2766
- code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2767
- "", code, flags=re.MULTILINE)
2768
-
2769
- # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2770
- # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2771
- fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2772
- # bare calls
2773
- code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2774
- # dotted calls with any identifier prefix (alias or module)
2775
- code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2776
- "", code, flags=re.MULTILINE)
2777
-
2778
- # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2779
- for al in aliases:
2780
- code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2781
-
2782
- # 6) Tidy excess blank lines
2783
- code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2784
- return code
2785
-
2786
-
2787
- def fix_predict_calls_records_arg(code: str) -> str:
2788
- """
2789
- If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2790
- (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2791
- Works line-by-line to avoid over-rewrites elsewhere.
2792
- Examples fixed:
2793
- predict_patient(X_test.iloc[:5].to_dict('records'))
2794
- predict_risk(df.head(3).to_dict(orient="records"))
2795
- → predict_patient(X_test.iloc[:5])
2796
- """
2797
- fixed_lines = []
2798
- for line in code.splitlines():
2799
- if "predict_" in line and "to_dict" in line and "records" in line:
2800
- line = re.sub(
2801
- r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2802
- "",
2803
- line
2804
- )
2805
- fixed_lines.append(line)
2806
- return "\n".join(fixed_lines)
2807
-
2808
-
2809
- def fix_fstring_backslash_paths(code: str) -> str:
2810
- """
2811
- Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2812
- → f"...{os.path.join(out_dir, r'plots\\img.png')}"
2813
- Only touches f-strings that contain a backslash path inside {...}.
2814
- """
2815
- def _fix_line(line: str) -> str:
2816
- # quick check: only f-strings need scanning
2817
- if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2818
- return line
2819
- # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2820
- pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2821
- def repl(m):
2822
- left = m.group(1)
2823
- right = m.group(2).strip().replace('"', '\\"')
2824
- return "{os.path.join(" + left + ', r"' + right + '")}'
2825
- return pattern.sub(repl, line)
2826
-
2827
- return "\n".join(_fix_line(ln) for ln in code.splitlines())
2828
-
2829
-
2830
- def ensure_os_import(code: str) -> str:
2831
- """
2832
- If os.path.join is used but 'import os' is missing, inject it at the top.
2833
- """
2834
- needs = "os.path.join(" in code
2835
- has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2836
- has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2837
- if needs and not (has_import_os or has_from_os):
2838
- return "import os\n" + code
2839
- return code
2840
-
2841
-
2842
- def fix_seaborn_boxplot_nameerror(code: str) -> str:
2843
- """
2844
- Fix bad calls like: sns.boxplot(boxplot)
2845
- Heuristic:
2846
- - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2847
- - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
2848
- - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2849
- - Else → sns.boxplot(ax=ax)
2850
- Ensures a matplotlib Axes 'ax' exists.
2851
- """
2852
- pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2853
- if not pattern.search(code):
2854
- return code
2855
-
2856
- has_plot_df = re.search(r"\bplot_df\b", code) is not None
2857
- has_var = re.search(r"\bvar\b", code) is not None
2858
- has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2859
-
2860
- if has_plot_df and has_var and has_fh:
2861
- replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2862
- elif has_plot_df and has_var:
2863
- replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2864
- elif has_plot_df:
2865
- replacement = "sns.boxplot(data=plot_df, ax=ax)"
2866
- else:
2867
- replacement = "sns.boxplot(ax=ax)"
2868
-
2869
- fixed = pattern.sub(replacement, code)
2870
-
2871
- # Ensure 'fig, ax = plt.subplots(...)' exists
2872
- if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2873
- # Insert right before the first seaborn call
2874
- m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2875
- insert_at = m.start() if m else 0
2876
- fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2877
-
2878
- return fixed
2879
-
2880
-
2881
- def fix_seaborn_barplot_nameerror(code: str) -> str:
2882
- """
2883
- Fix bad calls like: sns.barplot(barplot)
2884
- Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2885
- otherwise degrade safely to an empty call on an existing Axes.
2886
- """
2887
- import re
2888
- pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2889
- if not pattern.search(code):
2890
- return code
2891
-
2892
- has_plot_df = re.search(r"\bplot_df\b", code) is not None
2893
- has_var = re.search(r"\bvar\b", code) is not None
2894
- has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2895
-
2896
- if has_plot_df and has_var and has_fh:
2897
- replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2898
- elif has_plot_df and has_var:
2899
- replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2900
- elif has_plot_df:
2901
- replacement = "sns.barplot(data=plot_df, ax=ax)"
2902
- else:
2903
- replacement = "sns.barplot(ax=ax)"
2904
-
2905
- # ensure an Axes 'ax' exists (no-op if already present)
2906
- if "ax =" not in code:
2907
- code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2908
-
2909
- return pattern.sub(replacement, code)
2910
-
2911
-
2912
- def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2913
- """
2914
- Parses the raw text to extract and format the 'refined question',
2915
- 'intents (tasks)', and 'chronology of tasks' sections.
2916
- Args:
2917
- raw_text: The complete input string containing the ML pipeline structure.
2918
- Returns:
2919
- A tuple containing:
2920
- (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2921
- """
2922
- # --- 1. Regex Pattern to Extract Sections ---
2923
- # The pattern uses capturing groups (?) to look for the section headers
2924
- # (e.g., 'refined question:') and captures all the content until the next
2925
- # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2926
-
2927
- pattern = re.compile(
2928
- r"refined question:(?P<question>.*?)"
2929
- r"intents \(tasks\):(?P<intents>.*?)"
2930
- r"Chronology of tasks:(?P<chronology>.*)",
2931
- re.IGNORECASE | re.DOTALL
2932
- )
2933
-
2934
- match = pattern.search(raw_text)
2935
-
2936
- if not match:
2937
- raise ValueError("Input text structure does not match the expected pattern.")
2938
-
2939
- # --- 2. Extract Content ---
2940
- question_content = match.group('question').strip()
2941
- intents_content = match.group('intents').strip()
2942
- chronology_content = match.group('chronology').strip()
2943
-
2944
- # --- 3. Formatting Functions ---
2945
-
2946
- def format_question(content):
2947
- """Formats the Refined Question section."""
2948
- # Clean up leading/trailing whitespace and ensure clean paragraphs
2949
- content = content.strip().replace('\n', ' ').replace(' ', ' ')
2950
-
2951
- # Simple formatting using Markdown headers and bolding
2952
- formatted = (
2953
- # "## 1. Project Goal and Objectives\n\n"
2954
- "<b> Refined Question:</b>\n"
2955
- f"{content}\n"
2956
- )
2957
- return formatted
2958
-
2959
- def format_intents(content):
2960
- """Formats the Intents (Tasks) section as a structured list."""
2961
- # Use regex to find and format each numbered task
2962
- # It finds 'N. **Text** - ...' and breaks it down.
2963
-
2964
- tasks = []
2965
- # Pattern: N. **Text** - Content (including newlines, non-greedy)
2966
- # We need to explicitly handle the list items starting with '-' within the content
2967
- task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
2968
-
2969
- # Split the content by lines and join tasks back into clean strings
2970
- raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
2971
-
2972
- for task in raw_tasks:
2973
- # Replace the initial task number and **Heading** with a Heading 3
2974
- task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
2975
-
2976
- # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
2977
- task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
2978
- tasks.append(task)
2979
-
2980
- formatted_tasks = "\n\n".join(tasks)
2981
-
2982
- return (
2983
- "\n---\n"
2984
- "## 2. Methodology and Tasks\n\n"
2985
- f"{formatted_tasks}\n"
2986
- )
2987
-
2988
- def format_chronology(content):
2989
- """Formats the Chronology section."""
2990
- # Uses the given LaTeX format
2991
- content = content.strip().replace(' ', ' \rightarrow ')
2992
- formatted = (
2993
- "\n---\n"
2994
- "## 3. Chronology of Tasks\n"
2995
- f"$$\\text{{{content}}}$$"
2996
- )
2997
- return formatted
2998
-
2999
- # --- 4. Format and Return ---
3000
- formatted_question = format_question(question_content)
3001
- formatted_intents = format_intents(intents_content)
3002
- formatted_chronology = format_chronology(chronology_content)
3003
-
3004
- return formatted_question, formatted_intents, formatted_chronology
3005
-
3006
-
3007
- def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
3008
- """Combines all formatted parts into a final report string."""
3009
- return (
3010
- "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
3011
- f"{formatted_question}\n"
3012
- f"{formatted_intents}\n"
3013
- f"{formatted_chronology}\n"
3014
- )
3015
-
3016
-
3017
- def fix_confusion_matrix_for_multilabel(code: str) -> str:
3018
- """
3019
- Replace ConfusionMatrixDisplay.from_estimator(...) usages with
3020
- from_predictions(...) which works for multi-label loops without requiring
3021
- the estimator to expose _estimator_type.
3022
- """
3023
- return re.sub(
3024
- r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
3025
- r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
3026
- code
3027
- )
3028
-
3029
-
3030
- def smx_auto_title_plots(ctx=None, fallback="Analysis"):
3031
- """
3032
- Ensure every Matplotlib/Seaborn Axes has a title.
3033
- Uses refined_question -> askai_question -> fallback.
3034
- Only sets a title if it's currently empty.
3035
- """
3036
- import matplotlib.pyplot as plt
3037
-
3038
- def _all_figures():
3039
- try:
3040
- from matplotlib._pylab_helpers import Gcf
3041
- return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
3042
- except Exception:
3043
- # Best effort fallback
3044
- nums = plt.get_fignums()
3045
- return [plt.figure(n) for n in nums] if nums else []
3046
-
3047
- # Choose a concise title
3048
- title = None
3049
- if isinstance(ctx, dict):
3050
- title = ctx.get("refined_question") or ctx.get("askai_question")
3051
- title = (str(title).strip().splitlines()[0][:120]) if title else fallback
3052
-
3053
- for fig in _all_figures():
3054
- for ax in getattr(fig, "axes", []):
3055
- try:
3056
- if not (ax.get_title() or "").strip():
3057
- ax.set_title(title)
3058
- except Exception:
3059
- pass
3060
- try:
3061
- fig.tight_layout()
3062
- except Exception:
3063
- pass
3064
-
3065
-
3066
- def patch_fix_sentinel_plot_calls(code: str) -> str:
3067
- """
3068
- Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
3069
- SB_barplot(barplot) -> SB_barplot()
3070
- SB_barplot(barplot, ...) -> SB_barplot(...)
3071
- sns.barplot(barplot) -> SB_barplot()
3072
- sns.barplot(barplot, ...) -> SB_barplot(...)
3073
- Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
3074
- """
3075
- names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
3076
- for n in names:
3077
- # SB_* with sentinel as the first arg (with or without trailing args)
3078
- code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3079
- code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3080
- # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
3081
- code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3082
- code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3083
- return code
3084
-
3085
-
3086
- def patch_rmse_calls(code: str) -> str:
3087
- """
3088
- Make RMSE robust across sklearn versions.
3089
- - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
3090
- - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
3091
- """
3092
- import re
3093
- # (a) Specific RMSE pattern
3094
- code = re.sub(
3095
- r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3096
- r"_SMX_rmse(\1)",
3097
- code,
3098
- flags=re.DOTALL
3099
- )
3100
- # (b) Guard any other MSE calls
3101
- code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3102
- return code