syntaxmatrix 2.5.6__py3-none-any.whl → 2.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. syntaxmatrix/agentic/agents.py +1220 -169
  2. syntaxmatrix/agentic/agents_orchestrer.py +326 -0
  3. syntaxmatrix/agentic/code_tools_registry.py +27 -32
  4. syntaxmatrix/commentary.py +16 -16
  5. syntaxmatrix/core.py +185 -81
  6. syntaxmatrix/db.py +460 -4
  7. syntaxmatrix/{display.py → display_html.py} +2 -6
  8. syntaxmatrix/gpt_models_latest.py +1 -1
  9. syntaxmatrix/media/__init__.py +0 -0
  10. syntaxmatrix/media/media_pixabay.py +277 -0
  11. syntaxmatrix/models.py +1 -1
  12. syntaxmatrix/page_builder_defaults.py +183 -0
  13. syntaxmatrix/page_builder_generation.py +1122 -0
  14. syntaxmatrix/page_layout_contract.py +644 -0
  15. syntaxmatrix/page_patch_publish.py +1471 -0
  16. syntaxmatrix/preface.py +142 -21
  17. syntaxmatrix/profiles.py +28 -10
  18. syntaxmatrix/routes.py +1740 -453
  19. syntaxmatrix/selftest_page_templates.py +360 -0
  20. syntaxmatrix/settings/client_items.py +28 -0
  21. syntaxmatrix/settings/model_map.py +1022 -207
  22. syntaxmatrix/settings/prompts.py +328 -130
  23. syntaxmatrix/static/assets/hero-default.svg +22 -0
  24. syntaxmatrix/static/icons/bot-icon.png +0 -0
  25. syntaxmatrix/static/icons/favicon.png +0 -0
  26. syntaxmatrix/static/icons/logo.png +0 -0
  27. syntaxmatrix/static/icons/logo3.png +0 -0
  28. syntaxmatrix/templates/admin_branding.html +104 -0
  29. syntaxmatrix/templates/admin_features.html +63 -0
  30. syntaxmatrix/templates/admin_secretes.html +108 -0
  31. syntaxmatrix/templates/dashboard.html +296 -133
  32. syntaxmatrix/templates/dataset_resize.html +535 -0
  33. syntaxmatrix/templates/edit_page.html +2535 -0
  34. syntaxmatrix/utils.py +2431 -2383
  35. {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/METADATA +6 -2
  36. {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/RECORD +39 -24
  37. syntaxmatrix/generate_page.py +0 -644
  38. syntaxmatrix/static/icons/hero_bg.jpg +0 -0
  39. {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/WHEEL +0 -0
  40. {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/licenses/LICENSE.txt +0 -0
  41. {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py CHANGED
@@ -1,96 +1,46 @@
1
1
  from __future__ import annotations
2
- import re, textwrap
3
- import pandas as pd
4
- import numpy as np
2
+ import re, textwrap, ast
3
+ import pandas as pd, numpy as np
5
4
  import warnings
6
-
7
5
  from difflib import get_close_matches
8
6
  from typing import Iterable, Tuple, Dict
9
7
  import inspect
10
8
  from sklearn.preprocessing import OneHotEncoder
11
-
12
-
13
- from syntaxmatrix.agentic.model_templates import (
14
- classification, regression, multilabel_classification,
15
- eda_overview, eda_correlation,
16
- anomaly_detection, ts_anomaly_detection,
17
- dimensionality_reduction, feature_selection,
18
- time_series_forecasting, time_series_classification,
19
- unknown_group_proxy_pack, viz_line,
20
- clustering, recommendation, topic_modelling,
21
- viz_pie, viz_count_bar, viz_box, viz_scatter,
22
- viz_stacked_bar, viz_distribution, viz_area, viz_kde,
23
- )
24
9
  import ast
25
10
 
26
11
  warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
27
12
 
28
- INJECTABLE_INTENTS = [
29
- "classification",
30
- "multilabel_classification",
31
- "regression",
32
- "anomaly_detection",
33
- "time_series_forecasting",
34
- "time_series_classification",
35
- "ts_anomaly_detection",
36
- "dimensionality_reduction",
37
- "feature_selection",
38
- "clustering",
39
- "eda",
40
- "correlation_analysis",
41
- "visualisation",
42
- "recommendation",
43
- "topic_modelling",
44
- ]
45
-
46
- def classify_ml_job(prompt: str) -> str:
13
+ def patch_quiet_specific_warnings(code: str) -> str:
47
14
  """
48
- Very-light intent classifier.
49
- Returns one of:
50
- 'stat_test' | 'time_series' | 'clustering'
51
- 'classification' | 'regression' | 'eda'
15
+ Inserts targeted warning filters (not blanket ignores).
16
+ - seaborn palette/hue deprecation
17
+ - python-dotenv parse chatter
52
18
  """
53
- p = prompt.lower()
54
-
55
- greetings = {"hi", "hello", "hey", "good morning", "good afternoon", "good evening", "greetings"}
56
- if any(p.startswith(g) or p == g for g in greetings):
57
- return "greeting"
58
-
59
- # Feature selection / importance intent
60
- if any(k in p for k in (
61
- "feature selection", "select k best", "selectkbest", "rfe",
62
- "mutual information", "feature importance", "permutation importance",
63
- "feature engineering suggestions"
64
- )):
65
- return "feature_selection"
66
-
67
- # Dimensionality reduction intent
68
- if any(k in p for k in (
69
- "pca", "principal component", "dimensionality reduction",
70
- "reduce dimension", "reduce dimensionality", "t-sne", "tsne", "umap"
71
- )):
72
- return "dimensionality_reduction"
73
-
74
- # Anomaly / outlier intent
75
- if any(k in p for k in (
76
- "anomaly", "anomalies", "outlier", "outliers", "novelty",
77
- "fraud", "deviation", "rare event", "rare events", "odd pattern",
78
- "suspicious"
79
- )):
80
- return "anomaly_detection"
81
-
82
- if any(k in p for k in ("t-test", "anova", "p-value")):
83
- return "stat_test"
84
- if "forecast" in p or "prophet" in p:
85
- return "time_series"
86
- if "cluster" in p or "kmeans" in p:
87
- return "clustering"
88
- if any(k in p for k in ("accuracy", "precision", "roc")):
89
- return "classification"
90
- if any(k in p for k in ("rmse", "r2", "mae")):
91
- return "regression"
92
-
93
- return "eda"
19
+ prelude = (
20
+ "import warnings\n"
21
+ "warnings.filterwarnings(\n"
22
+ " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
23
+ "warnings.filterwarnings(\n"
24
+ " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
25
+ )
26
+ # If warnings already imported once, just add filters; else insert full prelude.
27
+ if "import warnings" in code:
28
+ code = re.sub(
29
+ r"(import warnings[^\n]*\n)",
30
+ lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
31
+ code,
32
+ count=1
33
+ )
34
+
35
+ else:
36
+ # place after first import block if possible
37
+ m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
38
+ if m:
39
+ idx = m.end()
40
+ code = code[:idx] + prelude + code[idx:]
41
+ else:
42
+ code = prelude + code
43
+ return code
94
44
 
95
45
 
96
46
  def _indent(code: str, spaces: int = 4) -> str:
@@ -123,7 +73,7 @@ def wrap_llm_code_safe(body: str) -> str:
123
73
  " df_local = globals().get('df')\n"
124
74
  " if df_local is not None:\n"
125
75
  " import pandas as pd\n"
126
- " from syntaxmatrix.preface import SB_histplot, _SMX_export_png\n"
76
+ " from syntaxmatrix.preface import SB_histplot, SB_boxplot, SB_scatterplot, SB_heatmap, _SMX_export_png\n"
127
77
  " num_cols = df_local.select_dtypes(include=['number', 'bool']).columns.tolist()\n"
128
78
  " cat_cols = [c for c in df_local.columns if c not in num_cols]\n"
129
79
  " info = {\n"
@@ -137,11 +87,78 @@ def wrap_llm_code_safe(body: str) -> str:
137
87
  " if num_cols:\n"
138
88
  " SB_histplot()\n"
139
89
  " _SMX_export_png()\n"
90
+ " if len(num_cols) >= 2:\n"
91
+ " SB_scatterplot()\n"
92
+ " _SMX_export_png()\n"
93
+ " if num_cols and cat_cols:\n"
94
+ " SB_boxplot()\n"
95
+ " _SMX_export_png()\n"
96
+ " if len(num_cols) >= 2:\n"
97
+ " SB_heatmap()\n"
98
+ " _SMX_export_png()\n"
140
99
  " except Exception as _f:\n"
141
100
  " show(f\"⚠️ Fallback EDA failed: {type(_f).__name__}: {_f}\")\n"
142
101
  )
143
102
 
144
103
 
104
+ def fix_print_html(code: str) -> str:
105
+ """
106
+ Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
107
+ not printed as `<IPython.core.display.HTML object>` to the server console.
108
+ - Rewrites: print(HTML(...)) → display(HTML(...))
109
+ print(display(...)) → display(...)
110
+ print(df.to_html(...)) → display(HTML(df.to_html(...)))
111
+ Also prepends `from IPython.display import display, HTML` if required.
112
+ """
113
+ import re
114
+
115
+ new = code
116
+
117
+ # 1) print(HTML(...)) -> display(HTML(...))
118
+ new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
119
+
120
+ # 2) print(display(...)) -> display(...)
121
+ new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
122
+
123
+ # 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
124
+ new = re.sub(
125
+ r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
126
+ r"display(HTML(\1.to_html(", new
127
+ )
128
+
129
+ # If code references HTML() or display() make sure the import exists
130
+ if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
131
+ "from IPython.display import display, HTML" not in new:
132
+ new = "from IPython.display import display, HTML\n" + new
133
+
134
+ return new
135
+
136
+
137
+ def ensure_ipy_display(code: str) -> str:
138
+ """
139
+ Guarantee that the cell has proper IPython display imports so that
140
+ display(HTML(...)) produces 'display_data' events the kernel captures.
141
+ """
142
+ if "display(" in code and "from IPython.display import display, HTML" not in code:
143
+ return "from IPython.display import display, HTML\n" + code
144
+ return code
145
+
146
+
147
+ def fix_plain_prints(code: str) -> str:
148
+ """
149
+ Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
150
+ to go through SyntaxMatrix's smart display (so it renders in the dashboard).
151
+ Keeps string prints alone.
152
+ """
153
+
154
+ # Skip obvious string-literal prints
155
+ new = re.sub(
156
+ r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
157
+ r"from syntaxmatrix.display import show\nshow(\1)",
158
+ code,
159
+ )
160
+ return new
161
+
145
162
  def harden_ai_code(code: str) -> str:
146
163
  """
147
164
  Make any AI-generated cell resilient:
@@ -155,31 +172,6 @@ def harden_ai_code(code: str) -> str:
155
172
  # Remove any LLM-added try/except blocks (hardener adds its own)
156
173
  import re
157
174
 
158
- def strip_placeholders(code: str) -> str:
159
- code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
160
- "show('⚠ Block skipped due to an error.')",
161
- code)
162
- code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
163
- return code
164
-
165
- def _SMX_OHE(**k):
166
- # normalise arg name across sklearn versions
167
- if "sparse" in k and "sparse_output" not in k:
168
- k["sparse_output"] = k.pop("sparse")
169
- # default behaviour we want
170
- k.setdefault("handle_unknown", "ignore")
171
- k.setdefault("sparse_output", False)
172
- try:
173
- # if running on old sklearn without sparse_output, translate back
174
- if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
175
- if "sparse_output" in k:
176
- k["sparse"] = k.pop("sparse_output")
177
- return OneHotEncoder(**k)
178
- except TypeError:
179
- # final fallback: try legacy name
180
- if "sparse_output" in k:
181
- k["sparse"] = k.pop("sparse_output")
182
- return OneHotEncoder(**k)
183
175
 
184
176
  def _strip_stray_backrefs(code: str) -> str:
185
177
  code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
@@ -229,21 +221,37 @@ def harden_ai_code(code: str) -> str:
229
221
 
230
222
  return pattern.sub(repl, code)
231
223
 
224
+
232
225
  def _wrap_metric_calls(code: str) -> str:
233
226
  names = [
234
- "r2_score","accuracy_score","precision_score","recall_score","f1_score",
235
- "roc_auc_score","classification_report","confusion_matrix",
236
- "mean_absolute_error","mean_absolute_percentage_error",
237
- "explained_variance_score","log_loss","average_precision_score",
238
- "precision_recall_fscore_support"
227
+ "r2_score",
228
+ "accuracy_score",
229
+ "precision_score",
230
+ "recall_score",
231
+ "f1_score",
232
+ "roc_auc_score",
233
+ "classification_report",
234
+ "confusion_matrix",
235
+ "mean_absolute_error",
236
+ "mean_absolute_percentage_error",
237
+ "explained_variance_score",
238
+ "log_loss",
239
+ "average_precision_score",
240
+ "precision_recall_fscore_support",
241
+ "mean_squared_error",
239
242
  ]
240
- pat = re.compile(r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\(")
243
+ pat = re.compile(
244
+ r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\("
245
+ )
246
+
241
247
  def repl(m):
242
248
  prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
243
249
  name = m.group(2)
244
250
  return f"_SMX_call({prefix}{name}, "
251
+
245
252
  return pat.sub(repl, code)
246
253
 
254
+
247
255
  def _smx_patch_mean_squared_error_squared_kw():
248
256
  """
249
257
  sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
@@ -318,6 +326,8 @@ def harden_ai_code(code: str) -> str:
318
326
  needed.add("r2_score")
319
327
  if "mean_absolute_error" in code:
320
328
  needed.add("mean_absolute_error")
329
+ if "mean_squared_error" in code:
330
+ needed.add("mean_squared_error")
321
331
  # ... add others if you like ...
322
332
 
323
333
  if not needed:
@@ -378,7 +388,8 @@ def harden_ai_code(code: str) -> str:
378
388
  - then falls back to generic but useful EDA.
379
389
 
380
390
  It assumes `from syntaxmatrix.preface import *` has already been done,
381
- so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `_SMX_export_png` and the
391
+ so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `SB_boxplot`,
392
+ `SB_scatterplot`, `SB_heatmap`, `_SMX_export_png` and the
382
393
  patched `show()` are available.
383
394
  """
384
395
  import textwrap
@@ -488,15 +499,91 @@ def harden_ai_code(code: str) -> str:
488
499
  show(df.head(), title='Sample of data')
489
500
  show(info, title='Dataset summary')
490
501
 
491
- # Quick univariate look if we have numeric columns
502
+ # 1) Distribution of a numeric column
492
503
  if num_cols:
493
504
  SB_histplot()
494
505
  _SMX_export_png()
506
+
507
+ # 2) Relationship between two numeric columns
508
+ if len(num_cols) >= 2:
509
+ SB_scatterplot()
510
+ _SMX_export_png()
511
+
512
+ # 3) Distribution of a numeric by the first categorical column
513
+ if num_cols and cat_cols:
514
+ SB_boxplot()
515
+ _SMX_export_png()
516
+
517
+ # 4) Correlation heatmap across numeric columns
518
+ if len(num_cols) >= 2:
519
+ SB_heatmap()
520
+ _SMX_export_png()
495
521
  except Exception as _eda_e:
496
522
  show(f"⚠ EDA fallback failed: {type(_eda_e).__name__}: {_eda_e}")
497
523
  """
498
524
  )
499
525
 
526
+ def _strip_file_io_ops(code: str) -> str:
527
+ """
528
+ Remove obvious local file I/O operations in LLM code
529
+ so nothing writes to the container filesystem.
530
+ """
531
+ # 1) Methods like df.to_csv(...), df.to_excel(...), etc.
532
+ FILE_WRITE_METHODS = (
533
+ "to_csv", "to_excel", "to_pickle", "to_parquet",
534
+ "to_json", "to_hdf",
535
+ )
536
+
537
+ for mname in FILE_WRITE_METHODS:
538
+ pat = re.compile(
539
+ rf"(?m)^(\s*)([A-Za-z_][A-Za-z0-9_\.]*)\s*\.\s*{mname}\s*\([^)]*\)\s*$"
540
+ )
541
+
542
+ def _repl(match):
543
+ indent = match.group(1)
544
+ expr = match.group(2)
545
+ return f"{indent}# [SMX] stripped file write: {expr}.{mname}(...)"
546
+
547
+ code = pat.sub(_repl, code)
548
+
549
+ # 2) plt.savefig(...) calls
550
+ pat_savefig = re.compile(r"(?m)^(\s*)(plt\.savefig\s*\([^)]*\)\s*)$")
551
+ code = pat_savefig.sub(
552
+ lambda m: f"{m.group(1)}# [SMX] stripped savefig: {m.group(2).strip()}",
553
+ code,
554
+ )
555
+
556
+ # 3) with open(..., 'w'/'wb') as f:
557
+ pat_with_open = re.compile(
558
+ r"(?m)^(\s*)with\s+open\([^)]*['\"]w[b]?['\"][^)]*\)\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*:\s*$"
559
+ )
560
+
561
+ def _with_open_repl(match):
562
+ indent = match.group(1)
563
+ var = match.group(2)
564
+ return f"{indent}if False: # [SMX] file write stripped (was: with open(... as {var}))"
565
+
566
+ code = pat_with_open.sub(_with_open_repl, code)
567
+
568
+ # 4) joblib.dump(...), pickle.dump(...)
569
+ for mod in ("joblib", "pickle"):
570
+ pat = re.compile(rf"(?m)^(\s*){mod}\.dump\s*\([^)]*\)\s*$")
571
+ code = pat.sub(
572
+ lambda m: f"{m.group(1)}# [SMX] stripped {mod}.dump(...)",
573
+ code,
574
+ )
575
+
576
+ # 5) bare open(..., 'w'/'wb') calls
577
+ pat_open = re.compile(
578
+ r"(?m)^(\s*)open\([^)]*['\"]w[b]?['\"][^)]*\)\s*$"
579
+ )
580
+ code = pat_open.sub(
581
+ lambda m: f"{m.group(1)}# [SMX] stripped open(..., 'w'/'wb')",
582
+ code,
583
+ )
584
+
585
+ return code
586
+
500
587
  # Register and run patches once per execution
501
588
  for _patch in (
502
589
  _smx_patch_mean_squared_error_squared_kw,
@@ -562,19 +649,21 @@ def harden_ai_code(code: str) -> str:
562
649
  except (SyntaxError, IndentationError):
563
650
  fixed = _fallback_snippet()
564
651
 
652
+
653
+ # Fix placeholder Ellipsis handlers from LLM
565
654
  fixed = re.sub(
566
655
  r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
567
656
  "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
568
657
  fixed,
569
658
  )
570
659
 
571
- # Fix placeholder Ellipsis handlers from LLM
660
+ # redirect that import to the real template module.
572
661
  fixed = re.sub(
573
- r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
574
- "except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
662
+ r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
663
+ r"from syntaxmatrix.agentic.model_templates import \1",
575
664
  fixed,
576
665
  )
577
-
666
+
578
667
  try:
579
668
  class _SMXMatmulRewriter(ast.NodeTransformer):
580
669
  def visit_BinOp(self, node):
@@ -593,11 +682,25 @@ def harden_ai_code(code: str) -> str:
593
682
  # 6) Final safety wrapper
594
683
  fixed = fixed.replace("\t", " ")
595
684
  fixed = textwrap.dedent(fixed).strip("\n")
685
+
596
686
  fixed = _ensure_metrics_imports(fixed)
597
687
  fixed = _strip_stray_backrefs(fixed)
598
688
  fixed = _wrap_metric_calls(fixed)
599
689
  fixed = _fix_unexpected_indent(fixed)
600
690
  fixed = _patch_feature_coef_dataframe(fixed)
691
+ fixed = _strip_file_io_ops(fixed)
692
+
693
+ metric_defaults = "\n".join([
694
+ "acc = None",
695
+ "accuracy = None",
696
+ "r2 = None",
697
+ "mae = None",
698
+ "rmse = None",
699
+ "f1 = None",
700
+ "precision = None",
701
+ "recall = None",
702
+ ]) + "\n"
703
+ fixed = metric_defaults + fixed
601
704
 
602
705
  # Import shared preface helpers once and wrap the LLM body safely
603
706
  header = "from syntaxmatrix.preface import *\n\n"
@@ -605,822 +708,822 @@ def harden_ai_code(code: str) -> str:
605
708
  return wrapped
606
709
 
607
710
 
608
- def indent_code(code: str, spaces: int = 4) -> str:
609
- pad = " " * spaces
610
- return "\n".join(pad + line for line in code.splitlines())
711
+ # def indent_code(code: str, spaces: int = 4) -> str:
712
+ # pad = " " * spaces
713
+ # return "\n".join(pad + line for line in code.splitlines())
611
714
 
612
715
 
613
- def fix_boxplot_placeholder(code: str) -> str:
614
- # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
615
- return re.sub(
616
- r"sns\.boxplot\(\s*boxplot\s*\)",
617
- "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
618
- code
619
- )
716
+ # def fix_boxplot_placeholder(code: str) -> str:
717
+ # # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
718
+ # return re.sub(
719
+ # r"sns\.boxplot\(\s*boxplot\s*\)",
720
+ # "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
721
+ # code
722
+ # )
620
723
 
621
724
 
622
- def relax_required_columns(code: str) -> str:
623
- # Remove hard failure on required_cols; keep a soft filter instead
624
- return re.sub(
625
- r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
626
- "required_cols = [c for c in df.columns]\n",
627
- code,
628
- flags=re.S
629
- )
725
+ # def relax_required_columns(code: str) -> str:
726
+ # # Remove hard failure on required_cols; keep a soft filter instead
727
+ # return re.sub(
728
+ # r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
729
+ # "required_cols = [c for c in df.columns]\n",
730
+ # code,
731
+ # flags=re.S
732
+ # )
630
733
 
631
734
 
632
- def make_numeric_vars_dynamic(code: str) -> str:
633
- # Replace any static numeric_vars list with a dynamic selection
634
- return re.sub(
635
- r"numeric_vars\s*=\s*\[.*?\]",
636
- "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
637
- code,
638
- flags=re.S
639
- )
735
+ # def make_numeric_vars_dynamic(code: str) -> str:
736
+ # # Replace any static numeric_vars list with a dynamic selection
737
+ # return re.sub(
738
+ # r"numeric_vars\s*=\s*\[.*?\]",
739
+ # "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
740
+ # code,
741
+ # flags=re.S
742
+ # )
640
743
 
641
744
 
642
- def auto_inject_template(code: str, intents, df) -> str:
643
- """If the LLM forgot the core logic, prepend a skeleton block."""
745
+ # def auto_inject_template(code: str, intents, df) -> str:
746
+ # """If the LLM forgot the core logic, prepend a skeleton block."""
644
747
 
645
- has_fit = ".fit(" in code
646
- has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
748
+ # has_fit = ".fit(" in code
749
+ # has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
647
750
 
648
- UNKNOWN_TOKENS = {
649
- "unknown","not reported","not_reported","not known","n/a","na",
650
- "none","nan","missing","unreported","unspecified","null","-",""
651
- }
652
-
653
- # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
654
- def _call_template(func, df, **hints):
655
- import inspect
656
- try:
657
- params = inspect.signature(func).parameters
658
- kw = {k: v for k, v in hints.items() if k in params}
659
- try:
660
- return func(df, **kw)
661
- except TypeError:
662
- # In case the template changed its signature at runtime
663
- return func(df)
664
- except Exception:
665
- # Absolute safety net
666
- try:
667
- return func(df)
668
- except Exception:
669
- # As a last resort, return empty code so we don't 500
670
- return ""
671
-
672
- def _guess_classification_target(df: pd.DataFrame) -> str | None:
673
- cols = list(df.columns)
674
-
675
- # Helper: does this column look like a sensible label?
676
- def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
677
- try:
678
- nunq = s.dropna().nunique()
679
- except Exception:
680
- return False
681
- # need at least 2 classes, but not hundreds
682
- if nunq < 2 or nunq > 20:
683
- return False
684
- bad_name_keys = ("id", "identifier", "index", "uuid", "key")
685
- name = str(col_name).lower()
686
- if any(k in name for k in bad_name_keys):
687
- return False
688
- return True
689
-
690
- # 1) columns whose names look like labels
691
- label_keys = ("target", "label", "outcome", "class", "y", "status")
692
- name_candidates: list[str] = []
693
- for key in label_keys:
694
- for c in cols:
695
- if key in str(c).lower():
696
- name_candidates.append(c)
697
- if name_candidates:
698
- break # keep the earliest matching key-group
699
-
700
- # prioritise name-based candidates that also look like proper label columns
701
- for c in name_candidates:
702
- if _is_reasonable_class_col(df[c], c):
703
- return c
704
- if name_candidates:
705
- # fall back to the first name-based candidate if none passed the shape test
706
- return name_candidates[0]
707
-
708
- # 2) any column with a small number of distinct values (likely a class label)
709
- for c in cols:
710
- s = df[c]
711
- if _is_reasonable_class_col(s, c):
712
- return c
713
-
714
- # Nothing suitable found
715
- return None
716
-
717
- def _guess_regression_target(df: pd.DataFrame) -> str | None:
718
- num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
719
- if not num_cols:
720
- return None
721
- # Avoid obvious ID-like columns
722
- bad_keys = ("id", "identifier", "index")
723
- candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
724
- return (candidates or num_cols)[-1]
751
+ # UNKNOWN_TOKENS = {
752
+ # "unknown","not reported","not_reported","not known","n/a","na",
753
+ # "none","nan","missing","unreported","unspecified","null","-",""
754
+ # }
755
+
756
+ # # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
757
+ # def _call_template(func, df, **hints):
758
+ # import inspect
759
+ # try:
760
+ # params = inspect.signature(func).parameters
761
+ # kw = {k: v for k, v in hints.items() if k in params}
762
+ # try:
763
+ # return func(df, **kw)
764
+ # except TypeError:
765
+ # # In case the template changed its signature at runtime
766
+ # return func(df)
767
+ # except Exception:
768
+ # # Absolute safety net
769
+ # try:
770
+ # return func(df)
771
+ # except Exception:
772
+ # # As a last resort, return empty code so we don't 500
773
+ # return ""
774
+
775
+ # def _guess_classification_target(df: pd.DataFrame) -> str | None:
776
+ # cols = list(df.columns)
777
+
778
+ # # Helper: does this column look like a sensible label?
779
+ # def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
780
+ # try:
781
+ # nunq = s.dropna().nunique()
782
+ # except Exception:
783
+ # return False
784
+ # # need at least 2 classes, but not hundreds
785
+ # if nunq < 2 or nunq > 20:
786
+ # return False
787
+ # bad_name_keys = ("id", "identifier", "index", "uuid", "key")
788
+ # name = str(col_name).lower()
789
+ # if any(k in name for k in bad_name_keys):
790
+ # return False
791
+ # return True
792
+
793
+ # # 1) columns whose names look like labels
794
+ # label_keys = ("target", "label", "outcome", "class", "y", "status")
795
+ # name_candidates: list[str] = []
796
+ # for key in label_keys:
797
+ # for c in cols:
798
+ # if key in str(c).lower():
799
+ # name_candidates.append(c)
800
+ # if name_candidates:
801
+ # break # keep the earliest matching key-group
802
+
803
+ # # prioritise name-based candidates that also look like proper label columns
804
+ # for c in name_candidates:
805
+ # if _is_reasonable_class_col(df[c], c):
806
+ # return c
807
+ # if name_candidates:
808
+ # # fall back to the first name-based candidate if none passed the shape test
809
+ # return name_candidates[0]
810
+
811
+ # # 2) any column with a small number of distinct values (likely a class label)
812
+ # for c in cols:
813
+ # s = df[c]
814
+ # if _is_reasonable_class_col(s, c):
815
+ # return c
816
+
817
+ # # Nothing suitable found
818
+ # return None
819
+
820
+ # def _guess_regression_target(df: pd.DataFrame) -> str | None:
821
+ # num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
822
+ # if not num_cols:
823
+ # return None
824
+ # # Avoid obvious ID-like columns
825
+ # bad_keys = ("id", "identifier", "index")
826
+ # candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
827
+ # return (candidates or num_cols)[-1]
725
828
 
726
- def _guess_time_col(df: pd.DataFrame) -> str | None:
727
- # Prefer actual datetime dtype
728
- dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
729
- if dt_cols:
730
- return dt_cols[0]
731
-
732
- # Fallback: name-based hints
733
- name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
734
- for c in df.columns:
735
- name = str(c).lower()
736
- if any(k in name for k in name_keys):
737
- return c
738
- return None
739
-
740
- def _guess_entity_col(df: pd.DataFrame) -> str | None:
741
- # Typical sequence IDs: id, patient, subject, device, series, entity
742
- keys = ["id", "patient", "subject", "device", "series", "entity"]
743
- candidates = []
744
- for c in df.columns:
745
- name = str(c).lower()
746
- if any(k in name for k in keys):
747
- candidates.append(c)
748
- return candidates[0] if candidates else None
749
-
750
- def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
751
- # Try label-like names first
752
- keys = ["target", "label", "class", "outcome", "y"]
753
- for key in keys:
754
- for c in df.columns:
755
- if key in str(c).lower():
756
- return c
757
-
758
- # Fallback: any column with few distinct values (e.g. <= 10)
759
- for c in df.columns:
760
- s = df[c]
761
- # avoid obvious IDs
762
- if any(k in str(c).lower() for k in ["id", "index"]):
763
- continue
764
- try:
765
- nunq = s.dropna().nunique()
766
- except Exception:
767
- continue
768
- if 1 < nunq <= 10:
769
- return c
770
-
771
- return None
772
-
773
- def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
774
- cols = list(df.columns)
775
- lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
776
- # also include boolean/binary columns with suitable names
777
- for c in cols:
778
- s = df[c]
779
- try:
780
- nunq = s.dropna().nunique()
781
- except Exception:
782
- continue
783
- if nunq in (2,) and c not in lbl_like:
784
- # avoid obvious IDs
785
- if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
786
- lbl_like.append(c)
787
- # keep at most, say, 12 to avoid accidental flood
788
- return lbl_like[:12]
789
-
790
- def _find_unknownish_column(df: pd.DataFrame) -> str | None:
791
- # Search categorical-like columns for any 'unknown-like' values or high missingness
792
- candidates = []
793
- for c in df.columns:
794
- s = df[c]
795
- # focus on object/category/boolean-ish or low-card columns
796
- if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
797
- continue
798
- try:
799
- vals = s.astype(str).str.strip().str.lower()
800
- except Exception:
801
- continue
802
- # score: presence of unknown tokens + missing rate
803
- token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
804
- miss_rate = s.isna().mean()
805
- name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
806
- score = 3*token_hit + 2*name_bonus + miss_rate
807
- if token_hit or miss_rate > 0.05 or name_bonus:
808
- candidates.append((score, c))
809
- if not candidates:
810
- return None
811
- candidates.sort(reverse=True)
812
- return candidates[0][1]
813
-
814
- def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
815
- cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
816
- # prefer non-constant columns
817
- scored = []
818
- for c in cols:
819
- try:
820
- v = df[c].dropna()
821
- var = float(v.var()) if len(v) else 0.0
822
- scored.append((var, c))
823
- except Exception:
824
- continue
825
- scored.sort(reverse=True)
826
- return [c for _, c in scored[:max_n]]
827
-
828
- def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
829
- exclude = exclude or set()
830
- picks = []
831
- for c in df.columns:
832
- if c in exclude:
833
- continue
834
- s = df[c]
835
- if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
836
- nunq = s.dropna().nunique()
837
- if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
838
- picks.append((nunq, c))
839
- picks.sort(reverse=True)
840
- return [c for _, c in picks[:max_n]]
841
-
842
- def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
843
- exclude = exclude or set()
844
- # name hints first
845
- name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
846
- for c in df.columns:
847
- if c in exclude:
848
- continue
849
- name = str(c).lower()
850
- if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
851
- return c
852
- # fallback: any binary numeric
853
- for c in df.select_dtypes(include=[np.number, "bool"]).columns:
854
- if c in exclude:
855
- continue
856
- try:
857
- if df[c].dropna().nunique() == 2:
858
- return c
859
- except Exception:
860
- continue
861
- return None
829
+ # def _guess_time_col(df: pd.DataFrame) -> str | None:
830
+ # # Prefer actual datetime dtype
831
+ # dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
832
+ # if dt_cols:
833
+ # return dt_cols[0]
834
+
835
+ # # Fallback: name-based hints
836
+ # name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
837
+ # for c in df.columns:
838
+ # name = str(c).lower()
839
+ # if any(k in name for k in name_keys):
840
+ # return c
841
+ # return None
842
+
843
+ # def _guess_entity_col(df: pd.DataFrame) -> str | None:
844
+ # # Typical sequence IDs: id, patient, subject, device, series, entity
845
+ # keys = ["id", "patient", "subject", "device", "series", "entity"]
846
+ # candidates = []
847
+ # for c in df.columns:
848
+ # name = str(c).lower()
849
+ # if any(k in name for k in keys):
850
+ # candidates.append(c)
851
+ # return candidates[0] if candidates else None
852
+
853
+ # def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
854
+ # # Try label-like names first
855
+ # keys = ["target", "label", "class", "outcome", "y"]
856
+ # for key in keys:
857
+ # for c in df.columns:
858
+ # if key in str(c).lower():
859
+ # return c
860
+
861
+ # # Fallback: any column with few distinct values (e.g. <= 10)
862
+ # for c in df.columns:
863
+ # s = df[c]
864
+ # # avoid obvious IDs
865
+ # if any(k in str(c).lower() for k in ["id", "index"]):
866
+ # continue
867
+ # try:
868
+ # nunq = s.dropna().nunique()
869
+ # except Exception:
870
+ # continue
871
+ # if 1 < nunq <= 10:
872
+ # return c
873
+
874
+ # return None
875
+
876
+ # def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
877
+ # cols = list(df.columns)
878
+ # lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
879
+ # # also include boolean/binary columns with suitable names
880
+ # for c in cols:
881
+ # s = df[c]
882
+ # try:
883
+ # nunq = s.dropna().nunique()
884
+ # except Exception:
885
+ # continue
886
+ # if nunq in (2,) and c not in lbl_like:
887
+ # # avoid obvious IDs
888
+ # if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
889
+ # lbl_like.append(c)
890
+ # # keep at most, say, 12 to avoid accidental flood
891
+ # return lbl_like[:12]
892
+
893
+ # def _find_unknownish_column(df: pd.DataFrame) -> str | None:
894
+ # # Search categorical-like columns for any 'unknown-like' values or high missingness
895
+ # candidates = []
896
+ # for c in df.columns:
897
+ # s = df[c]
898
+ # # focus on object/category/boolean-ish or low-card columns
899
+ # if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
900
+ # continue
901
+ # try:
902
+ # vals = s.astype(str).str.strip().str.lower()
903
+ # except Exception:
904
+ # continue
905
+ # # score: presence of unknown tokens + missing rate
906
+ # token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
907
+ # miss_rate = s.isna().mean()
908
+ # name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
909
+ # score = 3*token_hit + 2*name_bonus + miss_rate
910
+ # if token_hit or miss_rate > 0.05 or name_bonus:
911
+ # candidates.append((score, c))
912
+ # if not candidates:
913
+ # return None
914
+ # candidates.sort(reverse=True)
915
+ # return candidates[0][1]
916
+
917
+ # def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
918
+ # cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
919
+ # # prefer non-constant columns
920
+ # scored = []
921
+ # for c in cols:
922
+ # try:
923
+ # v = df[c].dropna()
924
+ # var = float(v.var()) if len(v) else 0.0
925
+ # scored.append((var, c))
926
+ # except Exception:
927
+ # continue
928
+ # scored.sort(reverse=True)
929
+ # return [c for _, c in scored[:max_n]]
930
+
931
+ # def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
932
+ # exclude = exclude or set()
933
+ # picks = []
934
+ # for c in df.columns:
935
+ # if c in exclude:
936
+ # continue
937
+ # s = df[c]
938
+ # if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
939
+ # nunq = s.dropna().nunique()
940
+ # if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
941
+ # picks.append((nunq, c))
942
+ # picks.sort(reverse=True)
943
+ # return [c for _, c in picks[:max_n]]
944
+
945
+ # def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
946
+ # exclude = exclude or set()
947
+ # # name hints first
948
+ # name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
949
+ # for c in df.columns:
950
+ # if c in exclude:
951
+ # continue
952
+ # name = str(c).lower()
953
+ # if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
954
+ # return c
955
+ # # fallback: any binary numeric
956
+ # for c in df.select_dtypes(include=[np.number, "bool"]).columns:
957
+ # if c in exclude:
958
+ # continue
959
+ # try:
960
+ # if df[c].dropna().nunique() == 2:
961
+ # return c
962
+ # except Exception:
963
+ # continue
964
+ # return None
862
965
 
863
966
 
864
- def _pick_viz_template(signal: str):
865
- s = signal.lower()
967
+ # def _pick_viz_template(signal: str):
968
+ # s = signal.lower()
866
969
 
867
- # explicit chart requests
868
- if any(k in s for k in ("pie", "donut")):
869
- return viz_pie
970
+ # # explicit chart requests
971
+ # if any(k in s for k in ("pie", "donut")):
972
+ # return viz_pie
870
973
 
871
- if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
872
- return viz_stacked_bar
974
+ # if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
975
+ # return viz_stacked_bar
873
976
 
874
- if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
875
- return viz_distribution
977
+ # if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
978
+ # return viz_distribution
876
979
 
877
- if any(k in s for k in ("kde", "density")):
878
- return viz_kde
980
+ # if any(k in s for k in ("kde", "density")):
981
+ # return viz_kde
879
982
 
880
- # these three you asked about
881
- if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
882
- return viz_box
983
+ # # these three you asked about
984
+ # if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
985
+ # return viz_box
883
986
 
884
- if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
885
- return viz_scatter
987
+ # if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
988
+ # return viz_scatter
886
989
 
887
- if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
888
- return viz_count_bar
990
+ # if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
991
+ # return viz_count_bar
889
992
 
890
- if any(k in s for k in ("area", "trend", "over time", "time series")):
891
- return viz_area
993
+ # if any(k in s for k in ("area", "trend", "over time", "time series")):
994
+ # return viz_area
892
995
 
893
- # fallback
894
- return viz_line
996
+ # # fallback
997
+ # return viz_line
895
998
 
896
- for intent in intents:
999
+ # for intent in intents:
897
1000
 
898
- if intent not in INJECTABLE_INTENTS:
899
- return code
1001
+ # if intent not in INJECTABLE_INTENTS:
1002
+ # return code
900
1003
 
901
- # Correlation analysis
902
- if intent == "correlation_analysis" and not has_fit:
903
- return eda_correlation(df) + "\n\n" + code
904
-
905
- # Generic visualisation (keyword-based)
906
- if intent == "visualisation" and not has_fit and not has_plot:
907
- rq = str(globals().get("refined_question", ""))
908
- # aq = str(globals().get("askai_question", ""))
909
- signal = rq + "\n" + str(intents) + "\n" + code
910
- tpl = _pick_viz_template(signal)
911
- return tpl(df) + "\n\n" + code
1004
+ # # Correlation analysis
1005
+ # if intent == "correlation_analysis" and not has_fit:
1006
+ # return eda_correlation(df) + "\n\n" + code
1007
+
1008
+ # # Generic visualisation (keyword-based)
1009
+ # if intent == "visualisation" and not has_fit and not has_plot:
1010
+ # rq = str(globals().get("refined_question", ""))
1011
+ # # aq = str(globals().get("askai_question", ""))
1012
+ # signal = rq + "\n" + str(intents) + "\n" + code
1013
+ # tpl = _pick_viz_template(signal)
1014
+ # return tpl(df) + "\n\n" + code
912
1015
 
913
- if intent == "clustering" and not has_fit:
914
- return clustering(df) + "\n\n" + code
1016
+ # if intent == "clustering" and not has_fit:
1017
+ # return clustering(df) + "\n\n" + code
915
1018
 
916
- if intent == "recommendation" and not has_fit:
917
- return recommendation(df) + "\\n\\n" + code
1019
+ # if intent == "recommendation" and not has_fit:
1020
+ # return recommendation(df) + "\\n\\n" + code
918
1021
 
919
- if intent == "topic_modelling" and not has_fit:
920
- return topic_modelling(df) + "\\n\\n" + code
1022
+ # if intent == "topic_modelling" and not has_fit:
1023
+ # return topic_modelling(df) + "\\n\\n" + code
921
1024
 
922
- if intent == "eda" and not has_fit:
923
- return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
924
-
925
- # --- Classification ------------------------------------------------
926
- if intent == "classification" and not has_fit:
927
- target = _guess_classification_target(df)
928
- if target:
929
- return classification(df) + "\n\n" + code
930
- # return _call_template(classification, df, target) + "\n\n" + code
931
-
932
- # --- Regression ----------------------------------------------------
933
- if intent == "regression" and not has_fit:
934
- target = _guess_regression_target(df)
935
- if target:
936
- return regression(df) + "\n\n" + code
937
- # return _call_template(regression, df, target) + "\n\n" + code
938
-
939
- # --- Anomaly detection --------------------------------------------
940
- if intent == "anomaly_detection":
941
- uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
942
- if not uses_anomaly:
943
- return anomaly_detection(df) + "\n\n" + code
944
-
945
- # --- Time-series anomaly detection --------------------------------
946
- if intent == "ts_anomaly_detection":
947
- uses_ts = "STL(" in code or "seasonal_decompose(" in code
948
- if not uses_ts:
949
- return ts_anomaly_detection(df) + "\n\n" + code
1025
+ # if intent == "eda" and not has_fit:
1026
+ # return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
1027
+
1028
+ # # --- Classification ------------------------------------------------
1029
+ # if intent == "classification" and not has_fit:
1030
+ # target = _guess_classification_target(df)
1031
+ # if target:
1032
+ # return classification(df) + "\n\n" + code
1033
+ # # return _call_template(classification, df, target) + "\n\n" + code
1034
+
1035
+ # # --- Regression ----------------------------------------------------
1036
+ # if intent == "regression" and not has_fit:
1037
+ # target = _guess_regression_target(df)
1038
+ # if target:
1039
+ # return regression(df) + "\n\n" + code
1040
+ # # return _call_template(regression, df, target) + "\n\n" + code
1041
+
1042
+ # # --- Anomaly detection --------------------------------------------
1043
+ # if intent == "anomaly_detection":
1044
+ # uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
1045
+ # if not uses_anomaly:
1046
+ # return anomaly_detection(df) + "\n\n" + code
1047
+
1048
+ # # --- Time-series anomaly detection --------------------------------
1049
+ # if intent == "ts_anomaly_detection":
1050
+ # uses_ts = "STL(" in code or "seasonal_decompose(" in code
1051
+ # if not uses_ts:
1052
+ # return ts_anomaly_detection(df) + "\n\n" + code
950
1053
 
951
- # --- Time-series classification --------------------------------
952
- if intent == "time_series_classification" and not has_fit:
953
- time_col = _guess_time_col(df)
954
- entity_col = _guess_entity_col(df)
955
- target_col = _guess_ts_class_target(df)
956
-
957
- # If we can't confidently identify these, do NOT inject anything
958
- if time_col and entity_col and target_col:
959
- return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
960
-
961
- # --- Dimensionality reduction --------------------------------------
962
- if intent == "dimensionality_reduction":
963
- uses_dr = any(k in code for k in ("PCA(", "TSNE("))
964
- if not uses_dr:
965
- return dimensionality_reduction(df) + "\n\n" + code
966
-
967
- # --- Feature selection ---------------------------------------------
968
- if intent == "feature_selection":
969
- uses_fs = any(k in code for k in (
970
- "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
971
- ))
972
- if not uses_fs:
973
- return feature_selection(df) + "\n\n" + code
974
-
975
- # --- EDA / correlation / visualisation -----------------------------
976
- if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
977
- if intent == "correlation_analysis":
978
- return eda_correlation(df) + "\n\n" + code
979
- else:
980
- return eda_overview(df) + "\n\n" + code
981
-
982
- # --- Time-series forecasting ---------------------------------------
983
- if intent == "time_series_forecasting" and not has_fit:
984
- uses_ts_forecast = any(k in code for k in (
985
- "ARIMA", "ExponentialSmoothing", "forecast", "predict("
986
- ))
987
- if not uses_ts_forecast:
988
- return time_series_forecasting(df) + "\n\n" + code
1054
+ # # --- Time-series classification --------------------------------
1055
+ # if intent == "time_series_classification" and not has_fit:
1056
+ # time_col = _guess_time_col(df)
1057
+ # entity_col = _guess_entity_col(df)
1058
+ # target_col = _guess_ts_class_target(df)
1059
+
1060
+ # # If we can't confidently identify these, do NOT inject anything
1061
+ # if time_col and entity_col and target_col:
1062
+ # return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
1063
+
1064
+ # # --- Dimensionality reduction --------------------------------------
1065
+ # if intent == "dimensionality_reduction":
1066
+ # uses_dr = any(k in code for k in ("PCA(", "TSNE("))
1067
+ # if not uses_dr:
1068
+ # return dimensionality_reduction(df) + "\n\n" + code
1069
+
1070
+ # # --- Feature selection ---------------------------------------------
1071
+ # if intent == "feature_selection":
1072
+ # uses_fs = any(k in code for k in (
1073
+ # "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
1074
+ # ))
1075
+ # if not uses_fs:
1076
+ # return feature_selection(df) + "\n\n" + code
1077
+
1078
+ # # --- EDA / correlation / visualisation -----------------------------
1079
+ # if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
1080
+ # if intent == "correlation_analysis":
1081
+ # return eda_correlation(df) + "\n\n" + code
1082
+ # else:
1083
+ # return eda_overview(df) + "\n\n" + code
1084
+
1085
+ # # --- Time-series forecasting ---------------------------------------
1086
+ # if intent == "time_series_forecasting" and not has_fit:
1087
+ # uses_ts_forecast = any(k in code for k in (
1088
+ # "ARIMA", "ExponentialSmoothing", "forecast", "predict("
1089
+ # ))
1090
+ # if not uses_ts_forecast:
1091
+ # return time_series_forecasting(df) + "\n\n" + code
989
1092
 
990
- # --- Multi-label classification -----------------------------------
991
- if intent in ("multilabel_classification",) and not has_fit:
992
- label_cols = _guess_multilabel_cols(df)
993
- if len(label_cols) >= 2:
994
- return multilabel_classification(df, label_cols) + "\n\n" + code
995
-
996
- group_col = _find_unknownish_column(df)
997
- if group_col:
998
- num_cols = _guess_numeric_cols(df)
999
- cat_cols = _guess_categorical_cols(df, exclude={group_col})
1000
- outcome_col = None # generic; let template skip if not present
1001
- tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1002
-
1003
- # Return template + guarded (repaired) LLM code, so it never crashes
1004
- repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1005
- return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1093
+ # # --- Multi-label classification -----------------------------------
1094
+ # if intent in ("multilabel_classification",) and not has_fit:
1095
+ # label_cols = _guess_multilabel_cols(df)
1096
+ # if len(label_cols) >= 2:
1097
+ # return multilabel_classification(df, label_cols) + "\n\n" + code
1098
+
1099
+ # group_col = _find_unknownish_column(df)
1100
+ # if group_col:
1101
+ # num_cols = _guess_numeric_cols(df)
1102
+ # cat_cols = _guess_categorical_cols(df, exclude={group_col})
1103
+ # outcome_col = None # generic; let template skip if not present
1104
+ # tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
1105
+
1106
+ # # Return template + guarded (repaired) LLM code, so it never crashes
1107
+ # repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
1108
+ # return tpl + "\n\n" + wrap_llm_code_safe(repaired)
1006
1109
 
1007
- return code
1008
-
1009
-
1010
- def fix_values_sum_numeric_only_bug(code: str) -> str:
1011
- """
1012
- If a previous pass injected numeric_only=True into a NumPy-style sum,
1013
- e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1014
- """
1015
- # .values.sum(numeric_only=True, ...)
1016
- code = re.sub(
1017
- r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1018
- ".to_numpy().sum()",
1019
- code,
1020
- flags=re.IGNORECASE,
1021
- )
1022
- # .to_numpy().sum(numeric_only=True, ...)
1023
- code = re.sub(
1024
- r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1025
- ".to_numpy().sum()",
1026
- code,
1027
- flags=re.IGNORECASE,
1028
- )
1029
- return code
1030
-
1031
-
1032
- def strip_describe_slice(code: str) -> str:
1033
- """
1034
- Remove any pattern like df.groupby(...).describe()[[ ... ]] because
1035
- slicing a SeriesGroupBy.describe() causes AttributeError.
1036
- We leave the plain .describe() in place (harmless) and let our own
1037
- table patcher add the correct .agg() table afterwards.
1038
- """
1039
- pat = re.compile(
1040
- r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
1041
- flags=re.DOTALL,
1042
- )
1043
- return pat.sub(r"\1)", code)
1044
-
1045
-
1046
- def remove_plt_show(code: str) -> str:
1047
- """Removes all plt.show() calls from the generated code string."""
1048
- return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
1049
-
1050
-
1051
- def patch_plot_with_table(code: str) -> str:
1052
- """
1053
- ▸ strips every `plt.show()` (avoids warnings)
1054
- ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
1055
- rendered in the dashboard
1056
- ▸ appends a summary-stats table **after** the plot
1057
- """
1058
- # 0. drop plt.show()
1059
- lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
1060
-
1061
- # 1. locate the last plotting line
1062
- plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
1063
- last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
1064
- if last_plot == -1:
1065
- return "\n".join(lines) # nothing to do
1066
-
1067
- whole = "\n".join(lines)
1068
-
1069
- # 2. detect group / feature (if any)
1070
- group, feature = None, None
1071
- xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
1072
- ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
1073
- if xm and ym:
1074
- group, feature = xm.group(1), ym.group(1)
1075
- else:
1076
- cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
1077
- if cm:
1078
- feature, group = cm.group(1), cm.group(2)
1079
-
1080
- # 3. code that captures current fig → PNG → HTML
1081
- img_block = textwrap.dedent("""
1082
- import io, base64
1083
- buf = io.BytesIO()
1084
- plt.savefig(buf, format='png', bbox_inches='tight')
1085
- buf.seek(0)
1086
- img_b64 = base64.b64encode(buf.read()).decode('utf-8')
1087
- from IPython.display import display, HTML
1088
- display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
1089
- plt.close()
1090
- """)
1091
-
1092
- # 4. build summary-table code
1093
- if group and feature:
1094
- tbl_block = (
1095
- f"summary_table = (\n"
1096
- f" df.groupby('{group}')['{feature}']\n"
1097
- f" .agg(['count','mean','std','min','median','max'])\n"
1098
- f" .rename(columns={{'median':'50%'}})\n"
1099
- f")\n"
1100
- )
1101
- elif ym:
1102
- feature = ym.group(1)
1103
- tbl_block = (
1104
- f"summary_table = (\n"
1105
- f" df['{feature}']\n"
1106
- f" .agg(['count','mean','std','min','median','max'])\n"
1107
- f" .rename(columns={{'median':'50%'}})\n"
1108
- f")\n"
1109
- )
1110
+ # return code
1111
+
1112
+
1113
+ # def fix_values_sum_numeric_only_bug(code: str) -> str:
1114
+ # """
1115
+ # If a previous pass injected numeric_only=True into a NumPy-style sum,
1116
+ # e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
1117
+ # """
1118
+ # # .values.sum(numeric_only=True, ...)
1119
+ # code = re.sub(
1120
+ # r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1121
+ # ".to_numpy().sum()",
1122
+ # code,
1123
+ # flags=re.IGNORECASE,
1124
+ # )
1125
+ # # .to_numpy().sum(numeric_only=True, ...)
1126
+ # code = re.sub(
1127
+ # r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
1128
+ # ".to_numpy().sum()",
1129
+ # code,
1130
+ # flags=re.IGNORECASE,
1131
+ # )
1132
+ # return code
1133
+
1134
+
1135
+ # def strip_describe_slice(code: str) -> str:
1136
+ # """
1137
+ # Remove any pattern like df.groupby(...).describe()[[ ... ]] because
1138
+ # slicing a SeriesGroupBy.describe() causes AttributeError.
1139
+ # We leave the plain .describe() in place (harmless) and let our own
1140
+ # table patcher add the correct .agg() table afterwards.
1141
+ # """
1142
+ # pat = re.compile(
1143
+ # r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
1144
+ # flags=re.DOTALL,
1145
+ # )
1146
+ # return pat.sub(r"\1)", code)
1147
+
1148
+
1149
+ # def remove_plt_show(code: str) -> str:
1150
+ # """Removes all plt.show() calls from the generated code string."""
1151
+ # return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
1152
+
1153
+
1154
+ # def patch_plot_with_table(code: str) -> str:
1155
+ # """
1156
+ # ▸ strips every `plt.show()` (avoids warnings)
1157
+ # ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
1158
+ # rendered in the dashboard
1159
+ # ▸ appends a summary-stats table **after** the plot
1160
+ # """
1161
+ # # 0. drop plt.show()
1162
+ # lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
1163
+
1164
+ # # 1. locate the last plotting line
1165
+ # plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
1166
+ # last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
1167
+ # if last_plot == -1:
1168
+ # return "\n".join(lines) # nothing to do
1169
+
1170
+ # whole = "\n".join(lines)
1171
+
1172
+ # # 2. detect group / feature (if any)
1173
+ # group, feature = None, None
1174
+ # xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
1175
+ # ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
1176
+ # if xm and ym:
1177
+ # group, feature = xm.group(1), ym.group(1)
1178
+ # else:
1179
+ # cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
1180
+ # if cm:
1181
+ # feature, group = cm.group(1), cm.group(2)
1182
+
1183
+ # # 3. code that captures current fig → PNG → HTML
1184
+ # img_block = textwrap.dedent("""
1185
+ # import io, base64
1186
+ # buf = io.BytesIO()
1187
+ # plt.savefig(buf, format='png', bbox_inches='tight')
1188
+ # buf.seek(0)
1189
+ # img_b64 = base64.b64encode(buf.read()).decode('utf-8')
1190
+ # from IPython.display import display, HTML
1191
+ # display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
1192
+ # plt.close()
1193
+ # """)
1194
+
1195
+ # # 4. build summary-table code
1196
+ # if group and feature:
1197
+ # tbl_block = (
1198
+ # f"summary_table = (\n"
1199
+ # f" df.groupby('{group}')['{feature}']\n"
1200
+ # f" .agg(['count','mean','std','min','median','max'])\n"
1201
+ # f" .rename(columns={{'median':'50%'}})\n"
1202
+ # f")\n"
1203
+ # )
1204
+ # elif ym:
1205
+ # feature = ym.group(1)
1206
+ # tbl_block = (
1207
+ # f"summary_table = (\n"
1208
+ # f" df['{feature}']\n"
1209
+ # f" .agg(['count','mean','std','min','median','max'])\n"
1210
+ # f" .rename(columns={{'median':'50%'}})\n"
1211
+ # f")\n"
1212
+ # )
1110
1213
 
1111
- # 3️⃣ grid-search results
1112
- elif "GridSearchCV(" in code:
1113
- tbl_block = textwrap.dedent("""
1114
- # build tidy CV-results table
1115
- cv_df = (
1116
- pd.DataFrame(grid_search.cv_results_)
1117
- .loc[:, ['param_n_estimators', 'param_max_depth',
1118
- 'mean_test_score', 'std_test_score']]
1119
- .rename(columns={
1120
- 'param_n_estimators': 'n_estimators',
1121
- 'param_max_depth': 'max_depth',
1122
- 'mean_test_score': 'mean_cv_accuracy',
1123
- 'std_test_score': 'std'
1124
- })
1125
- .sort_values('mean_cv_accuracy', ascending=False)
1126
- .reset_index(drop=True)
1127
- )
1128
- summary_table = cv_df
1129
- """).strip() + "\n"
1130
- else:
1131
- tbl_block = (
1132
- "summary_table = (\n"
1133
- " df.describe().T[['count','mean','std','min','50%','max']]\n"
1134
- ")\n"
1135
- )
1136
-
1137
- tbl_block += "show(summary_table, title='Summary Statistics')"
1138
-
1139
- # 5. inject image-export block, then table block, after the plot
1140
- patched = (
1141
- lines[:last_plot+1]
1142
- + img_block.splitlines()
1143
- + tbl_block.splitlines()
1144
- + lines[last_plot+1:]
1145
- )
1146
- patched_code = "\n".join(patched)
1147
- # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
1148
- return textwrap.dedent(patched_code)
1149
-
1150
-
1151
- def refine_eda_question(raw_question, df=None, max_points=1000):
1152
- """
1153
- Rewrites user's EDA question to avoid classic mistakes:
1154
- - For line plots and scatter: recommend aggregation or sampling if large.
1155
- - For histograms/bar: clarify which variable to plot and bin count.
1156
- - For correlation: suggest a heatmap.
1157
- - For counts: direct request for df.shape[0].
1158
- df (optional): pass DataFrame for column inspection.
1159
- """
1160
-
1161
- # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
1162
- pc = re.match(
1163
- r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
1164
- raw_question, re.I
1165
- )
1166
- if pc:
1167
- col1, col2 = pc.group(1), pc.group(3)
1168
- # Return an instruction that preserves the exact intent
1169
- return (
1170
- f"Compute the Pearson correlation coefficient (r) and p-value "
1171
- f"between {col1} and {col2}. "
1172
- f"Print a short interpretation."
1173
- )
1174
- # -----------------------------------------------------------------
1175
- # ── Detect "predict <column>" intent ──────────────────────────────
1176
- c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
1177
- if c:
1178
- target = c.group(1)
1179
- raw_question += (
1180
- f" IMPORTANT: do NOT recreate or overwrite the existing target column "
1181
- f"“{target}”. Use it as-is for y = df['{target}']."
1182
- )
1183
-
1184
- q = raw_question.strip()
1185
- # REMOVE explicit summary-table instructions
1186
- # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
1187
- q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
1188
- q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
1189
-
1190
- ql = q.lower()
1191
-
1192
- # ── NEW: if the text contains an exact column name, leave it alone ──
1193
- if df is not None:
1194
- for col in df.columns:
1195
- if col.lower() in ql:
1196
- return q
1197
-
1198
- modelling_keywords = (
1199
- "random forest", "gradient-boost", "tree-based model",
1200
- "feature importance", "feature importances",
1201
- "overall accuracy", "train a model", "predict "
1202
- )
1203
- if any(k in ql for k in modelling_keywords):
1204
- return q
1205
-
1206
- # 1. Line plots: average if plotting raw numeric vs numeric
1207
- if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
1208
- match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
1209
- if match:
1210
- y, _, x = match.groups()
1211
- return f"Show me the average {y} by {x} as a line plot."
1212
-
1213
- # 2. Scatter plots: sample if too large
1214
- if "scatter" in ql or "scatter plot" in ql:
1215
- if df is not None and df.shape[0] > max_points:
1216
- return q + " (use only a random sample of 1000 points to avoid overplotting)"
1217
- else:
1218
- return q
1219
-
1220
- # 3. Histogram: specify bins and column
1221
- if "histogram" in ql:
1222
- match = re.search(r'histogram of ([\w_]+)', ql)
1223
- if match:
1224
- col = match.group(1)
1225
- return f"Show me a histogram of {col} using 20 bins."
1226
-
1227
- # Special case: histogram for column with most missing values
1228
- if "most missing" in ql:
1229
- return (
1230
- "Show a histogram for the column with the most missing values. "
1231
- "First, select the column using: "
1232
- "column_with_most_missing = df.isnull().sum().idxmax(); "
1233
- "then plot its histogram with: "
1234
- "df[column_with_most_missing].hist()"
1235
- )
1236
-
1237
- # 4. Bar plot: show top N
1238
- if "bar plot" in ql or "bar chart" in ql:
1239
- match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
1240
- if match:
1241
- col = match.group(2)
1242
- return f"Show me a bar plot of the top 10 {col} values."
1243
-
1244
- # 5. Correlation or heatmap
1245
- if "correlation" in ql:
1246
- return (
1247
- "Show a correlation heatmap for all numeric columns only. "
1248
- "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
1249
- )
1250
-
1251
-
1252
- # 6. Counts/size
1253
- if "how many record" in ql or "row count" in ql or "number of rows" in ql:
1254
- return "How many rows are in the dataset?"
1255
-
1256
- # 7. General best-practices fallback: add axis labels/titles
1257
- if "plot" in ql:
1258
- return q + " (make sure the axes are labeled and the plot is readable)"
1214
+ # # 3️⃣ grid-search results
1215
+ # elif "GridSearchCV(" in code:
1216
+ # tbl_block = textwrap.dedent("""
1217
+ # # build tidy CV-results table
1218
+ # cv_df = (
1219
+ # pd.DataFrame(grid_search.cv_results_)
1220
+ # .loc[:, ['param_n_estimators', 'param_max_depth',
1221
+ # 'mean_test_score', 'std_test_score']]
1222
+ # .rename(columns={
1223
+ # 'param_n_estimators': 'n_estimators',
1224
+ # 'param_max_depth': 'max_depth',
1225
+ # 'mean_test_score': 'mean_cv_accuracy',
1226
+ # 'std_test_score': 'std'
1227
+ # })
1228
+ # .sort_values('mean_cv_accuracy', ascending=False)
1229
+ # .reset_index(drop=True)
1230
+ # )
1231
+ # summary_table = cv_df
1232
+ # """).strip() + "\n"
1233
+ # else:
1234
+ # tbl_block = (
1235
+ # "summary_table = (\n"
1236
+ # " df.describe().T[['count','mean','std','min','50%','max']]\n"
1237
+ # ")\n"
1238
+ # )
1239
+
1240
+ # tbl_block += "show(summary_table, title='Summary Statistics')"
1241
+
1242
+ # # 5. inject image-export block, then table block, after the plot
1243
+ # patched = (
1244
+ # lines[:last_plot+1]
1245
+ # + img_block.splitlines()
1246
+ # + tbl_block.splitlines()
1247
+ # + lines[last_plot+1:]
1248
+ # )
1249
+ # patched_code = "\n".join(patched)
1250
+ # # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
1251
+ # return textwrap.dedent(patched_code)
1252
+
1253
+
1254
+ # def refine_eda_question(raw_question, df=None, max_points=1000):
1255
+ # """
1256
+ # Rewrites user's EDA question to avoid classic mistakes:
1257
+ # - For line plots and scatter: recommend aggregation or sampling if large.
1258
+ # - For histograms/bar: clarify which variable to plot and bin count.
1259
+ # - For correlation: suggest a heatmap.
1260
+ # - For counts: direct request for df.shape[0].
1261
+ # df (optional): pass DataFrame for column inspection.
1262
+ # """
1263
+
1264
+ # # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
1265
+ # pc = re.match(
1266
+ # r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
1267
+ # raw_question, re.I
1268
+ # )
1269
+ # if pc:
1270
+ # col1, col2 = pc.group(1), pc.group(3)
1271
+ # # Return an instruction that preserves the exact intent
1272
+ # return (
1273
+ # f"Compute the Pearson correlation coefficient (r) and p-value "
1274
+ # f"between {col1} and {col2}. "
1275
+ # f"Print a short interpretation."
1276
+ # )
1277
+ # # -----------------------------------------------------------------
1278
+ # # ── Detect "predict <column>" intent ──────────────────────────────
1279
+ # c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
1280
+ # if c:
1281
+ # target = c.group(1)
1282
+ # raw_question += (
1283
+ # f" IMPORTANT: do NOT recreate or overwrite the existing target column "
1284
+ # f"“{target}”. Use it as-is for y = df['{target}']."
1285
+ # )
1286
+
1287
+ # q = raw_question.strip()
1288
+ # # REMOVE explicit summary-table instructions
1289
+ # # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
1290
+ # q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
1291
+ # q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
1292
+
1293
+ # ql = q.lower()
1294
+
1295
+ # # ── NEW: if the text contains an exact column name, leave it alone ──
1296
+ # if df is not None:
1297
+ # for col in df.columns:
1298
+ # if col.lower() in ql:
1299
+ # return q
1300
+
1301
+ # modelling_keywords = (
1302
+ # "random forest", "gradient-boost", "tree-based model",
1303
+ # "feature importance", "feature importances",
1304
+ # "overall accuracy", "train a model", "predict "
1305
+ # )
1306
+ # if any(k in ql for k in modelling_keywords):
1307
+ # return q
1308
+
1309
+ # # 1. Line plots: average if plotting raw numeric vs numeric
1310
+ # if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
1311
+ # match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
1312
+ # if match:
1313
+ # y, _, x = match.groups()
1314
+ # return f"Show me the average {y} by {x} as a line plot."
1315
+
1316
+ # # 2. Scatter plots: sample if too large
1317
+ # if "scatter" in ql or "scatter plot" in ql:
1318
+ # if df is not None and df.shape[0] > max_points:
1319
+ # return q + " (use only a random sample of 1000 points to avoid overplotting)"
1320
+ # else:
1321
+ # return q
1322
+
1323
+ # # 3. Histogram: specify bins and column
1324
+ # if "histogram" in ql:
1325
+ # match = re.search(r'histogram of ([\w_]+)', ql)
1326
+ # if match:
1327
+ # col = match.group(1)
1328
+ # return f"Show me a histogram of {col} using 20 bins."
1329
+
1330
+ # # Special case: histogram for column with most missing values
1331
+ # if "most missing" in ql:
1332
+ # return (
1333
+ # "Show a histogram for the column with the most missing values. "
1334
+ # "First, select the column using: "
1335
+ # "column_with_most_missing = df.isnull().sum().idxmax(); "
1336
+ # "then plot its histogram with: "
1337
+ # "df[column_with_most_missing].hist()"
1338
+ # )
1339
+
1340
+ # # 4. Bar plot: show top N
1341
+ # if "bar plot" in ql or "bar chart" in ql:
1342
+ # match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
1343
+ # if match:
1344
+ # col = match.group(2)
1345
+ # return f"Show me a bar plot of the top 10 {col} values."
1346
+
1347
+ # # 5. Correlation or heatmap
1348
+ # if "correlation" in ql:
1349
+ # return (
1350
+ # "Show a correlation heatmap for all numeric columns only. "
1351
+ # "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
1352
+ # )
1353
+
1354
+
1355
+ # # 6. Counts/size
1356
+ # if "how many record" in ql or "row count" in ql or "number of rows" in ql:
1357
+ # return "How many rows are in the dataset?"
1358
+
1359
+ # # 7. General best-practices fallback: add axis labels/titles
1360
+ # if "plot" in ql:
1361
+ # return q + " (make sure the axes are labeled and the plot is readable)"
1259
1362
 
1260
- # 8.
1261
- if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
1262
- match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
1263
- col = match.group(1) if match else None
1264
- if col:
1265
- return (
1266
- f"Show a bar plot of the counts of {col} using: "
1267
- f"df['{col}'].value_counts().plot(kind='bar'); "
1268
- "add axis labels and a title, then plt.show()."
1269
- )
1363
+ # # 8.
1364
+ # if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
1365
+ # match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
1366
+ # col = match.group(1) if match else None
1367
+ # if col:
1368
+ # return (
1369
+ # f"Show a bar plot of the counts of {col} using: "
1370
+ # f"df['{col}'].value_counts().plot(kind='bar'); "
1371
+ # "add axis labels and a title, then plt.show()."
1372
+ # )
1270
1373
 
1271
- if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
1272
- return (
1273
- "Show a table of the mean, median, and standard deviation for all numeric columns. "
1274
- "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
1275
- )
1374
+ # if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
1375
+ # return (
1376
+ # "Show a table of the mean, median, and standard deviation for all numeric columns. "
1377
+ # "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
1378
+ # )
1276
1379
 
1277
- # 9. Fallback: return the raw question
1278
- return q
1380
+ # # 9. Fallback: return the raw question
1381
+ # return q
1279
1382
 
1280
1383
 
1281
- def patch_plot_code(code, df, user_question=None):
1384
+ # def patch_plot_code(code, df, user_question=None):
1282
1385
 
1283
- # ── Early guard: abort nicely if the generated code references columns that
1284
- # do not exist in the DataFrame. This prevents KeyError crashes.
1285
- import re
1386
+ # # ── Early guard: abort nicely if the generated code references columns that
1387
+ # # do not exist in the DataFrame. This prevents KeyError crashes.
1388
+ # import re
1286
1389
 
1287
1390
 
1288
- # ── Detect columns referenced in the code ──────────────────────────
1289
- col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
1290
-
1291
- # Columns that will be newly CREATED (appear left of '=')
1292
- new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
1293
-
1294
- missing_cols = [
1295
- col for col in col_refs
1296
- if col not in df.columns and col not in new_cols
1297
- ]
1298
-
1299
- if missing_cols:
1300
- cols_list = ", ".join(missing_cols)
1301
- warning = (
1302
- f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1303
- "These must either exist in df or be created earlier in the code; "
1304
- "otherwise you may see a KeyError.')\n"
1305
- )
1306
- # Prepend the warning but keep the original code so it can still run
1307
- code = warning + code
1391
+ # # ── Detect columns referenced in the code ──────────────────────────
1392
+ # col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
1393
+
1394
+ # # Columns that will be newly CREATED (appear left of '=')
1395
+ # new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
1396
+
1397
+ # missing_cols = [
1398
+ # col for col in col_refs
1399
+ # if col not in df.columns and col not in new_cols
1400
+ # ]
1401
+
1402
+ # if missing_cols:
1403
+ # cols_list = ", ".join(missing_cols)
1404
+ # warning = (
1405
+ # f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
1406
+ # "These must either exist in df or be created earlier in the code; "
1407
+ # "otherwise you may see a KeyError.')\n"
1408
+ # )
1409
+ # # Prepend the warning but keep the original code so it can still run
1410
+ # code = warning + code
1308
1411
 
1309
- # 1. For line plots (auto-aggregate)
1310
- m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1311
- if m_l:
1312
- x, y = m_l.groups()
1313
- if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
1314
- return (
1315
- f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
1316
- f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
1317
- f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
1318
- )
1319
-
1320
- # 2. For scatter plots: sample to 1000 points max
1321
- m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1322
- if m_s:
1323
- x, y = m_s.groups()
1324
- if len(df) > 1000:
1325
- return (
1326
- f"samp = df.sample(1000, random_state=42)\n"
1327
- f"plt.scatter(samp['{x}'], samp['{y}'])\n"
1328
- f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
1329
- )
1330
-
1331
- # 3. For histograms: use bins=20 for numeric, value_counts for categorical
1332
- m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
1333
- if m_h:
1334
- col = m_h.group(1)
1335
- if pd.api.types.is_numeric_dtype(df[col]):
1336
- return (
1337
- f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
1338
- f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
1339
- )
1340
- else:
1341
- # If categorical, show bar plot of value counts
1342
- return (
1343
- f"df['{col}'].value_counts().plot(kind='bar')\n"
1344
- f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
1345
- )
1346
-
1347
- # 4. For bar plots: show only top 20
1348
- m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
1349
- if m_b:
1350
- col = m_b.group(1)
1351
- if df[col].nunique() > 20:
1352
- return (
1353
- f"topN = df['{col}'].value_counts().head(20)\n"
1354
- f"topN.plot(kind='bar')\n"
1355
- f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
1356
- )
1357
-
1358
- # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
1359
- if "df.plot" in code and len(df) > 10000:
1360
- return (
1361
- f"samp = df.sample(1000, random_state=42)\n"
1362
- + code.replace("df.", "samp.")
1363
- )
1412
+ # # 1. For line plots (auto-aggregate)
1413
+ # m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1414
+ # if m_l:
1415
+ # x, y = m_l.groups()
1416
+ # if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
1417
+ # return (
1418
+ # f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
1419
+ # f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
1420
+ # f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
1421
+ # )
1422
+
1423
+ # # 2. For scatter plots: sample to 1000 points max
1424
+ # m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
1425
+ # if m_s:
1426
+ # x, y = m_s.groups()
1427
+ # if len(df) > 1000:
1428
+ # return (
1429
+ # f"samp = df.sample(1000, random_state=42)\n"
1430
+ # f"plt.scatter(samp['{x}'], samp['{y}'])\n"
1431
+ # f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
1432
+ # )
1433
+
1434
+ # # 3. For histograms: use bins=20 for numeric, value_counts for categorical
1435
+ # m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
1436
+ # if m_h:
1437
+ # col = m_h.group(1)
1438
+ # if pd.api.types.is_numeric_dtype(df[col]):
1439
+ # return (
1440
+ # f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
1441
+ # f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
1442
+ # )
1443
+ # else:
1444
+ # # If categorical, show bar plot of value counts
1445
+ # return (
1446
+ # f"df['{col}'].value_counts().plot(kind='bar')\n"
1447
+ # f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
1448
+ # )
1449
+
1450
+ # # 4. For bar plots: show only top 20
1451
+ # m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
1452
+ # if m_b:
1453
+ # col = m_b.group(1)
1454
+ # if df[col].nunique() > 20:
1455
+ # return (
1456
+ # f"topN = df['{col}'].value_counts().head(20)\n"
1457
+ # f"topN.plot(kind='bar')\n"
1458
+ # f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
1459
+ # )
1460
+
1461
+ # # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
1462
+ # if "df.plot" in code and len(df) > 10000:
1463
+ # return (
1464
+ # f"samp = df.sample(1000, random_state=42)\n"
1465
+ # + code.replace("df.", "samp.")
1466
+ # )
1364
1467
 
1365
- # ── Block assignment to an existing target column ────────────────
1366
- #*******************************************************
1367
- target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
1368
- if target_match:
1369
- target = target_match.group(1)
1370
-
1371
- # pattern for an assignment to that target
1372
- assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
1373
- assign_line = re.search(assign_pat + r".*", code)
1374
- if assign_line:
1375
- # runtime check: keep the assignment **only if** the column is absent
1376
- guard = (
1377
- f"if '{target}' in df.columns:\n"
1378
- f" print('⚠️ {target} already exists – overwrite skipped.');\n"
1379
- f"else:\n"
1380
- f" {assign_line.group(0)}"
1381
- )
1382
- # remove original assignment line and insert guarded block
1383
- code = code.replace(assign_line.group(0), guard, 1)
1384
- # ***************************************************
1468
+ # # ── Block assignment to an existing target column ────────────────
1469
+ # #*******************************************************
1470
+ # target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
1471
+ # if target_match:
1472
+ # target = target_match.group(1)
1473
+
1474
+ # # pattern for an assignment to that target
1475
+ # assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
1476
+ # assign_line = re.search(assign_pat + r".*", code)
1477
+ # if assign_line:
1478
+ # # runtime check: keep the assignment **only if** the column is absent
1479
+ # guard = (
1480
+ # f"if '{target}' in df.columns:\n"
1481
+ # f" print('⚠️ {target} already exists – overwrite skipped.');\n"
1482
+ # f"else:\n"
1483
+ # f" {assign_line.group(0)}"
1484
+ # )
1485
+ # # remove original assignment line and insert guarded block
1486
+ # code = code.replace(assign_line.group(0), guard, 1)
1487
+ # # ***************************************************
1385
1488
 
1386
- # 6. Grouped bar plot for two categoricals
1387
- # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
1388
- if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
1389
- # Try to infer columns from user question if possible:
1390
- group, cat = None, None
1391
- if user_question:
1392
- # crude parse for "counts of X for each Y"
1393
- m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
1394
- if m:
1395
- cat, group = m.groups()
1396
- if not (cat and group):
1397
- # fallback: use two most frequent categoricals
1398
- categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
1399
- if len(categoricals) >= 2:
1400
- cat, group = categoricals[:2]
1401
- else:
1402
- # fallback: any
1403
- cat, group = df.columns[:2]
1404
- return (
1405
- f"import pandas as pd\n"
1406
- f"import matplotlib.pyplot as plt\n"
1407
- f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
1408
- f"ct.plot(kind='bar')\n"
1409
- f"plt.title('Counts of {cat} for each {group}')\n"
1410
- f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
1411
- )
1412
-
1413
- # Fallback: Return original code
1414
- return code
1415
-
1416
-
1417
- def ensure_matplotlib_title(code, title_var="refined_question"):
1418
- import re
1419
- makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1420
- has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1421
- if makes_plot and not has_title:
1422
- code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1423
- return code
1489
+ # # 6. Grouped bar plot for two categoricals
1490
+ # # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
1491
+ # if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
1492
+ # # Try to infer columns from user question if possible:
1493
+ # group, cat = None, None
1494
+ # if user_question:
1495
+ # # crude parse for "counts of X for each Y"
1496
+ # m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
1497
+ # if m:
1498
+ # cat, group = m.groups()
1499
+ # if not (cat and group):
1500
+ # # fallback: use two most frequent categoricals
1501
+ # categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
1502
+ # if len(categoricals) >= 2:
1503
+ # cat, group = categoricals[:2]
1504
+ # else:
1505
+ # # fallback: any
1506
+ # cat, group = df.columns[:2]
1507
+ # return (
1508
+ # f"import pandas as pd\n"
1509
+ # f"import matplotlib.pyplot as plt\n"
1510
+ # f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
1511
+ # f"ct.plot(kind='bar')\n"
1512
+ # f"plt.title('Counts of {cat} for each {group}')\n"
1513
+ # f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
1514
+ # )
1515
+
1516
+ # # Fallback: Return original code
1517
+ # return code
1518
+
1519
+
1520
+ # def ensure_matplotlib_title(code, title_var="refined_question"):
1521
+ # import re
1522
+ # makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
1523
+ # has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
1524
+ # if makes_plot and not has_title:
1525
+ # code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
1526
+ # return code
1424
1527
 
1425
1528
 
1426
1529
  def ensure_output(code: str) -> str:
@@ -1499,1510 +1602,1455 @@ def ensure_output(code: str) -> str:
1499
1602
  return "\n".join(lines)
1500
1603
 
1501
1604
 
1502
- def get_plotting_imports(code):
1503
- imports = []
1504
- if "plt." in code and "import matplotlib.pyplot as plt" not in code:
1505
- imports.append("import matplotlib.pyplot as plt")
1506
- if "sns." in code and "import seaborn as sns" not in code:
1507
- imports.append("import seaborn as sns")
1508
- if "px." in code and "import plotly.express as px" not in code:
1509
- imports.append("import plotly.express as px")
1510
- if "pd." in code and "import pandas as pd" not in code:
1511
- imports.append("import pandas as pd")
1512
- if "np." in code and "import numpy as np" not in code:
1513
- imports.append("import numpy as np")
1514
- if "display(" in code and "from IPython.display import display" not in code:
1515
- imports.append("from IPython.display import display")
1516
- # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
1517
- if imports:
1518
- code = "\n".join(imports) + "\n\n" + code
1519
- return code
1520
-
1521
-
1522
- def patch_pairplot(code, df):
1523
- if "sns.pairplot" in code:
1524
- # Always assign and print pairgrid
1525
- code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
1526
- if "plt.show()" not in code:
1527
- code += "\nplt.show()"
1528
- if "print(pairgrid)" not in code:
1529
- code += "\nprint(pairgrid)"
1530
- return code
1531
-
1532
-
1533
- def ensure_image_output(code: str) -> str:
1534
- """
1535
- Replace each plt.show() with an indented _SMX_export_png() call.
1536
- This keeps block indentation valid and still renders images in the dashboard.
1537
- """
1538
- if "plt.show()" not in code:
1539
- return code
1540
-
1541
- import re
1542
- out_lines = []
1543
- for ln in code.splitlines():
1544
- if "plt.show()" not in ln:
1545
- out_lines.append(ln)
1546
- continue
1547
-
1548
- # works for:
1549
- # plt.show()
1550
- # plt.tight_layout(); plt.show()
1551
- # ... ; plt.show(); ... (multiple on one line)
1552
- indent = re.match(r"^(\s*)", ln).group(1)
1553
- parts = ln.split("plt.show()")
1554
-
1555
- # keep whatever is before the first plt.show()
1556
- if parts[0].strip():
1557
- out_lines.append(parts[0].rstrip())
1558
-
1559
- # for every plt.show() we removed, insert exporter at same indent
1560
- for _ in range(len(parts) - 1):
1561
- out_lines.append(indent + "_SMX_export_png()")
1562
-
1563
- # keep whatever comes after the last plt.show()
1564
- if parts[-1].strip():
1565
- out_lines.append(indent + parts[-1].lstrip())
1566
-
1567
- return "\n".join(out_lines)
1568
-
1569
-
1570
- def clean_llm_code(code: str) -> str:
1571
- """
1572
- Make LLM output safe to exec:
1573
- - If fenced blocks exist, keep the largest one (usually the real code).
1574
- - Otherwise strip any stray ``` / ```python lines.
1575
- - Remove common markdown/preamble junk.
1576
- """
1577
- code = str(code or "")
1578
-
1579
- # Extract fenced blocks (```python ... ``` or ``` ... ```)
1580
- blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1581
-
1582
- if blocks:
1583
- # pick the largest block; small trailing blocks are usually garbage
1584
- largest = max(blocks, key=lambda b: len(b.strip()))
1585
- if len(largest.strip().splitlines()) >= 10:
1586
- code = largest
1587
- else:
1588
- # if no meaningful block, just remove fence markers
1589
- code = re.sub(r"^```.*?$", "", code, flags=re.M)
1590
- else:
1591
- # no complete blocks — still remove any stray fence lines
1592
- code = re.sub(r"^```.*?$", "", code, flags=re.M)
1593
-
1594
- # Strip common markdown/preamble lines
1595
- drop_prefixes = (
1596
- "here is", "here's", "below is", "sure,", "certainly",
1597
- "explanation", "note:", "```"
1598
- )
1599
- cleaned_lines = []
1600
- for ln in code.splitlines():
1601
- s = ln.strip().lower()
1602
- if any(s.startswith(p) for p in drop_prefixes):
1603
- continue
1604
- cleaned_lines.append(ln)
1605
-
1606
- return "\n".join(cleaned_lines).strip()
1607
-
1608
-
1609
- def fix_groupby_describe_slice(code: str) -> str:
1610
- """
1611
- Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
1612
- so it works for both SeriesGroupBy and DataFrameGroupBy.
1613
- """
1614
- pat = re.compile(
1615
- r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
1616
- re.MULTILINE
1617
- )
1618
- def repl(match):
1619
- inner = match.group(0)
1620
- # extract group and feature to build df.groupby('g')['f']
1621
- g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
1622
- f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
1623
- return (
1624
- f"df.groupby('{g}')['{f}']"
1625
- ".agg(['count','mean','std','min','median','max'])"
1626
- ".rename(columns={'median':'50%'})"
1627
- )
1628
- return pat.sub(repl, code)
1629
-
1630
-
1631
- def fix_importance_groupby(code: str) -> str:
1632
- pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
1633
- if "importance_df" in code:
1634
- return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
1635
- return code
1636
-
1637
- def inject_auto_preprocessing(code: str) -> str:
1638
- """
1639
- Detects a RandomForestClassifier in the generated code.
1640
- Finds the target column from `y = df['target']`.
1641
- • Prepends a fully-dedented preprocessing snippet that:
1642
- – auto-detects numeric & categorical columns
1643
- – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
1644
- The dedent() call guarantees no leading-space IndentationError.
1645
- """
1646
- if "RandomForestClassifier" not in code:
1647
- return code # nothing to patch
1648
-
1649
- y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
1650
- if not y_match:
1651
- return code # can't infer target safely
1652
- target = y_match.group(1)
1653
-
1654
- prep_snippet = textwrap.dedent(f"""
1655
- # ── automatic preprocessing ───────────────────────────────
1656
- num_cols = df.select_dtypes(include=['number']).columns.tolist()
1657
- cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
1658
- num_cols = [c for c in num_cols if c != '{target}']
1659
- cat_cols = [c for c in cat_cols if c != '{target}']
1660
-
1661
- from sklearn.compose import ColumnTransformer
1662
- from sklearn.preprocessing import StandardScaler, OneHotEncoder
1663
-
1664
- preproc = ColumnTransformer(
1665
- transformers=[
1666
- ('num', StandardScaler(), num_cols),
1667
- ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
1668
- ],
1669
- remainder='drop',
1670
- )
1671
- # ───────────────────────────────────────────────────────────
1672
- """).strip() + "\n\n"
1673
-
1674
- # simply prepend; model code that follows can wrap estimator in a Pipeline
1675
- return prep_snippet + code
1676
-
1677
-
1678
- def fix_to_datetime_errors(code: str) -> str:
1679
- """
1680
- Force every pd.to_datetime(…) call to ignore bad dates so that
1605
+ # def get_plotting_imports(code):
1606
+ # imports = []
1607
+ # if "plt." in code and "import matplotlib.pyplot as plt" not in code:
1608
+ # imports.append("import matplotlib.pyplot as plt")
1609
+ # if "sns." in code and "import seaborn as sns" not in code:
1610
+ # imports.append("import seaborn as sns")
1611
+ # if "px." in code and "import plotly.express as px" not in code:
1612
+ # imports.append("import plotly.express as px")
1613
+ # if "pd." in code and "import pandas as pd" not in code:
1614
+ # imports.append("import pandas as pd")
1615
+ # if "np." in code and "import numpy as np" not in code:
1616
+ # imports.append("import numpy as np")
1617
+ # if "display(" in code and "from IPython.display import display" not in code:
1618
+ # imports.append("from IPython.display import display")
1619
+ # # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
1620
+ # if imports:
1621
+ # code = "\n".join(imports) + "\n\n" + code
1622
+ # return code
1623
+
1624
+
1625
+ # def patch_pairplot(code, df):
1626
+ # if "sns.pairplot" in code:
1627
+ # # Always assign and print pairgrid
1628
+ # code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
1629
+ # if "plt.show()" not in code:
1630
+ # code += "\nplt.show()"
1631
+ # if "print(pairgrid)" not in code:
1632
+ # code += "\nprint(pairgrid)"
1633
+ # return code
1634
+
1635
+
1636
+ # def ensure_image_output(code: str) -> str:
1637
+ # """
1638
+ # Replace each plt.show() with an indented _SMX_export_png() call.
1639
+ # This keeps block indentation valid and still renders images in the dashboard.
1640
+ # """
1641
+ # if "plt.show()" not in code:
1642
+ # return code
1643
+
1644
+ # import re
1645
+ # out_lines = []
1646
+ # for ln in code.splitlines():
1647
+ # if "plt.show()" not in ln:
1648
+ # out_lines.append(ln)
1649
+ # continue
1650
+
1651
+ # # works for:
1652
+ # # plt.show()
1653
+ # # plt.tight_layout(); plt.show()
1654
+ # # ... ; plt.show(); ... (multiple on one line)
1655
+ # indent = re.match(r"^(\s*)", ln).group(1)
1656
+ # parts = ln.split("plt.show()")
1657
+
1658
+ # # keep whatever is before the first plt.show()
1659
+ # if parts[0].strip():
1660
+ # out_lines.append(parts[0].rstrip())
1661
+
1662
+ # # for every plt.show() we removed, insert exporter at same indent
1663
+ # for _ in range(len(parts) - 1):
1664
+ # out_lines.append(indent + "_SMX_export_png()")
1665
+
1666
+ # # keep whatever comes after the last plt.show()
1667
+ # if parts[-1].strip():
1668
+ # out_lines.append(indent + parts[-1].lstrip())
1669
+
1670
+ # return "\n".join(out_lines)
1671
+
1672
+
1673
+ # def clean_llm_code(code: str) -> str:
1674
+ # """
1675
+ # Make LLM output safe to exec:
1676
+ # - If fenced blocks exist, keep the largest one (usually the real code).
1677
+ # - Otherwise strip any stray ``` / ```python lines.
1678
+ # - Remove common markdown/preamble junk.
1679
+ # """
1680
+ # code = str(code or "")
1681
+
1682
+ # # Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
1683
+ # # is accidentally passed here as `code`. In that case, extract the actual
1684
+ # # Python code from the ChatCompletionMessage(content=...) field.
1685
+ # if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
1686
+ # try:
1687
+ # extracted = None
1688
+
1689
+ # class _ChatCompletionVisitor(ast.NodeVisitor):
1690
+ # def visit_Call(self, node):
1691
+ # nonlocal extracted
1692
+ # func = node.func
1693
+ # fname = getattr(func, "id", None) or getattr(func, "attr", None)
1694
+ # if fname == "ChatCompletionMessage":
1695
+ # for kw in node.keywords:
1696
+ # if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
1697
+ # extracted = kw.value.value
1698
+ # self.generic_visit(node)
1699
+
1700
+ # tree = ast.parse(code, mode="exec")
1701
+ # _ChatCompletionVisitor().visit(tree)
1702
+ # if extracted:
1703
+ # code = extracted
1704
+ # except Exception:
1705
+ # # Best-effort regex fallback if AST parsing fails
1706
+ # m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
1707
+ # if m:
1708
+ # code = m.group("body")
1709
+
1710
+ # # Existing logic continues unchanged below...
1711
+ # # Extract fenced blocks (```python ... ``` or ``` ... ```)
1712
+ # blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1713
+
1714
+ # if blocks:
1715
+ # # pick the largest block; small trailing blocks are usually garbage
1716
+ # largest = max(blocks, key=lambda b: len(b.strip()))
1717
+ # if len(largest.strip().splitlines()) >= 10:
1718
+ # code = largest
1719
+
1720
+ # # Extract fenced blocks (```python ... ``` or ``` ... ```)
1721
+ # blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
1722
+
1723
+ # if blocks:
1724
+ # # pick the largest block; small trailing blocks are usually garbage
1725
+ # largest = max(blocks, key=lambda b: len(b.strip()))
1726
+ # if len(largest.strip().splitlines()) >= 10:
1727
+ # code = largest
1728
+ # else:
1729
+ # # if no meaningful block, just remove fence markers
1730
+ # code = re.sub(r"^```.*?$", "", code, flags=re.M)
1731
+ # else:
1732
+ # # no complete blocks — still remove any stray fence lines
1733
+ # code = re.sub(r"^```.*?$", "", code, flags=re.M)
1734
+
1735
+ # # Strip common markdown/preamble lines
1736
+ # drop_prefixes = (
1737
+ # "here is", "here's", "below is", "sure,", "certainly",
1738
+ # "explanation", "note:", "```"
1739
+ # )
1740
+ # cleaned_lines = []
1741
+ # for ln in code.splitlines():
1742
+ # s = ln.strip().lower()
1743
+ # if any(s.startswith(p) for p in drop_prefixes):
1744
+ # continue
1745
+ # cleaned_lines.append(ln)
1746
+
1747
+ # return "\n".join(cleaned_lines).strip()
1748
+
1749
+
1750
+ # def fix_groupby_describe_slice(code: str) -> str:
1751
+ # """
1752
+ # Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
1753
+ # so it works for both SeriesGroupBy and DataFrameGroupBy.
1754
+ # """
1755
+ # pat = re.compile(
1756
+ # r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
1757
+ # re.MULTILINE
1758
+ # )
1759
+ # def repl(match):
1760
+ # inner = match.group(0)
1761
+ # # extract group and feature to build df.groupby('g')['f']
1762
+ # g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
1763
+ # f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
1764
+ # return (
1765
+ # f"df.groupby('{g}')['{f}']"
1766
+ # ".agg(['count','mean','std','min','median','max'])"
1767
+ # ".rename(columns={'median':'50%'})"
1768
+ # )
1769
+ # return pat.sub(repl, code)
1770
+
1771
+
1772
+ # def fix_importance_groupby(code: str) -> str:
1773
+ # pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
1774
+ # if "importance_df" in code:
1775
+ # return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
1776
+ # return code
1777
+
1778
+ # def inject_auto_preprocessing(code: str) -> str:
1779
+ # """
1780
+ # • Detects a RandomForestClassifier in the generated code.
1781
+ # • Finds the target column from `y = df['target']`.
1782
+ # • Prepends a fully-dedented preprocessing snippet that:
1783
+ # – auto-detects numeric & categorical columns
1784
+ # – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
1785
+ # The dedent() call guarantees no leading-space IndentationError.
1786
+ # """
1787
+ # if "RandomForestClassifier" not in code:
1788
+ # return code # nothing to patch
1789
+
1790
+ # y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
1791
+ # if not y_match:
1792
+ # return code # can't infer target safely
1793
+ # target = y_match.group(1)
1794
+
1795
+ # prep_snippet = textwrap.dedent(f"""
1796
+ # # ── automatic preprocessing ───────────────────────────────
1797
+ # num_cols = df.select_dtypes(include=['number']).columns.tolist()
1798
+ # cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
1799
+ # num_cols = [c for c in num_cols if c != '{target}']
1800
+ # cat_cols = [c for c in cat_cols if c != '{target}']
1801
+
1802
+ # from sklearn.compose import ColumnTransformer
1803
+ # from sklearn.preprocessing import StandardScaler, OneHotEncoder
1804
+
1805
+ # preproc = ColumnTransformer(
1806
+ # transformers=[
1807
+ # ('num', StandardScaler(), num_cols),
1808
+ # ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
1809
+ # ],
1810
+ # remainder='drop',
1811
+ # )
1812
+ # # ───────────────────────────────────────────────────────────
1813
+ # """).strip() + "\n\n"
1814
+
1815
+ # # simply prepend; model code that follows can wrap estimator in a Pipeline
1816
+ # return prep_snippet + code
1817
+
1818
+
1819
+ # def fix_to_datetime_errors(code: str) -> str:
1820
+ # """
1821
+ # Force every pd.to_datetime(…) call to ignore bad dates so that
1681
1822
 
1682
- 'year 16500 is out of range' and similar issues don’t crash runs.
1683
- """
1684
- import re
1685
- # look for any pd.to_datetime( … )
1686
- pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
1687
- def repl(m):
1688
- inside = m.group(1)
1689
- # if the call already has errors=, leave it unchanged
1690
- if "errors=" in inside:
1691
- return m.group(0)
1692
- return f"pd.to_datetime({inside}, errors='coerce')"
1693
- return pat.sub(repl, code)
1694
-
1695
-
1696
- def fix_numeric_sum(code: str) -> str:
1697
- """
1698
- Make .sum(...) code safe across pandas versions by removing any
1699
- numeric_only=... argument (True/False/None) from function calls.
1700
-
1701
- This avoids errors on pandas versions where numeric_only is not
1702
- supported for Series/grouped sums, and we rely instead on explicit
1703
- numeric column selection (e.g. select_dtypes) in the generated code.
1704
- """
1705
- # Case 1: ..., numeric_only=True/False/None
1706
- code = re.sub(
1707
- r",\s*numeric_only\s*=\s*(True|False|None)",
1708
- "",
1709
- code,
1710
- flags=re.IGNORECASE,
1711
- )
1712
-
1713
- # Case 2: numeric_only=True/False/None, ... (as first argument)
1714
- code = re.sub(
1715
- r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1716
- "",
1717
- code,
1718
- flags=re.IGNORECASE,
1719
- )
1720
-
1721
- # Case 3: numeric_only=True/False/None (only argument)
1722
- code = re.sub(
1723
- r"numeric_only\s*=\s*(True|False|None)",
1724
- "",
1725
- code,
1726
- flags=re.IGNORECASE,
1727
- )
1728
-
1729
- return code
1730
-
1731
-
1732
- def fix_concat_empty_list(code: str) -> str:
1733
- """
1734
- Make pd.concat calls resilient to empty lists of objects.
1735
-
1736
- Transforms patterns like:
1737
- pd.concat(frames, ignore_index=True)
1738
- pd.concat(frames)
1739
-
1740
- into:
1741
- pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1742
- pd.concat(frames or [pd.DataFrame()])
1743
-
1744
- Only triggers when the first argument is a simple variable name.
1745
- """
1746
- pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1747
-
1748
- def _repl(m):
1749
- name = m.group(1)
1750
- sep = m.group(2) # ',' or ')'
1751
- return f"pd.concat({name} or [pd.DataFrame()]{sep}"
1752
-
1753
- return pattern.sub(_repl, code)
1754
-
1755
-
1756
- def fix_numeric_aggs(code: str) -> str:
1757
- _AGG_FUNCS = ("sum", "mean")
1758
- pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
1759
- def _repl(m):
1760
- func, args = m.group(1), m.group(2) or ""
1761
- if "numeric_only" in args:
1762
- return m.group(0)
1763
- args = args.rstrip()
1764
- if args:
1765
- args += ", "
1766
- return f".{func}({args}numeric_only=True)"
1767
- return pat.sub(_repl, code)
1768
-
1769
-
1770
- def ensure_accuracy_block(code: str) -> str:
1771
- """
1772
- Inject a sensible evaluation block right after the last `<est>.fit(...)`
1773
- Classification → accuracy + weighted F1
1774
- Regression → R², RMSE, MAE
1775
- Heuristic: infer task from estimator names present in the code.
1776
- """
1777
- import re, textwrap
1778
-
1779
- # If any proper metric already exists, do nothing
1780
- if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
1781
- return code
1782
-
1783
- # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
1784
- m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
1785
- if not m:
1786
- return code # no estimator
1787
-
1788
- var = m[-1].group(1)
1789
- # indent with same leading whitespace used on that line
1790
- indent = re.match(r"\s*", code[m[-1].start():]).group(0)
1791
-
1792
- # Detect regression by estimator names / hints in code
1793
- is_regression = bool(
1794
- re.search(
1795
- r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1796
- r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1797
- r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1798
- )
1799
- or re.search(r"\bOLS\s*\(", code)
1800
- or re.search(r"\bRegressor\b", code)
1801
- )
1802
-
1803
- if is_regression:
1804
- # inject numpy import if needed for RMSE
1805
- if "import numpy as np" not in code and "np." not in code:
1806
- code = "import numpy as np\n" + code
1807
- eval_block = textwrap.dedent(f"""
1808
- {indent}# ── automatic regression evaluation ─────────
1809
- {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
1810
- {indent}y_pred = {var}.predict(X_test)
1811
- {indent}r2 = r2_score(y_test, y_pred)
1812
- {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
1813
- {indent}mae = float(mean_absolute_error(y_test, y_pred))
1814
- {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
1815
- """)
1816
- else:
1817
- eval_block = textwrap.dedent(f"""
1818
- {indent}# ── automatic classification evaluation ─────────
1819
- {indent}from sklearn.metrics import accuracy_score, f1_score
1820
- {indent}y_pred = {var}.predict(X_test)
1821
- {indent}acc = accuracy_score(y_test, y_pred)
1822
- {indent}f1 = f1_score(y_test, y_pred, average='weighted')
1823
- {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
1824
- """)
1825
-
1826
- insert_at = code.find("\n", m[-1].end()) + 1
1827
- return code[:insert_at] + eval_block + code[insert_at:]
1828
-
1829
-
1830
- def fix_scatter_and_summary(code: str) -> str:
1831
- """
1832
- 1. Change cmap='spectral' (any case) → cmap='Spectral'
1833
- 2. If the LLM forgets to close the parenthesis in
1834
- summary_table = ( df.describe()... <missing )>
1835
- insert the ')' right before the next 'from' or 'show('.
1836
- """
1837
- # 1️⃣ colormap case
1838
- code = re.sub(
1839
- r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
1840
- "cmap='Spectral'",
1841
- code,
1842
- flags=re.IGNORECASE,
1843
- )
1844
-
1845
- # 2️⃣ close summary_table = ( ... )
1846
- code = re.sub(
1847
- r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
1848
- r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
1849
- r"\1)", # keep group 1 and add ')'
1850
- code,
1851
- flags=re.MULTILINE,
1852
- )
1853
-
1854
- return code
1855
-
1856
-
1857
- def auto_format_with_black(code: str) -> str:
1858
- """
1859
- Format the generated code with Black. Falls back silently if Black
1860
- is missing or raises (so the dashboard never 500s).
1861
- """
1862
- try:
1863
- import black # make sure black is in your v-env: pip install black
1864
-
1865
- mode = black.FileMode() # default settings
1866
- return black.format_str(code, mode=mode)
1867
-
1868
- except Exception:
1869
- return code
1870
-
1871
-
1872
- def ensure_preproc_in_pipeline(code: str) -> str:
1873
- """
1874
- If code defines `preproc = ColumnTransformer(...)` but then builds
1875
- `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
1876
- that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
1877
- """
1878
- return re.sub(
1879
- r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
1880
- "Pipeline([('prep', preproc)",
1881
- code
1882
- )
1883
-
1884
-
1885
- def fix_plain_prints(code: str) -> str:
1886
- """
1887
- Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
1888
- to go through SyntaxMatrix's smart display (so it renders in the dashboard).
1889
- Keeps string prints alone.
1890
- """
1891
-
1892
- # Skip obvious string-literal prints
1893
- new = re.sub(
1894
- r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
1895
- r"from syntaxmatrix.display import show\nshow(\1)",
1896
- code,
1897
- )
1898
- return new
1899
-
1900
-
1901
- def fix_print_html(code: str) -> str:
1902
- """
1903
- Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
1904
- not printed as `<IPython.core.display.HTML object>` to the server console.
1905
- - Rewrites: print(HTML(...)) → display(HTML(...))
1906
- print(display(...)) → display(...)
1907
- print(df.to_html(...)) display(HTML(df.to_html(...)))
1908
- Also prepends `from IPython.display import display, HTML` if required.
1909
- """
1910
- import re
1911
-
1912
- new = code
1913
-
1914
- # 1) print(HTML(...)) -> display(HTML(...))
1915
- new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
1916
-
1917
- # 2) print(display(...)) -> display(...)
1918
- new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
1919
-
1920
- # 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
1921
- new = re.sub(
1922
- r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
1923
- r"display(HTML(\1.to_html(", new
1924
- )
1925
-
1926
- # If code references HTML() or display() make sure the import exists
1927
- if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
1928
- "from IPython.display import display, HTML" not in new:
1929
- new = "from IPython.display import display, HTML\n" + new
1930
-
1931
- return new
1932
-
1933
-
1934
- def ensure_ipy_display(code: str) -> str:
1935
- """
1936
- Guarantee that the cell has proper IPython display imports so that
1937
- display(HTML(...)) produces 'display_data' events the kernel captures.
1938
- """
1939
- if "display(" in code and "from IPython.display import display, HTML" not in code:
1940
- return "from IPython.display import display, HTML\n" + code
1941
- return code
1942
-
1943
-
1944
- def drop_bad_classification_metrics(code: str, y_or_df) -> str:
1945
- """
1946
- Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
1947
- if the generated cell is *regression*. We infer this from:
1948
- 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
1949
- 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
1950
- Safe across datasets and queries.
1951
- """
1952
- import re
1953
- import pandas as pd
1954
-
1955
- # 1) Heuristic by estimator names in the *code* (fast path)
1956
- regression_by_model = bool(re.search(
1957
- r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1958
- r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1959
- r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1960
- ) or re.search(r"\bOLS\s*\(", code))
1961
-
1962
- is_regression = regression_by_model
1963
-
1964
- # 2) If not obvious from the model, try to infer from y dtype (if we can)
1965
- if not is_regression:
1966
- try:
1967
- # Try to parse: y = df['target']
1968
- m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
1969
- if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
1970
- y = y_or_df[m.group(1)]
1971
- if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
1972
- is_regression = True
1973
- else:
1974
- # If a Series was passed
1975
- y = y_or_df
1976
- if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
1977
- is_regression = True
1978
- except Exception:
1979
- pass
1980
-
1981
- if is_regression:
1982
- # Strip classification-only lines
1983
- for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
1984
- code = re.sub(pat, "", code, flags=re.I)
1985
-
1986
- return code
1987
-
1988
-
1989
- def force_capture_display(code: str) -> str:
1990
- """
1991
- Ensure our executor captures HTML output:
1992
- - Remove any import that would override our 'display' hook.
1993
- - Keep/allow importing HTML only.
1994
- - Handle alias cases like 'display as d'.
1995
- """
1996
- import re
1997
- new = code
1998
-
1999
- # 'from IPython.display import display, HTML' -> keep HTML only
2000
- new = re.sub(
2001
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
2002
- r"from IPython.display import HTML\1", new
2003
- )
2004
-
2005
- # 'from IPython.display import display as d' -> 'd = display'
2006
- new = re.sub(
2007
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
2008
- r"\1 = display", new
2009
- )
2010
-
2011
- # 'from IPython.display import display' -> remove (use our injected display)
2012
- new = re.sub(
2013
- r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
2014
- r"# display import removed (SMX capture active)", new
2015
- )
2016
-
2017
- # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
2018
- new = re.sub(
2019
- r"(?m)\bIPython\.display\.display\s*\(",
2020
- "display(", new
2021
- )
2022
- new = re.sub(
2023
- r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
2024
- r"(?=.*import\s+IPython\.display\s+as\s+\1)",
2025
- "display(", new
2026
- )
2027
- return new
2028
-
2029
-
2030
- def strip_matplotlib_show(code: str) -> str:
2031
- """Remove blocking plt.show() calls (we export base64 instead)."""
2032
- import re
2033
- return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
2034
-
2035
-
2036
- def inject_display_shim(code: str) -> str:
2037
- """
2038
- Provide display()/HTML() if missing, forwarding to our executor hook.
2039
- Harmless if the names already exist.
2040
- """
2041
- shim = (
2042
- "try:\n"
2043
- " display\n"
2044
- "except NameError:\n"
2045
- " def display(obj=None, **kwargs):\n"
2046
- " __builtins__.get('_smx_display', print)(obj)\n"
2047
- "try:\n"
2048
- " HTML\n"
2049
- "except NameError:\n"
2050
- " class HTML:\n"
2051
- " def __init__(self, data): self.data = str(data)\n"
2052
- " def _repr_html_(self): return self.data\n"
2053
- "\n"
2054
- )
2055
- return shim + code
2056
-
2057
-
2058
- def strip_spurious_column_tokens(code: str) -> str:
2059
- """
2060
- Remove common stop-words ('the','whether', ...) when they appear
2061
- inside column lists, e.g.:
2062
- predictors = ['BMI','the','HbA1c']
2063
- df[['GGT','whether','BMI']]
2064
- Leaves other strings intact.
2065
- """
2066
- STOP = {
2067
- "the","whether","a","an","and","or","of","to","in","on","for","by",
2068
- "with","as","at","from","that","this","these","those","is","are","was","were",
2069
- "coef", "Coef", "coefficient", "Coefficient"
2070
- }
2071
-
2072
- def _norm(s: str) -> str:
2073
- return re.sub(r"[^a-z0-9]+", "", s.lower())
2074
-
2075
- def _clean_list(content: str) -> str:
2076
- # Rebuild a string list, keeping only non-stopword items
2077
- items = re.findall(r"(['\"])(.*?)\1", content)
2078
- if not items:
2079
- return "[" + content + "]"
2080
- keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
2081
- return "[" + ", ".join(keep) + "]"
2082
-
2083
- # Variable assignments: predictors/features/columns/cols = [...]
2084
- code = re.sub(
2085
- r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
2086
- lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
2087
- code
2088
- )
2089
-
2090
- # df[[ ... ]] selections
2091
- code = re.sub(
2092
- r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2093
-
2094
- return code
2095
-
2096
-
2097
- def patch_prefix_seaborn_calls(code: str) -> str:
2098
- """
2099
- Ensure bare seaborn calls are prefixed with `sns.`.
2100
- E.g., `barplot(...)` `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2101
- """
2102
- if "sns." in code:
2103
- # still fix any leftover bare calls alongside prefixed ones
2104
- pass
2105
-
2106
- # functions commonly used from seaborn
2107
- funcs = [
2108
- "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2109
- "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2110
- "scatterplot","lineplot","catplot","displot","lmplot"
2111
- ]
2112
- # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2113
- # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2114
- pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2115
-
2116
- def _add_prefix(m):
2117
- fn = m.group(1)
2118
- return f"sns.{fn}("
2119
-
2120
- return pattern.sub(_add_prefix, code)
2121
-
2122
-
2123
- def patch_ensure_seaborn_import(code: str) -> str:
2124
- """
2125
- If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2126
- Also set a quiet theme for consistent visuals.
2127
- """
2128
- needs_sns = "sns." in code
2129
- has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2130
- if needs_sns and not has_import:
2131
- # Insert after the first block of imports if possible, else at top
2132
- import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2133
- inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2134
- if import_block:
2135
- start = import_block.end()
2136
- code = code[:start] + inject + code[start:]
2137
- else:
2138
- code = inject + code
2139
- return code
2140
-
2141
-
2142
- def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2143
- """
2144
- Normalise pie-chart requests.
2145
-
2146
- Supports three patterns:
2147
- A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2148
- B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2149
- → one pie per facet level (grid) + counts bar of facet sizes.
2150
- C) Single pie when no split/facet is requested.
2151
-
2152
- Notes:
2153
- - Pie variables must be categorical (or numeric binned).
2154
- - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2155
- """
2156
-
2157
- q = (user_question or "")
2158
- q_low = q.lower()
2159
-
2160
- # Prefer explicit: df['col'].value_counts()
2161
- m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2162
- col = m.group(1) if m else None
2163
-
2164
- # ---------- helpers ----------
2165
- def _is_cat(col):
2166
- return (str(df[col].dtype).startswith("category")
2167
- or df[col].dtype == "object"
2168
- or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2169
-
2170
- def _cats_from_question(question: str):
2171
- found = []
2172
- for c in df.columns:
2173
- if c.lower() in question.lower() and _is_cat(c):
2174
- found.append(c)
2175
- # dedupe preserve order
2176
- seen, out = set(), []
2177
- for c in found:
2178
- if c not in seen:
2179
- out.append(c); seen.add(c)
2180
- return out
2181
-
2182
- def _fallback_cat():
2183
- cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2184
- if not cats: return None
2185
- cats.sort(key=lambda t: t[1])
2186
- return cats[0][0]
1823
+ # 'year 16500 is out of range' and similar issues don’t crash runs.
1824
+ # """
1825
+ # import re
1826
+ # # look for any pd.to_datetime( … )
1827
+ # pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
1828
+ # def repl(m):
1829
+ # inside = m.group(1)
1830
+ # # if the call already has errors=, leave it unchanged
1831
+ # if "errors=" in inside:
1832
+ # return m.group(0)
1833
+ # return f"pd.to_datetime({inside}, errors='coerce')"
1834
+ # return pat.sub(repl, code)
1835
+
1836
+
1837
+ # def fix_numeric_sum(code: str) -> str:
1838
+ # """
1839
+ # Make .sum(...) code safe across pandas versions by removing any
1840
+ # numeric_only=... argument (True/False/None) from function calls.
1841
+
1842
+ # This avoids errors on pandas versions where numeric_only is not
1843
+ # supported for Series/grouped sums, and we rely instead on explicit
1844
+ # numeric column selection (e.g. select_dtypes) in the generated code.
1845
+ # """
1846
+ # # Case 1: ..., numeric_only=True/False/None
1847
+ # code = re.sub(
1848
+ # r",\s*numeric_only\s*=\s*(True|False|None)",
1849
+ # "",
1850
+ # code,
1851
+ # flags=re.IGNORECASE,
1852
+ # )
1853
+
1854
+ # # Case 2: numeric_only=True/False/None, ... (as first argument)
1855
+ # code = re.sub(
1856
+ # r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
1857
+ # "",
1858
+ # code,
1859
+ # flags=re.IGNORECASE,
1860
+ # )
1861
+
1862
+ # # Case 3: numeric_only=True/False/None (only argument)
1863
+ # code = re.sub(
1864
+ # r"numeric_only\s*=\s*(True|False|None)",
1865
+ # "",
1866
+ # code,
1867
+ # flags=re.IGNORECASE,
1868
+ # )
1869
+
1870
+ # return code
1871
+
1872
+
1873
+ # def fix_concat_empty_list(code: str) -> str:
1874
+ # """
1875
+ # Make pd.concat calls resilient to empty lists of objects.
1876
+
1877
+ # Transforms patterns like:
1878
+ # pd.concat(frames, ignore_index=True)
1879
+ # pd.concat(frames)
1880
+
1881
+ # into:
1882
+ # pd.concat(frames or [pd.DataFrame()], ignore_index=True)
1883
+ # pd.concat(frames or [pd.DataFrame()])
1884
+
1885
+ # Only triggers when the first argument is a simple variable name.
1886
+ # """
1887
+ # pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
1888
+
1889
+ # def _repl(m):
1890
+ # name = m.group(1)
1891
+ # sep = m.group(2) # ',' or ')'
1892
+ # return f"pd.concat({name} or [pd.DataFrame()]{sep}"
1893
+
1894
+ # return pattern.sub(_repl, code)
1895
+
1896
+
1897
+ # def fix_numeric_aggs(code: str) -> str:
1898
+ # _AGG_FUNCS = ("sum", "mean")
1899
+ # pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
1900
+ # def _repl(m):
1901
+ # func, args = m.group(1), m.group(2) or ""
1902
+ # if "numeric_only" in args:
1903
+ # return m.group(0)
1904
+ # args = args.rstrip()
1905
+ # if args:
1906
+ # args += ", "
1907
+ # return f".{func}({args}numeric_only=True)"
1908
+ # return pat.sub(_repl, code)
1909
+
1910
+
1911
+ # def ensure_accuracy_block(code: str) -> str:
1912
+ # """
1913
+ # Inject a sensible evaluation block right after the last `<est>.fit(...)`
1914
+ # Classification → accuracy + weighted F1
1915
+ # Regression → R², RMSE, MAE
1916
+ # Heuristic: infer task from estimator names present in the code.
1917
+ # """
1918
+ # import re, textwrap
1919
+
1920
+ # # If any proper metric already exists, do nothing
1921
+ # if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
1922
+ # return code
1923
+
1924
+ # # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
1925
+ # m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
1926
+ # if not m:
1927
+ # return code # no estimator
1928
+
1929
+ # var = m[-1].group(1)
1930
+ # # indent with same leading whitespace used on that line
1931
+ # indent = re.match(r"\s*", code[m[-1].start():]).group(0)
1932
+
1933
+ # # Detect regression by estimator names / hints in code
1934
+ # is_regression = bool(
1935
+ # re.search(
1936
+ # r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
1937
+ # r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
1938
+ # r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
1939
+ # )
1940
+ # or re.search(r"\bOLS\s*\(", code)
1941
+ # or re.search(r"\bRegressor\b", code)
1942
+ # )
1943
+
1944
+ # if is_regression:
1945
+ # # inject numpy import if needed for RMSE
1946
+ # if "import numpy as np" not in code and "np." not in code:
1947
+ # code = "import numpy as np\n" + code
1948
+ # eval_block = textwrap.dedent(f"""
1949
+ # {indent}# ── automatic regression evaluation ─────────
1950
+ # {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
1951
+ # {indent}y_pred = {var}.predict(X_test)
1952
+ # {indent}r2 = r2_score(y_test, y_pred)
1953
+ # {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
1954
+ # {indent}mae = float(mean_absolute_error(y_test, y_pred))
1955
+ # {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
1956
+ # """)
1957
+ # else:
1958
+ # eval_block = textwrap.dedent(f"""
1959
+ # {indent}# ── automatic classification evaluation ─────────
1960
+ # {indent}from sklearn.metrics import accuracy_score, f1_score
1961
+ # {indent}y_pred = {var}.predict(X_test)
1962
+ # {indent}acc = accuracy_score(y_test, y_pred)
1963
+ # {indent}f1 = f1_score(y_test, y_pred, average='weighted')
1964
+ # {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
1965
+ # """)
1966
+
1967
+ # insert_at = code.find("\n", m[-1].end()) + 1
1968
+ # return code[:insert_at] + eval_block + code[insert_at:]
1969
+
1970
+
1971
+ # def fix_scatter_and_summary(code: str) -> str:
1972
+ # """
1973
+ # 1. Change cmap='spectral' (any case) → cmap='Spectral'
1974
+ # 2. If the LLM forgets to close the parenthesis in
1975
+ # summary_table = ( df.describe()... <missing )>
1976
+ # insert the ')' right before the next 'from' or 'show('.
1977
+ # """
1978
+ # # 1️⃣ colormap case
1979
+ # code = re.sub(
1980
+ # r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
1981
+ # "cmap='Spectral'",
1982
+ # code,
1983
+ # flags=re.IGNORECASE,
1984
+ # )
1985
+
1986
+ # # 2️⃣ close summary_table = ( ... )
1987
+ # code = re.sub(
1988
+ # r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
1989
+ # r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
1990
+ # r"\1)", # keep group 1 and add ')'
1991
+ # code,
1992
+ # flags=re.MULTILINE,
1993
+ # )
1994
+
1995
+ # return code
1996
+
1997
+
1998
+ # def auto_format_with_black(code: str) -> str:
1999
+ # """
2000
+ # Format the generated code with Black. Falls back silently if Black
2001
+ # is missing or raises (so the dashboard never 500s).
2002
+ # """
2003
+ # try:
2004
+ # import black # make sure black is in your v-env: pip install black
2005
+
2006
+ # mode = black.FileMode() # default settings
2007
+ # return black.format_str(code, mode=mode)
2008
+
2009
+ # except Exception:
2010
+ # return code
2011
+
2012
+
2013
+ # def ensure_preproc_in_pipeline(code: str) -> str:
2014
+ # """
2015
+ # If code defines `preproc = ColumnTransformer(...)` but then builds
2016
+ # `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
2017
+ # that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
2018
+ # """
2019
+ # return re.sub(
2020
+ # r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
2021
+ # "Pipeline([('prep', preproc)",
2022
+ # code
2023
+ # )
2024
+
2025
+ # def drop_bad_classification_metrics(code: str, y_or_df) -> str:
2026
+ # """
2027
+ # Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
2028
+ # if the generated cell is *regression*. We infer this from:
2029
+ # 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
2030
+ # 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
2031
+ # Safe across datasets and queries.
2032
+ # """
2033
+ # import re
2034
+ # import pandas as pd
2035
+
2036
+ # # 1) Heuristic by estimator names in the *code* (fast path)
2037
+ # regression_by_model = bool(re.search(
2038
+ # r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
2039
+ # r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
2040
+ # r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
2041
+ # ) or re.search(r"\bOLS\s*\(", code))
2042
+
2043
+ # is_regression = regression_by_model
2044
+
2045
+ # # 2) If not obvious from the model, try to infer from y dtype (if we can)
2046
+ # if not is_regression:
2047
+ # try:
2048
+ # # Try to parse: y = df['target']
2049
+ # m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
2050
+ # if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
2051
+ # y = y_or_df[m.group(1)]
2052
+ # if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2053
+ # is_regression = True
2054
+ # else:
2055
+ # # If a Series was passed
2056
+ # y = y_or_df
2057
+ # if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
2058
+ # is_regression = True
2059
+ # except Exception:
2060
+ # pass
2061
+
2062
+ # if is_regression:
2063
+ # # Strip classification-only lines
2064
+ # for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
2065
+ # code = re.sub(pat, "", code, flags=re.I)
2066
+
2067
+ # return code
2068
+
2069
+
2070
+ # def force_capture_display(code: str) -> str:
2071
+ # """
2072
+ # Ensure our executor captures HTML output:
2073
+ # - Remove any import that would override our 'display' hook.
2074
+ # - Keep/allow importing HTML only.
2075
+ # - Handle alias cases like 'display as d'.
2076
+ # """
2077
+ # import re
2078
+ # new = code
2079
+
2080
+ # # 'from IPython.display import display, HTML' -> keep HTML only
2081
+ # new = re.sub(
2082
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
2083
+ # r"from IPython.display import HTML\1", new
2084
+ # )
2085
+
2086
+ # # 'from IPython.display import display as d' -> 'd = display'
2087
+ # new = re.sub(
2088
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
2089
+ # r"\1 = display", new
2090
+ # )
2091
+
2092
+ # # 'from IPython.display import display' -> remove (use our injected display)
2093
+ # new = re.sub(
2094
+ # r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
2095
+ # r"# display import removed (SMX capture active)", new
2096
+ # )
2097
+
2098
+ # # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
2099
+ # new = re.sub(
2100
+ # r"(?m)\bIPython\.display\.display\s*\(",
2101
+ # "display(", new
2102
+ # )
2103
+ # new = re.sub(
2104
+ # r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
2105
+ # r"(?=.*import\s+IPython\.display\s+as\s+\1)",
2106
+ # "display(", new
2107
+ # )
2108
+ # return new
2109
+
2110
+
2111
+ # def strip_matplotlib_show(code: str) -> str:
2112
+ # """Remove blocking plt.show() calls (we export base64 instead)."""
2113
+ # import re
2114
+ # return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
2115
+
2116
+
2117
+ # def inject_display_shim(code: str) -> str:
2118
+ # """
2119
+ # Provide display()/HTML() if missing, forwarding to our executor hook.
2120
+ # Harmless if the names already exist.
2121
+ # """
2122
+ # shim = (
2123
+ # "try:\n"
2124
+ # " display\n"
2125
+ # "except NameError:\n"
2126
+ # " def display(obj=None, **kwargs):\n"
2127
+ # " __builtins__.get('_smx_display', print)(obj)\n"
2128
+ # "try:\n"
2129
+ # " HTML\n"
2130
+ # "except NameError:\n"
2131
+ # " class HTML:\n"
2132
+ # " def __init__(self, data): self.data = str(data)\n"
2133
+ # " def _repr_html_(self): return self.data\n"
2134
+ # "\n"
2135
+ # )
2136
+ # return shim + code
2137
+
2138
+
2139
+ # def strip_spurious_column_tokens(code: str) -> str:
2140
+ # """
2141
+ # Remove common stop-words ('the','whether', ...) when they appear
2142
+ # inside column lists, e.g.:
2143
+ # predictors = ['BMI','the','HbA1c']
2144
+ # df[['GGT','whether','BMI']]
2145
+ # Leaves other strings intact.
2146
+ # """
2147
+ # STOP = {
2148
+ # "the","whether","a","an","and","or","of","to","in","on","for","by",
2149
+ # "with","as","at","from","that","this","these","those","is","are","was","were",
2150
+ # "coef", "Coef", "coefficient", "Coefficient"
2151
+ # }
2152
+
2153
+ # def _norm(s: str) -> str:
2154
+ # return re.sub(r"[^a-z0-9]+", "", s.lower())
2155
+
2156
+ # def _clean_list(content: str) -> str:
2157
+ # # Rebuild a string list, keeping only non-stopword items
2158
+ # items = re.findall(r"(['\"])(.*?)\1", content)
2159
+ # if not items:
2160
+ # return "[" + content + "]"
2161
+ # keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
2162
+ # return "[" + ", ".join(keep) + "]"
2163
+
2164
+ # # Variable assignments: predictors/features/columns/cols = [...]
2165
+ # code = re.sub(
2166
+ # r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
2167
+ # lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
2168
+ # code
2169
+ # )
2170
+
2171
+ # # df[[ ... ]] selections
2172
+ # code = re.sub(
2173
+ # r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
2174
+
2175
+ # return code
2176
+
2177
+
2178
+ # def patch_prefix_seaborn_calls(code: str) -> str:
2179
+ # """
2180
+ # Ensure bare seaborn calls are prefixed with `sns.`.
2181
+ # E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
2182
+ # """
2183
+ # if "sns." in code:
2184
+ # # still fix any leftover bare calls alongside prefixed ones
2185
+ # pass
2186
+
2187
+ # # functions commonly used from seaborn
2188
+ # funcs = [
2189
+ # "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
2190
+ # "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
2191
+ # "scatterplot","lineplot","catplot","displot","lmplot"
2192
+ # ]
2193
+ # # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
2194
+ # # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
2195
+ # pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
2196
+
2197
+ # def _add_prefix(m):
2198
+ # fn = m.group(1)
2199
+ # return f"sns.{fn}("
2200
+
2201
+ # return pattern.sub(_add_prefix, code)
2202
+
2203
+
2204
+ # def patch_ensure_seaborn_import(code: str) -> str:
2205
+ # """
2206
+ # If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
2207
+ # Also set a quiet theme for consistent visuals.
2208
+ # """
2209
+ # needs_sns = "sns." in code
2210
+ # has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
2211
+ # if needs_sns and not has_import:
2212
+ # # Insert after the first block of imports if possible, else at top
2213
+ # import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
2214
+ # inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
2215
+ # if import_block:
2216
+ # start = import_block.end()
2217
+ # code = code[:start] + inject + code[start:]
2218
+ # else:
2219
+ # code = inject + code
2220
+ # return code
2221
+
2222
+
2223
+ # def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
2224
+ # """
2225
+ # Normalise pie-chart requests.
2226
+
2227
+ # Supports three patterns:
2228
+ # A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
2229
+ # B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
2230
+ # → one pie per facet level (grid) + counts bar of facet sizes.
2231
+ # C) Single pie when no split/facet is requested.
2232
+
2233
+ # Notes:
2234
+ # - Pie variables must be categorical (or numeric binned).
2235
+ # - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
2236
+ # """
2237
+
2238
+ # q = (user_question or "")
2239
+ # q_low = q.lower()
2240
+
2241
+ # # Prefer explicit: df['col'].value_counts()
2242
+ # m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
2243
+ # col = m.group(1) if m else None
2244
+
2245
+ # # ---------- helpers ----------
2246
+ # def _is_cat(col):
2247
+ # return (str(df[col].dtype).startswith("category")
2248
+ # or df[col].dtype == "object"
2249
+ # or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
2250
+
2251
+ # def _cats_from_question(question: str):
2252
+ # found = []
2253
+ # for c in df.columns:
2254
+ # if c.lower() in question.lower() and _is_cat(c):
2255
+ # found.append(c)
2256
+ # # dedupe preserve order
2257
+ # seen, out = set(), []
2258
+ # for c in found:
2259
+ # if c not in seen:
2260
+ # out.append(c); seen.add(c)
2261
+ # return out
2262
+
2263
+ # def _fallback_cat():
2264
+ # cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2265
+ # if not cats: return None
2266
+ # cats.sort(key=lambda t: t[1])
2267
+ # return cats[0][0]
2187
2268
 
2188
- def _infer_comp_pref(question: str) -> str:
2189
- ql = (question or "").lower()
2190
- if "heatmap" in ql or "matrix" in ql:
2191
- return "heatmap"
2192
- if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2193
- return "stacked_bar_pct"
2194
- if "stacked" in ql:
2195
- return "stacked_bar"
2196
- if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2197
- return "grouped_bar"
2198
- return "counts_bar"
2199
-
2200
- # parse threshold split like "HbA1c ≥ 6.5"
2201
- def _parse_split(question: str):
2202
- ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2203
- m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2204
- if not m: return None
2205
- col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2206
- op = ops_map.get(op_raw);
2207
- if not op: return None
2208
- # case-insensitive column match
2209
- candidates = {c.lower(): c for c in df.columns}
2210
- col = candidates.get(col_raw.lower())
2211
- if not col: return None
2212
- try: val = float(val_raw)
2213
- except Exception: return None
2214
- return (col, op, val)
2215
-
2216
- # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2217
- def _extract_facet(question: str):
2218
- # 1) explicit "by/ across / within / per <col>"
2219
- for kw in [" by ", " across ", " within ", " within each ", " per "]:
2220
- m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2221
- if m:
2222
- col_raw = m.group(1).strip()
2223
- candidates = {c.lower(): c for c in df.columns}
2224
- if col_raw.lower() in candidates:
2225
- return (candidates[col_raw.lower()], "auto")
2226
- # 2) "bin <col>"
2227
- m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2228
- if m2:
2229
- col_raw = m2.group(1).strip()
2230
- candidates = {c.lower(): c for c in df.columns}
2231
- if col_raw.lower() in candidates:
2232
- return (candidates[col_raw.lower()], "bin")
2233
- # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2234
- if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2235
- any(c.lower() == "bmi" for c in df.columns.str.lower()):
2236
- bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2237
- return (bmi_col, "bmi")
2238
- return None
2239
-
2240
- def _bmi_bins(series: pd.Series):
2241
- # WHO cutoffs
2242
- bins = [-np.inf, 18.5, 25, 30, np.inf]
2243
- labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2244
- return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2245
-
2246
- wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2247
- if not wants_pie:
2248
- return code
2249
-
2250
- split = _parse_split(q)
2251
- facet = _extract_facet(q)
2252
- cats = _cats_from_question(q)
2253
- _comp_pref = _infer_comp_pref(q)
2254
-
2255
- # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2256
- for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2257
- if hard in df.columns and hard not in cats and hard.lower() in q_low:
2258
- cats.append(hard)
2259
-
2260
- # --------------- CASE A: threshold split (cohorts) ---------------
2261
- if split:
2262
- if not (cats or any(_is_cat(c) for c in df.columns)):
2263
- return code
2264
- if not cats:
2265
- pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2266
- pool.sort(key=lambda t: t[1])
2267
- cats = [t[0] for t in pool[:3]] if pool else []
2268
- if not cats:
2269
- return code
2270
-
2271
- split_col, op, val = split
2272
- cond_str = f"(df['{split_col}'] {op} {val})"
2273
- snippet = f"""
2274
- import numpy as np
2275
- import pandas as pd
2276
- import matplotlib.pyplot as plt
2277
-
2278
- _mask_a = ({cond_str}) & df['{split_col}'].notna()
2279
- _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2280
-
2281
- _cohort_a_name = "{split_col} {op} {val}"
2282
- _cohort_b_name = "NOT ({split_col} {op} {val})"
2283
-
2284
- _cat_cols = {cats!r}
2285
- n = len(_cat_cols)
2286
- fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2287
- if n == 1:
2288
- axes = np.array([axes])
2289
-
2290
- for i, col in enumerate(_cat_cols):
2291
- s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2292
- s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2293
-
2294
- ax_a = axes[i, 0]; ax_b = axes[i, 1]
2295
- if len(s_a) > 0:
2296
- ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2297
- autopct='%1.1f%%', startangle=90, counterclock=False)
2298
- ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2299
-
2300
- if len(s_b) > 0:
2301
- ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2302
- autopct='%1.1f%%', startangle=90, counterclock=False)
2303
- ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2304
-
2305
- plt.tight_layout(); plt.show()
2306
-
2307
- # grouped bar complement
2308
- for col in _cat_cols:
2309
- _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2310
- .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2311
- _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2312
- _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2313
-
2314
- if _comp_pref == "grouped_bar":
2315
- ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2316
- ax.set_title(f"{col} by cohort (grouped)")
2317
- ax.set_xlabel(col); ax.set_ylabel("Count")
2318
- plt.tight_layout(); plt.show()
2319
-
2320
- elif _comp_pref == "stacked_bar":
2321
- ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2322
- ax.set_title(f"{col} by cohort (stacked)")
2323
- ax.set_xlabel(col); ax.set_ylabel("Count")
2324
- plt.tight_layout(); plt.show()
2325
-
2326
- elif _comp_pref == "stacked_bar_pct":
2327
- _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2328
- ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2329
- ax.set_title(f"{col} by cohort (100% stacked)")
2330
- ax.set_xlabel(col); ax.set_ylabel("Percent")
2331
- plt.tight_layout(); plt.show()
2332
-
2333
- elif _comp_pref == "heatmap":
2334
- _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2335
- import numpy as np
2336
- fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2337
- im = ax.imshow(_perc.values, aspect='auto')
2338
- ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2339
- ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2340
- ax.set_title(f"{col} by cohort — % heatmap")
2341
- for i in range(_perc.shape[0]):
2342
- for j in range(_perc.shape[1]):
2343
- ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2344
- fig.colorbar(im, ax=ax, label="%")
2345
- plt.tight_layout(); plt.show()
2346
-
2347
- else: # counts_bar (default)
2348
- ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2349
- ax.set_title(f"{col}: total counts (both cohorts)")
2350
- ax.set_xlabel(col); ax.set_ylabel("Count")
2351
- plt.tight_layout(); plt.show()
2352
- """.lstrip()
2353
- return snippet
2354
-
2355
- # --------------- CASE B: facet-by (categories/bins) ---------------
2356
- if facet:
2357
- facet_col, how = facet
2358
- # Build facet series
2359
- if pd.api.types.is_numeric_dtype(df[facet_col]):
2360
- if how == "bmi":
2361
- facet_series = _bmi_bins(df[facet_col])
2362
- else:
2363
- # generic numeric bins: 3 equal-width bins by default
2364
- facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2365
- else:
2366
- facet_series = df[facet_col].astype(str)
2367
-
2368
- # Choose pie dimension (categorical to count inside each facet)
2369
- pie_dim = None
2370
- for c in cats:
2371
- if c in df.columns and _is_cat(c):
2372
- pie_dim = c; break
2373
- if pie_dim is None:
2374
- pie_dim = _fallback_cat()
2375
- if pie_dim is None:
2376
- return code
2377
-
2378
- snippet = f"""
2379
- import math
2380
- import pandas as pd
2381
- import matplotlib.pyplot as plt
2382
-
2383
- df = df.copy()
2384
- _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2385
-
2386
- def _select_facet_col(df, preferred=None):
2387
- if preferred is not None:
2388
- return preferred
2389
- # Prefer low-cardinality categoricals (readable pies/grids)
2390
- cat_cols = [
2391
- c for c in df.columns
2392
- if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2393
- and df[c].nunique() > 1 and df[c].nunique() <= 20
2394
- ]
2395
- if cat_cols:
2396
- cat_cols.sort(key=lambda c: df[c].nunique())
2397
- return cat_cols[0]
2398
- # Else fall back to first usable numeric
2399
- num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2400
- return num_cols[0] if num_cols else None
2401
-
2402
- _facet_col = _select_facet_col(df, _preferred)
2403
-
2404
- if _facet_col is None:
2405
- # Nothing suitable → single facet keeps pipeline alive
2406
- df["__facet__"] = "All"
2407
- else:
2408
- s = df[_facet_col]
2409
- if pd.api.types.is_numeric_dtype(s):
2410
- # Robust numeric binning: quantiles first, fallback to equal-width
2411
- uniq = pd.Series(s).dropna().nunique()
2412
- q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2413
- try:
2414
- df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2415
- except Exception:
2416
- df["__facet__"] = pd.cut(s.astype(float), bins=q)
2417
- else:
2418
- # Cap long tails; keep top categories
2419
- vc = s.astype(str).value_counts()
2420
- keep = vc.index[:{top_n}]
2421
- df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2422
-
2423
- levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2424
- levels = [x for x in levels if x != "nan"]
2425
- levels.sort()
2426
-
2427
- m = len(levels)
2428
- cols = 3 if m >= 3 else m or 1
2429
- rows = int(math.ceil(m / cols))
2430
-
2431
- fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2432
- if not isinstance(axes, (list, np.ndarray)):
2433
- axes = np.array([[axes]])
2434
- axes = axes.reshape(rows, cols)
2435
-
2436
- for i, lvl in enumerate(levels):
2437
- r, c = divmod(i, cols)
2438
- ax = axes[r, c]
2439
- s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2440
- .astype(str).value_counts().nlargest({top_n}))
2441
- if len(s) > 0:
2442
- ax.pie(s.values, labels=[str(x) for x in s.index],
2443
- autopct='%1.1f%%', startangle=90, counterclock=False)
2444
- ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2445
-
2446
- # hide any empty subplots
2447
- for j in range(m, rows*cols):
2448
- r, c = divmod(j, cols)
2449
- axes[r, c].axis("off")
2450
-
2451
- plt.tight_layout(); plt.show()
2452
-
2453
- # --- companion visual (adaptive) ---
2454
- _comp_pref = "{_comp_pref}"
2455
- # build contingency table: pie_dim x facet
2456
- _tab = (df[["__facet__", "{pie_dim}"]]
2457
- .dropna()
2458
- .astype({{"__facet__": str, "{pie_dim}": str}})
2459
- .value_counts()
2460
- .unstack(level="__facet__")
2461
- .fillna(0))
2462
-
2463
- # keep top categories by overall size
2464
- _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2465
-
2466
- if _comp_pref == "grouped_bar":
2467
- ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2468
- ax.set_title("{pie_dim} by {facet_col} (grouped)")
2469
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2470
- plt.tight_layout(); plt.show()
2471
-
2472
- elif _comp_pref == "stacked_bar":
2473
- ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2474
- ax.set_title("{pie_dim} by {facet_col} (stacked)")
2475
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2476
- plt.tight_layout(); plt.show()
2477
-
2478
- elif _comp_pref == "stacked_bar_pct":
2479
- _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2480
- ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2481
- ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2482
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2483
- plt.tight_layout(); plt.show()
2484
-
2485
- elif _comp_pref == "heatmap":
2486
- _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2487
- import numpy as np
2488
- fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2489
- im = ax.imshow(_perc.values, aspect='auto')
2490
- ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2491
- ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2492
- ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2493
- for i in range(_perc.shape[0]):
2494
- for j in range(_perc.shape[1]):
2495
- ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2496
- fig.colorbar(im, ax=ax, label="%")
2497
- plt.tight_layout(); plt.show()
2498
-
2499
- else: # counts_bar (default denominators)
2500
- _counts = df["__facet"].value_counts()
2501
- ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2502
- ax.set_title("Counts by {facet_col}")
2503
- ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2504
- plt.tight_layout(); plt.show()
2505
-
2506
- """.lstrip()
2507
- return snippet
2508
-
2509
- # --------------- CASE C: single pie ---------------
2510
- chosen = None
2511
- for c in cats:
2512
- if c in df.columns and _is_cat(c):
2513
- chosen = c; break
2514
- if chosen is None:
2515
- chosen = _fallback_cat()
2516
-
2517
- if chosen:
2518
- snippet = f"""
2519
- import matplotlib.pyplot as plt
2520
- counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2521
- fig, ax = plt.subplots()
2522
- if len(counts) > 0:
2523
- ax.pie(counts.values, labels=[str(i) for i in counts.index],
2524
- autopct='%1.1f%%', startangle=90, counterclock=False)
2525
- ax.set_title('Distribution of {chosen} (top {top_n})')
2526
- ax.axis('equal')
2527
- plt.show()
2528
- """.lstrip()
2529
- return snippet
2530
-
2531
- # numeric last resort
2532
- num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2533
- if num_cols:
2534
- col = num_cols[0]
2535
- snippet = f"""
2536
- import pandas as pd
2537
- import matplotlib.pyplot as plt
2538
- bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2539
- counts = bins.value_counts().sort_index()
2540
- fig, ax = plt.subplots()
2541
- if len(counts) > 0:
2542
- ax.pie(counts.values, labels=[str(i) for i in counts.index],
2543
- autopct='%1.1f%%', startangle=90, counterclock=False)
2544
- ax.set_title('Distribution of {col} (binned)')
2545
- ax.axis('equal')
2546
- plt.show()
2547
- """.lstrip()
2548
- return snippet
2549
-
2550
- return code
2551
-
2552
-
2553
- def patch_fix_seaborn_palette_calls(code: str) -> str:
2554
- """
2555
- Removes seaborn `palette=` when no `hue=` is present in the same call.
2556
- Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2557
- """
2558
- if "sns." not in code:
2559
- return code
2560
-
2561
- # Targets common seaborn plotters
2562
- funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2563
- pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2564
-
2565
- def _fix_call(m):
2566
- head, inner = m.group(1), m.group(2)
2567
- # If there's already hue=, keep as is
2568
- if re.search(r"(?<!\w)hue\s*=", inner):
2569
- return f"{head}{inner})"
2570
- # Otherwise remove palette=... safely (and any adjacent comma spacing)
2571
- inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2572
- inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2573
- inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2574
- return f"{head}{inner2})"
2575
-
2576
- return pattern.sub(_fix_call, code)
2577
-
2578
-
2579
- def patch_quiet_specific_warnings(code: str) -> str:
2580
- """
2581
- Inserts targeted warning filters (not blanket ignores).
2582
- - seaborn palette/hue deprecation
2583
- - python-dotenv parse chatter
2584
- """
2585
- prelude = (
2586
- "import warnings\n"
2587
- "warnings.filterwarnings(\n"
2588
- " 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
2589
- "warnings.filterwarnings(\n"
2590
- " 'ignore', message=r'python-dotenv could not parse statement.*')\n"
2591
- )
2592
- # If warnings already imported once, just add filters; else insert full prelude.
2593
- if "import warnings" in code:
2594
- code = re.sub(
2595
- r"(import warnings[^\n]*\n)",
2596
- lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
2597
- code,
2598
- count=1
2599
- )
2600
-
2601
- else:
2602
- # place after first import block if possible
2603
- m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
2604
- if m:
2605
- idx = m.end()
2606
- code = code[:idx] + prelude + code[idx:]
2607
- else:
2608
- code = prelude + code
2609
- return code
2610
-
2611
-
2612
- def _norm_col_name(s: str) -> str:
2613
- """normalise a column name: lowercase + strip non-alphanumerics."""
2614
- return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2615
-
2616
-
2617
- def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2618
- """return the actual df column that matches any candidate (after normalisation)."""
2619
- norm_map = {_norm_col_name(c): c for c in df.columns}
2620
- for cand in candidates:
2621
- hit = norm_map.get(_norm_col_name(cand))
2622
- if hit is not None:
2623
- return hit
2624
- return None
2625
-
2626
-
2627
- def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2628
- """
2629
- If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2630
- Returns (df, found_bool).
2631
- """
2632
- if target in df.columns:
2633
- return df, True
2634
- col = _first_present(df, [target, *aliases])
2635
- if col is None:
2636
- return df, False
2637
- df[target] = df[col]
2638
- return df, True
2639
-
2640
-
2641
- def strip_python_dotenv(code: str) -> str:
2642
- """
2643
- Remove any use of python-dotenv from generated code, including:
2644
- - single and multi-line 'from dotenv import ...'
2645
- - 'import dotenv' (with or without alias) and calls via any alias
2646
- - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2647
- - IPython magics (%load_ext dotenv, %dotenv, %env …)
2648
- - shell installs like '!pip install python-dotenv'
2649
- """
2650
- original = code
2651
-
2652
- # 0) Kill IPython magics & shell installs referencing dotenv
2653
- code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2654
- code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2655
- code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2656
- code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2657
-
2658
- # 1) Remove single-line 'from dotenv import ...'
2659
- code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2660
-
2661
- # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2662
- code = re.sub(
2663
- r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2664
- "",
2665
- code,
2666
- flags=re.MULTILINE,
2667
- )
2668
-
2669
- # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2670
- aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2671
- code, flags=re.MULTILINE)
2672
- code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2673
- "", code, flags=re.MULTILINE)
2674
-
2675
- # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2676
- # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2677
- fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2678
- # bare calls
2679
- code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2680
- # dotted calls with any identifier prefix (alias or module)
2681
- code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2682
- "", code, flags=re.MULTILINE)
2683
-
2684
- # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2685
- for al in aliases:
2686
- code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2687
-
2688
- # 6) Tidy excess blank lines
2689
- code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2690
- return code
2691
-
2692
-
2693
- def fix_predict_calls_records_arg(code: str) -> str:
2694
- """
2695
- If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2696
- (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2697
- Works line-by-line to avoid over-rewrites elsewhere.
2698
- Examples fixed:
2699
- predict_patient(X_test.iloc[:5].to_dict('records'))
2700
- predict_risk(df.head(3).to_dict(orient="records"))
2701
- → predict_patient(X_test.iloc[:5])
2702
- """
2703
- fixed_lines = []
2704
- for line in code.splitlines():
2705
- if "predict_" in line and "to_dict" in line and "records" in line:
2706
- line = re.sub(
2707
- r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2708
- "",
2709
- line
2710
- )
2711
- fixed_lines.append(line)
2712
- return "\n".join(fixed_lines)
2713
-
2714
-
2715
- def fix_fstring_backslash_paths(code: str) -> str:
2716
- """
2717
- Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2718
- f"...{os.path.join(out_dir, r'plots\\img.png')}"
2719
- Only touches f-strings that contain a backslash path inside {...}.
2720
- """
2721
- def _fix_line(line: str) -> str:
2722
- # quick check: only f-strings need scanning
2723
- if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2724
- return line
2725
- # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2726
- pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2727
- def repl(m):
2728
- left = m.group(1)
2729
- right = m.group(2).strip().replace('"', '\\"')
2730
- return "{os.path.join(" + left + ', r"' + right + '")}'
2731
- return pattern.sub(repl, line)
2732
-
2733
- return "\n".join(_fix_line(ln) for ln in code.splitlines())
2734
-
2735
-
2736
- def ensure_os_import(code: str) -> str:
2737
- """
2738
- If os.path.join is used but 'import os' is missing, inject it at the top.
2739
- """
2740
- needs = "os.path.join(" in code
2741
- has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2742
- has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2743
- if needs and not (has_import_os or has_from_os):
2744
- return "import os\n" + code
2745
- return code
2746
-
2747
-
2748
- def fix_seaborn_boxplot_nameerror(code: str) -> str:
2749
- """
2750
- Fix bad calls like: sns.boxplot(boxplot)
2751
- Heuristic:
2752
- - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2753
- - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
2754
- - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2755
- - Else sns.boxplot(ax=ax)
2756
- Ensures a matplotlib Axes 'ax' exists.
2757
- """
2758
- pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2759
- if not pattern.search(code):
2760
- return code
2761
-
2762
- has_plot_df = re.search(r"\bplot_df\b", code) is not None
2763
- has_var = re.search(r"\bvar\b", code) is not None
2764
- has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2765
-
2766
- if has_plot_df and has_var and has_fh:
2767
- replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2768
- elif has_plot_df and has_var:
2769
- replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2770
- elif has_plot_df:
2771
- replacement = "sns.boxplot(data=plot_df, ax=ax)"
2772
- else:
2773
- replacement = "sns.boxplot(ax=ax)"
2774
-
2775
- fixed = pattern.sub(replacement, code)
2776
-
2777
- # Ensure 'fig, ax = plt.subplots(...)' exists
2778
- if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2779
- # Insert right before the first seaborn call
2780
- m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2781
- insert_at = m.start() if m else 0
2782
- fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2783
-
2784
- return fixed
2785
-
2786
-
2787
- def fix_seaborn_barplot_nameerror(code: str) -> str:
2788
- """
2789
- Fix bad calls like: sns.barplot(barplot)
2790
- Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2791
- otherwise degrade safely to an empty call on an existing Axes.
2792
- """
2793
- import re
2794
- pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2795
- if not pattern.search(code):
2796
- return code
2797
-
2798
- has_plot_df = re.search(r"\bplot_df\b", code) is not None
2799
- has_var = re.search(r"\bvar\b", code) is not None
2800
- has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2801
-
2802
- if has_plot_df and has_var and has_fh:
2803
- replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2804
- elif has_plot_df and has_var:
2805
- replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2806
- elif has_plot_df:
2807
- replacement = "sns.barplot(data=plot_df, ax=ax)"
2808
- else:
2809
- replacement = "sns.barplot(ax=ax)"
2810
-
2811
- # ensure an Axes 'ax' exists (no-op if already present)
2812
- if "ax =" not in code:
2813
- code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2814
-
2815
- return pattern.sub(replacement, code)
2816
-
2817
-
2818
- def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2819
- """
2820
- Parses the raw text to extract and format the 'refined question',
2821
- 'intents (tasks)', and 'chronology of tasks' sections.
2822
- Args:
2823
- raw_text: The complete input string containing the ML pipeline structure.
2824
- Returns:
2825
- A tuple containing:
2826
- (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2827
- """
2828
- # --- 1. Regex Pattern to Extract Sections ---
2829
- # The pattern uses capturing groups (?) to look for the section headers
2830
- # (e.g., 'refined question:') and captures all the content until the next
2831
- # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2269
+ # def _infer_comp_pref(question: str) -> str:
2270
+ # ql = (question or "").lower()
2271
+ # if "heatmap" in ql or "matrix" in ql:
2272
+ # return "heatmap"
2273
+ # if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
2274
+ # return "stacked_bar_pct"
2275
+ # if "stacked" in ql:
2276
+ # return "stacked_bar"
2277
+ # if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
2278
+ # return "grouped_bar"
2279
+ # return "counts_bar"
2280
+
2281
+ # # parse threshold split like "HbA1c ≥ 6.5"
2282
+ # def _parse_split(question: str):
2283
+ # ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
2284
+ # m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
2285
+ # if not m: return None
2286
+ # col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
2287
+ # op = ops_map.get(op_raw);
2288
+ # if not op: return None
2289
+ # # case-insensitive column match
2290
+ # candidates = {c.lower(): c for c in df.columns}
2291
+ # col = candidates.get(col_raw.lower())
2292
+ # if not col: return None
2293
+ # try: val = float(val_raw)
2294
+ # except Exception: return None
2295
+ # return (col, op, val)
2296
+
2297
+ # # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
2298
+ # def _extract_facet(question: str):
2299
+ # # 1) explicit "by/ across / within / per <col>"
2300
+ # for kw in [" by ", " across ", " within ", " within each ", " per "]:
2301
+ # m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
2302
+ # if m:
2303
+ # col_raw = m.group(1).strip()
2304
+ # candidates = {c.lower(): c for c in df.columns}
2305
+ # if col_raw.lower() in candidates:
2306
+ # return (candidates[col_raw.lower()], "auto")
2307
+ # # 2) "bin <col>"
2308
+ # m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
2309
+ # if m2:
2310
+ # col_raw = m2.group(1).strip()
2311
+ # candidates = {c.lower(): c for c in df.columns}
2312
+ # if col_raw.lower() in candidates:
2313
+ # return (candidates[col_raw.lower()], "bin")
2314
+ # # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
2315
+ # if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
2316
+ # any(c.lower() == "bmi" for c in df.columns.str.lower()):
2317
+ # bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
2318
+ # return (bmi_col, "bmi")
2319
+ # return None
2320
+
2321
+ # def _bmi_bins(series: pd.Series):
2322
+ # # WHO cutoffs
2323
+ # bins = [-np.inf, 18.5, 25, 30, np.inf]
2324
+ # labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
2325
+ # return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
2326
+
2327
+ # wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
2328
+ # if not wants_pie:
2329
+ # return code
2330
+
2331
+ # split = _parse_split(q)
2332
+ # facet = _extract_facet(q)
2333
+ # cats = _cats_from_question(q)
2334
+ # _comp_pref = _infer_comp_pref(q)
2335
+
2336
+ # # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
2337
+ # for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
2338
+ # if hard in df.columns and hard not in cats and hard.lower() in q_low:
2339
+ # cats.append(hard)
2340
+
2341
+ # # --------------- CASE A: threshold split (cohorts) ---------------
2342
+ # if split:
2343
+ # if not (cats or any(_is_cat(c) for c in df.columns)):
2344
+ # return code
2345
+ # if not cats:
2346
+ # pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
2347
+ # pool.sort(key=lambda t: t[1])
2348
+ # cats = [t[0] for t in pool[:3]] if pool else []
2349
+ # if not cats:
2350
+ # return code
2351
+
2352
+ # split_col, op, val = split
2353
+ # cond_str = f"(df['{split_col}'] {op} {val})"
2354
+ # snippet = f"""
2355
+ # import numpy as np
2356
+ # import pandas as pd
2357
+ # import matplotlib.pyplot as plt
2358
+
2359
+ # _mask_a = ({cond_str}) & df['{split_col}'].notna()
2360
+ # _mask_b = (~({cond_str})) & df['{split_col}'].notna()
2361
+
2362
+ # _cohort_a_name = "{split_col} {op} {val}"
2363
+ # _cohort_b_name = "NOT ({split_col} {op} {val})"
2364
+
2365
+ # _cat_cols = {cats!r}
2366
+ # n = len(_cat_cols)
2367
+ # fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
2368
+ # if n == 1:
2369
+ # axes = np.array([axes])
2370
+
2371
+ # for i, col in enumerate(_cat_cols):
2372
+ # s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
2373
+ # s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
2374
+
2375
+ # ax_a = axes[i, 0]; ax_b = axes[i, 1]
2376
+ # if len(s_a) > 0:
2377
+ # ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
2378
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2379
+ # ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
2380
+
2381
+ # if len(s_b) > 0:
2382
+ # ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
2383
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2384
+ # ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
2385
+
2386
+ # plt.tight_layout(); plt.show()
2387
+
2388
+ # # grouped bar complement
2389
+ # for col in _cat_cols:
2390
+ # _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
2391
+ # .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
2392
+ # _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
2393
+ # _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2394
+
2395
+ # if _comp_pref == "grouped_bar":
2396
+ # ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
2397
+ # ax.set_title(f"{col} by cohort (grouped)")
2398
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2399
+ # plt.tight_layout(); plt.show()
2400
+
2401
+ # elif _comp_pref == "stacked_bar":
2402
+ # ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2403
+ # ax.set_title(f"{col} by cohort (stacked)")
2404
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2405
+ # plt.tight_layout(); plt.show()
2406
+
2407
+ # elif _comp_pref == "stacked_bar_pct":
2408
+ # _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2409
+ # ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
2410
+ # ax.set_title(f"{col} by cohort (100% stacked)")
2411
+ # ax.set_xlabel(col); ax.set_ylabel("Percent")
2412
+ # plt.tight_layout(); plt.show()
2413
+
2414
+ # elif _comp_pref == "heatmap":
2415
+ # _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
2416
+ # import numpy as np
2417
+ # fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
2418
+ # im = ax.imshow(_perc.values, aspect='auto')
2419
+ # ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2420
+ # ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2421
+ # ax.set_title(f"{col} by cohort — % heatmap")
2422
+ # for i in range(_perc.shape[0]):
2423
+ # for j in range(_perc.shape[1]):
2424
+ # ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2425
+ # fig.colorbar(im, ax=ax, label="%")
2426
+ # plt.tight_layout(); plt.show()
2427
+
2428
+ # else: # counts_bar (default)
2429
+ # ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
2430
+ # ax.set_title(f"{col}: total counts (both cohorts)")
2431
+ # ax.set_xlabel(col); ax.set_ylabel("Count")
2432
+ # plt.tight_layout(); plt.show()
2433
+ # """.lstrip()
2434
+ # return snippet
2435
+
2436
+ # # --------------- CASE B: facet-by (categories/bins) ---------------
2437
+ # if facet:
2438
+ # facet_col, how = facet
2439
+ # # Build facet series
2440
+ # if pd.api.types.is_numeric_dtype(df[facet_col]):
2441
+ # if how == "bmi":
2442
+ # facet_series = _bmi_bins(df[facet_col])
2443
+ # else:
2444
+ # # generic numeric bins: 3 equal-width bins by default
2445
+ # facet_series = pd.cut(df[facet_col].astype(float), bins=3)
2446
+ # else:
2447
+ # facet_series = df[facet_col].astype(str)
2448
+
2449
+ # # Choose pie dimension (categorical to count inside each facet)
2450
+ # pie_dim = None
2451
+ # for c in cats:
2452
+ # if c in df.columns and _is_cat(c):
2453
+ # pie_dim = c; break
2454
+ # if pie_dim is None:
2455
+ # pie_dim = _fallback_cat()
2456
+ # if pie_dim is None:
2457
+ # return code
2458
+
2459
+ # snippet = f"""
2460
+ # import math
2461
+ # import pandas as pd
2462
+ # import matplotlib.pyplot as plt
2463
+
2464
+ # df = df.copy()
2465
+ # _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
2466
+
2467
+ # def _select_facet_col(df, preferred=None):
2468
+ # if preferred is not None:
2469
+ # return preferred
2470
+ # # Prefer low-cardinality categoricals (readable pies/grids)
2471
+ # cat_cols = [
2472
+ # c for c in df.columns
2473
+ # if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
2474
+ # and df[c].nunique() > 1 and df[c].nunique() <= 20
2475
+ # ]
2476
+ # if cat_cols:
2477
+ # cat_cols.sort(key=lambda c: df[c].nunique())
2478
+ # return cat_cols[0]
2479
+ # # Else fall back to first usable numeric
2480
+ # num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
2481
+ # return num_cols[0] if num_cols else None
2482
+
2483
+ # _facet_col = _select_facet_col(df, _preferred)
2484
+
2485
+ # if _facet_col is None:
2486
+ # # Nothing suitable → single facet keeps pipeline alive
2487
+ # df["__facet__"] = "All"
2488
+ # else:
2489
+ # s = df[_facet_col]
2490
+ # if pd.api.types.is_numeric_dtype(s):
2491
+ # # Robust numeric binning: quantiles first, fallback to equal-width
2492
+ # uniq = pd.Series(s).dropna().nunique()
2493
+ # q = 3 if uniq < 10 else 4 if uniq < 30 else 5
2494
+ # try:
2495
+ # df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
2496
+ # except Exception:
2497
+ # df["__facet__"] = pd.cut(s.astype(float), bins=q)
2498
+ # else:
2499
+ # # Cap long tails; keep top categories
2500
+ # vc = s.astype(str).value_counts()
2501
+ # keep = vc.index[:{top_n}]
2502
+ # df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
2503
+
2504
+ # levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
2505
+ # levels = [x for x in levels if x != "nan"]
2506
+ # levels.sort()
2507
+
2508
+ # m = len(levels)
2509
+ # cols = 3 if m >= 3 else m or 1
2510
+ # rows = int(math.ceil(m / cols))
2511
+
2512
+ # fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
2513
+ # if not isinstance(axes, (list, np.ndarray)):
2514
+ # axes = np.array([[axes]])
2515
+ # axes = axes.reshape(rows, cols)
2516
+
2517
+ # for i, lvl in enumerate(levels):
2518
+ # r, c = divmod(i, cols)
2519
+ # ax = axes[r, c]
2520
+ # s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
2521
+ # .astype(str).value_counts().nlargest({top_n}))
2522
+ # if len(s) > 0:
2523
+ # ax.pie(s.values, labels=[str(x) for x in s.index],
2524
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2525
+ # ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
2526
+
2527
+ # # hide any empty subplots
2528
+ # for j in range(m, rows*cols):
2529
+ # r, c = divmod(j, cols)
2530
+ # axes[r, c].axis("off")
2531
+
2532
+ # plt.tight_layout(); plt.show()
2533
+
2534
+ # # --- companion visual (adaptive) ---
2535
+ # _comp_pref = "{_comp_pref}"
2536
+ # # build contingency table: pie_dim x facet
2537
+ # _tab = (df[["__facet__", "{pie_dim}"]]
2538
+ # .dropna()
2539
+ # .astype({{"__facet__": str, "{pie_dim}": str}})
2540
+ # .value_counts()
2541
+ # .unstack(level="__facet__")
2542
+ # .fillna(0))
2543
+
2544
+ # # keep top categories by overall size
2545
+ # _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
2546
+
2547
+ # if _comp_pref == "grouped_bar":
2548
+ # ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2549
+ # ax.set_title("{pie_dim} by {facet_col} (grouped)")
2550
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2551
+ # plt.tight_layout(); plt.show()
2552
+
2553
+ # elif _comp_pref == "stacked_bar":
2554
+ # ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
2555
+ # ax.set_title("{pie_dim} by {facet_col} (stacked)")
2556
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2557
+ # plt.tight_layout(); plt.show()
2558
+
2559
+ # elif _comp_pref == "stacked_bar_pct":
2560
+ # _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
2561
+ # ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
2562
+ # ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
2563
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
2564
+ # plt.tight_layout(); plt.show()
2565
+
2566
+ # elif _comp_pref == "heatmap":
2567
+ # _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
2568
+ # import numpy as np
2569
+ # fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
2570
+ # im = ax.imshow(_perc.values, aspect='auto')
2571
+ # ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
2572
+ # ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
2573
+ # ax.set_title("{pie_dim} by {facet_col} — % heatmap")
2574
+ # for i in range(_perc.shape[0]):
2575
+ # for j in range(_perc.shape[1]):
2576
+ # ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
2577
+ # fig.colorbar(im, ax=ax, label="%")
2578
+ # plt.tight_layout(); plt.show()
2579
+
2580
+ # else: # counts_bar (default denominators)
2581
+ # _counts = df["__facet"].value_counts()
2582
+ # ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
2583
+ # ax.set_title("Counts by {facet_col}")
2584
+ # ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
2585
+ # plt.tight_layout(); plt.show()
2586
+
2587
+ # """.lstrip()
2588
+ # return snippet
2589
+
2590
+ # # --------------- CASE C: single pie ---------------
2591
+ # chosen = None
2592
+ # for c in cats:
2593
+ # if c in df.columns and _is_cat(c):
2594
+ # chosen = c; break
2595
+ # if chosen is None:
2596
+ # chosen = _fallback_cat()
2597
+
2598
+ # if chosen:
2599
+ # snippet = f"""
2600
+ # import matplotlib.pyplot as plt
2601
+ # counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
2602
+ # fig, ax = plt.subplots()
2603
+ # if len(counts) > 0:
2604
+ # ax.pie(counts.values, labels=[str(i) for i in counts.index],
2605
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2606
+ # ax.set_title('Distribution of {chosen} (top {top_n})')
2607
+ # ax.axis('equal')
2608
+ # plt.show()
2609
+ # """.lstrip()
2610
+ # return snippet
2611
+
2612
+ # # numeric last resort
2613
+ # num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
2614
+ # if num_cols:
2615
+ # col = num_cols[0]
2616
+ # snippet = f"""
2617
+ # import pandas as pd
2618
+ # import matplotlib.pyplot as plt
2619
+ # bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
2620
+ # counts = bins.value_counts().sort_index()
2621
+ # fig, ax = plt.subplots()
2622
+ # if len(counts) > 0:
2623
+ # ax.pie(counts.values, labels=[str(i) for i in counts.index],
2624
+ # autopct='%1.1f%%', startangle=90, counterclock=False)
2625
+ # ax.set_title('Distribution of {col} (binned)')
2626
+ # ax.axis('equal')
2627
+ # plt.show()
2628
+ # """.lstrip()
2629
+ # return snippet
2630
+
2631
+ # return code
2632
+
2633
+
2634
+ # def patch_fix_seaborn_palette_calls(code: str) -> str:
2635
+ # """
2636
+ # Removes seaborn `palette=` when no `hue=` is present in the same call.
2637
+ # Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
2638
+ # """
2639
+ # if "sns." not in code:
2640
+ # return code
2641
+
2642
+ # # Targets common seaborn plotters
2643
+ # funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
2644
+ # pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
2645
+
2646
+ # def _fix_call(m):
2647
+ # head, inner = m.group(1), m.group(2)
2648
+ # # If there's already hue=, keep as is
2649
+ # if re.search(r"(?<!\w)hue\s*=", inner):
2650
+ # return f"{head}{inner})"
2651
+ # # Otherwise remove palette=... safely (and any adjacent comma spacing)
2652
+ # inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
2653
+ # inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
2654
+ # inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
2655
+ # return f"{head}{inner2})"
2656
+
2657
+ # return pattern.sub(_fix_call, code)
2658
+
2659
+ # def _norm_col_name(s: str) -> str:
2660
+ # """normalise a column name: lowercase + strip non-alphanumerics."""
2661
+ # return re.sub(r"[^a-z0-9]+", "", str(s).lower())
2662
+
2663
+
2664
+ # def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
2665
+ # """return the actual df column that matches any candidate (after normalisation)."""
2666
+ # norm_map = {_norm_col_name(c): c for c in df.columns}
2667
+ # for cand in candidates:
2668
+ # hit = norm_map.get(_norm_col_name(cand))
2669
+ # if hit is not None:
2670
+ # return hit
2671
+ # return None
2672
+
2673
+
2674
+ # def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
2675
+ # """
2676
+ # If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
2677
+ # Returns (df, found_bool).
2678
+ # """
2679
+ # if target in df.columns:
2680
+ # return df, True
2681
+ # col = _first_present(df, [target, *aliases])
2682
+ # if col is None:
2683
+ # return df, False
2684
+ # df[target] = df[col]
2685
+ # return df, True
2686
+
2687
+
2688
+ # def strip_python_dotenv(code: str) -> str:
2689
+ # """
2690
+ # Remove any use of python-dotenv from generated code, including:
2691
+ # - single and multi-line 'from dotenv import ...'
2692
+ # - 'import dotenv' (with or without alias) and calls via any alias
2693
+ # - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
2694
+ # - IPython magics (%load_ext dotenv, %dotenv, %env …)
2695
+ # - shell installs like '!pip install python-dotenv'
2696
+ # """
2697
+ # original = code
2698
+
2699
+ # # 0) Kill IPython magics & shell installs referencing dotenv
2700
+ # code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
2701
+ # code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
2702
+ # code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
2703
+ # code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
2704
+
2705
+ # # 1) Remove single-line 'from dotenv import ...'
2706
+ # code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
2707
+
2708
+ # # 2) Remove multi-line 'from dotenv import ( ... )' blocks
2709
+ # code = re.sub(
2710
+ # r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
2711
+ # "",
2712
+ # code,
2713
+ # flags=re.MULTILINE,
2714
+ # )
2715
+
2716
+ # # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
2717
+ # aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
2718
+ # code, flags=re.MULTILINE)
2719
+ # code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
2720
+ # "", code, flags=re.MULTILINE)
2721
+
2722
+ # # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
2723
+ # # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
2724
+ # fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
2725
+ # # bare calls
2726
+ # code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2727
+ # # dotted calls with any identifier prefix (alias or module)
2728
+ # code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
2729
+ # "", code, flags=re.MULTILINE)
2730
+
2731
+ # # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
2732
+ # for al in aliases:
2733
+ # code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
2734
+
2735
+ # # 6) Tidy excess blank lines
2736
+ # code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
2737
+ # return code
2738
+
2739
+
2740
+ # def fix_predict_calls_records_arg(code: str) -> str:
2741
+ # """
2742
+ # If generated code calls predict_* with a list-of-dicts via .to_dict('records')
2743
+ # (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
2744
+ # Works line-by-line to avoid over-rewrites elsewhere.
2745
+ # Examples fixed:
2746
+ # predict_patient(X_test.iloc[:5].to_dict('records'))
2747
+ # predict_risk(df.head(3).to_dict(orient="records"))
2748
+ # → predict_patient(X_test.iloc[:5])
2749
+ # """
2750
+ # fixed_lines = []
2751
+ # for line in code.splitlines():
2752
+ # if "predict_" in line and "to_dict" in line and "records" in line:
2753
+ # line = re.sub(
2754
+ # r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
2755
+ # "",
2756
+ # line
2757
+ # )
2758
+ # fixed_lines.append(line)
2759
+ # return "\n".join(fixed_lines)
2760
+
2761
+
2762
+ # def fix_fstring_backslash_paths(code: str) -> str:
2763
+ # """
2764
+ # Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
2765
+ # f"...{os.path.join(out_dir, r'plots\\img.png')}"
2766
+ # Only touches f-strings that contain a backslash path inside {...}.
2767
+ # """
2768
+ # def _fix_line(line: str) -> str:
2769
+ # # quick check: only f-strings need scanning
2770
+ # if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
2771
+ # return line
2772
+ # # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
2773
+ # pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
2774
+ # def repl(m):
2775
+ # left = m.group(1)
2776
+ # right = m.group(2).strip().replace('"', '\\"')
2777
+ # return "{os.path.join(" + left + ', r"' + right + '")}'
2778
+ # return pattern.sub(repl, line)
2779
+
2780
+ # return "\n".join(_fix_line(ln) for ln in code.splitlines())
2781
+
2782
+
2783
+ # def ensure_os_import(code: str) -> str:
2784
+ # """
2785
+ # If os.path.join is used but 'import os' is missing, inject it at the top.
2786
+ # """
2787
+ # needs = "os.path.join(" in code
2788
+ # has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
2789
+ # has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
2790
+ # if needs and not (has_import_os or has_from_os):
2791
+ # return "import os\n" + code
2792
+ # return code
2793
+
2794
+
2795
+ # def fix_seaborn_boxplot_nameerror(code: str) -> str:
2796
+ # """
2797
+ # Fix bad calls like: sns.boxplot(boxplot)
2798
+ # Heuristic:
2799
+ # - If plot_df + FH_status + var exist sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
2800
+ # - Else if plot_df + var sns.boxplot(data=plot_df, y=var, ax=ax)
2801
+ # - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
2802
+ # - Else → sns.boxplot(ax=ax)
2803
+ # Ensures a matplotlib Axes 'ax' exists.
2804
+ # """
2805
+ # pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
2806
+ # if not pattern.search(code):
2807
+ # return code
2808
+
2809
+ # has_plot_df = re.search(r"\bplot_df\b", code) is not None
2810
+ # has_var = re.search(r"\bvar\b", code) is not None
2811
+ # has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2812
+
2813
+ # if has_plot_df and has_var and has_fh:
2814
+ # replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2815
+ # elif has_plot_df and has_var:
2816
+ # replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
2817
+ # elif has_plot_df:
2818
+ # replacement = "sns.boxplot(data=plot_df, ax=ax)"
2819
+ # else:
2820
+ # replacement = "sns.boxplot(ax=ax)"
2821
+
2822
+ # fixed = pattern.sub(replacement, code)
2823
+
2824
+ # # Ensure 'fig, ax = plt.subplots(...)' exists
2825
+ # if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
2826
+ # # Insert right before the first seaborn call
2827
+ # m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
2828
+ # insert_at = m.start() if m else 0
2829
+ # fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
2830
+
2831
+ # return fixed
2832
+
2833
+
2834
+ # def fix_seaborn_barplot_nameerror(code: str) -> str:
2835
+ # """
2836
+ # Fix bad calls like: sns.barplot(barplot)
2837
+ # Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
2838
+ # otherwise degrade safely to an empty call on an existing Axes.
2839
+ # """
2840
+ # import re
2841
+ # pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
2842
+ # if not pattern.search(code):
2843
+ # return code
2844
+
2845
+ # has_plot_df = re.search(r"\bplot_df\b", code) is not None
2846
+ # has_var = re.search(r"\bvar\b", code) is not None
2847
+ # has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
2848
+
2849
+ # if has_plot_df and has_var and has_fh:
2850
+ # replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
2851
+ # elif has_plot_df and has_var:
2852
+ # replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
2853
+ # elif has_plot_df:
2854
+ # replacement = "sns.barplot(data=plot_df, ax=ax)"
2855
+ # else:
2856
+ # replacement = "sns.barplot(ax=ax)"
2857
+
2858
+ # # ensure an Axes 'ax' exists (no-op if already present)
2859
+ # if "ax =" not in code:
2860
+ # code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
2861
+
2862
+ # return pattern.sub(replacement, code)
2863
+
2864
+
2865
+ # def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
2866
+ # """
2867
+ # Parses the raw text to extract and format the 'refined question',
2868
+ # 'intents (tasks)', and 'chronology of tasks' sections.
2869
+ # Args:
2870
+ # raw_text: The complete input string containing the ML pipeline structure.
2871
+ # Returns:
2872
+ # A tuple containing:
2873
+ # (formatted_question_str, formatted_intents_str, formatted_chronology_str)
2874
+ # """
2875
+ # # --- 1. Regex Pattern to Extract Sections ---
2876
+ # # The pattern uses capturing groups (?) to look for the section headers
2877
+ # # (e.g., 'refined question:') and captures all the content until the next
2878
+ # # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
2832
2879
 
2833
- pattern = re.compile(
2834
- r"refined question:(?P<question>.*?)"
2835
- r"intents \(tasks\):(?P<intents>.*?)"
2836
- r"Chronology of tasks:(?P<chronology>.*)",
2837
- re.IGNORECASE | re.DOTALL
2838
- )
2880
+ # pattern = re.compile(
2881
+ # r"refined question:(?P<question>.*?)"
2882
+ # r"intents \(tasks\):(?P<intents>.*?)"
2883
+ # r"Chronology of tasks:(?P<chronology>.*)",
2884
+ # re.IGNORECASE | re.DOTALL
2885
+ # )
2839
2886
 
2840
- match = pattern.search(raw_text)
2887
+ # match = pattern.search(raw_text)
2841
2888
 
2842
- if not match:
2843
- raise ValueError("Input text structure does not match the expected pattern.")
2889
+ # if not match:
2890
+ # raise ValueError("Input text structure does not match the expected pattern.")
2844
2891
 
2845
- # --- 2. Extract Content ---
2846
- question_content = match.group('question').strip()
2847
- intents_content = match.group('intents').strip()
2848
- chronology_content = match.group('chronology').strip()
2892
+ # # --- 2. Extract Content ---
2893
+ # question_content = match.group('question').strip()
2894
+ # intents_content = match.group('intents').strip()
2895
+ # chronology_content = match.group('chronology').strip()
2849
2896
 
2850
- # --- 3. Formatting Functions ---
2897
+ # # --- 3. Formatting Functions ---
2851
2898
 
2852
- def format_question(content):
2853
- """Formats the Refined Question section."""
2854
- # Clean up leading/trailing whitespace and ensure clean paragraphs
2855
- content = content.strip().replace('\n', ' ').replace(' ', ' ')
2899
+ # def format_question(content):
2900
+ # """Formats the Refined Question section."""
2901
+ # # Clean up leading/trailing whitespace and ensure clean paragraphs
2902
+ # content = content.strip().replace('\n', ' ').replace(' ', ' ')
2856
2903
 
2857
- # Simple formatting using Markdown headers and bolding
2858
- formatted = (
2859
- # "## 1. Project Goal and Objectives\n\n"
2860
- "<b> Refined Question:</b>\n"
2861
- f"{content}\n"
2862
- )
2863
- return formatted
2864
-
2865
- def format_intents(content):
2866
- """Formats the Intents (Tasks) section as a structured list."""
2867
- # Use regex to find and format each numbered task
2868
- # It finds 'N. **Text** - ...' and breaks it down.
2904
+ # # Simple formatting using Markdown headers and bolding
2905
+ # formatted = (
2906
+ # # "## 1. Project Goal and Objectives\n\n"
2907
+ # "<b> Refined Question:</b>\n"
2908
+ # f"{content}\n"
2909
+ # )
2910
+ # return formatted
2911
+
2912
+ # def format_intents(content):
2913
+ # """Formats the Intents (Tasks) section as a structured list."""
2914
+ # # Use regex to find and format each numbered task
2915
+ # # It finds 'N. **Text** - ...' and breaks it down.
2869
2916
 
2870
- tasks = []
2871
- # Pattern: N. **Text** - Content (including newlines, non-greedy)
2872
- # We need to explicitly handle the list items starting with '-' within the content
2873
- task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
2917
+ # tasks = []
2918
+ # # Pattern: N. **Text** - Content (including newlines, non-greedy)
2919
+ # # We need to explicitly handle the list items starting with '-' within the content
2920
+ # task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
2874
2921
 
2875
- # Split the content by lines and join tasks back into clean strings
2876
- raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
2922
+ # # Split the content by lines and join tasks back into clean strings
2923
+ # raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
2877
2924
 
2878
- for task in raw_tasks:
2879
- # Replace the initial task number and **Heading** with a Heading 3
2880
- task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
2925
+ # for task in raw_tasks:
2926
+ # # Replace the initial task number and **Heading** with a Heading 3
2927
+ # task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
2881
2928
 
2882
- # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
2883
- task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
2884
- tasks.append(task)
2929
+ # # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
2930
+ # task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
2931
+ # tasks.append(task)
2885
2932
 
2886
- formatted_tasks = "\n\n".join(tasks)
2933
+ # formatted_tasks = "\n\n".join(tasks)
2934
+
2935
+ # return (
2936
+ # "\n---\n"
2937
+ # "## 2. Methodology and Tasks\n\n"
2938
+ # f"{formatted_tasks}\n"
2939
+ # )
2940
+
2941
+ # def format_chronology(content):
2942
+ # """Formats the Chronology section."""
2943
+ # # Uses the given LaTeX format
2944
+ # content = content.strip().replace(' ', ' \rightarrow ')
2945
+ # formatted = (
2946
+ # "\n---\n"
2947
+ # "## 3. Chronology of Tasks\n"
2948
+ # f"$$\\text{{{content}}}$$"
2949
+ # )
2950
+ # return formatted
2951
+
2952
+ # # --- 4. Format and Return ---
2953
+ # formatted_question = format_question(question_content)
2954
+ # formatted_intents = format_intents(intents_content)
2955
+ # formatted_chronology = format_chronology(chronology_content)
2956
+
2957
+ # return formatted_question, formatted_intents, formatted_chronology
2958
+
2959
+
2960
+ # def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
2961
+ # """Combines all formatted parts into a final report string."""
2962
+ # return (
2963
+ # "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
2964
+ # f"{formatted_question}\n"
2965
+ # f"{formatted_intents}\n"
2966
+ # f"{formatted_chronology}\n"
2967
+ # )
2968
+
2969
+
2970
+ # def fix_confusion_matrix_for_multilabel(code: str) -> str:
2971
+ # """
2972
+ # Replace ConfusionMatrixDisplay.from_estimator(...) usages with
2973
+ # from_predictions(...) which works for multi-label loops without requiring
2974
+ # the estimator to expose _estimator_type.
2975
+ # """
2976
+ # return re.sub(
2977
+ # r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
2978
+ # r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
2979
+ # code
2980
+ # )
2981
+
2982
+
2983
+ # def smx_auto_title_plots(ctx=None, fallback="Analysis"):
2984
+ # """
2985
+ # Ensure every Matplotlib/Seaborn Axes has a title.
2986
+ # Uses refined_question -> askai_question -> fallback.
2987
+ # Only sets a title if it's currently empty.
2988
+ # """
2989
+ # import matplotlib.pyplot as plt
2990
+
2991
+ # def _all_figures():
2992
+ # try:
2993
+ # from matplotlib._pylab_helpers import Gcf
2994
+ # return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
2995
+ # except Exception:
2996
+ # # Best effort fallback
2997
+ # nums = plt.get_fignums()
2998
+ # return [plt.figure(n) for n in nums] if nums else []
2999
+
3000
+ # # Choose a concise title
3001
+ # title = None
3002
+ # if isinstance(ctx, dict):
3003
+ # title = ctx.get("refined_question") or ctx.get("askai_question")
3004
+ # title = (str(title).strip().splitlines()[0][:120]) if title else fallback
3005
+
3006
+ # for fig in _all_figures():
3007
+ # for ax in getattr(fig, "axes", []):
3008
+ # try:
3009
+ # if not (ax.get_title() or "").strip():
3010
+ # ax.set_title(title)
3011
+ # except Exception:
3012
+ # pass
3013
+ # try:
3014
+ # fig.tight_layout()
3015
+ # except Exception:
3016
+ # pass
3017
+
3018
+
3019
+ # def patch_fix_sentinel_plot_calls(code: str) -> str:
3020
+ # """
3021
+ # Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
3022
+ # SB_barplot(barplot) -> SB_barplot()
3023
+ # SB_barplot(barplot, ...) -> SB_barplot(...)
3024
+ # sns.barplot(barplot) -> SB_barplot()
3025
+ # sns.barplot(barplot, ...) -> SB_barplot(...)
3026
+ # Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
3027
+ # """
3028
+ # names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
3029
+ # for n in names:
3030
+ # # SB_* with sentinel as the first arg (with or without trailing args)
3031
+ # code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3032
+ # code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3033
+ # # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
3034
+ # code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
3035
+ # code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
3036
+ # return code
3037
+
3038
+
3039
+ # def patch_rmse_calls(code: str) -> str:
3040
+ # """
3041
+ # Make RMSE robust across sklearn versions.
3042
+ # - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
3043
+ # - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
3044
+ # """
3045
+ # import re
3046
+ # # (a) Specific RMSE pattern
3047
+ # code = re.sub(
3048
+ # r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3049
+ # r"_SMX_rmse(\1)",
3050
+ # code,
3051
+ # flags=re.DOTALL
3052
+ # )
3053
+ # # (b) Guard any other MSE calls
3054
+ # code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3055
+ # return code
2887
3056
 
2888
- return (
2889
- "\n---\n"
2890
- "## 2. Methodology and Tasks\n\n"
2891
- f"{formatted_tasks}\n"
2892
- )
2893
-
2894
- def format_chronology(content):
2895
- """Formats the Chronology section."""
2896
- # Uses the given LaTeX format
2897
- content = content.strip().replace(' ', ' \rightarrow ')
2898
- formatted = (
2899
- "\n---\n"
2900
- "## 3. Chronology of Tasks\n"
2901
- f"$$\\text{{{content}}}$$"
2902
- )
2903
- return formatted
2904
-
2905
- # --- 4. Format and Return ---
2906
- formatted_question = format_question(question_content)
2907
- formatted_intents = format_intents(intents_content)
2908
- formatted_chronology = format_chronology(chronology_content)
2909
-
2910
- return formatted_question, formatted_intents, formatted_chronology
2911
-
2912
-
2913
- def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
2914
- """Combines all formatted parts into a final report string."""
2915
- return (
2916
- "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
2917
- f"{formatted_question}\n"
2918
- f"{formatted_intents}\n"
2919
- f"{formatted_chronology}\n"
2920
- )
2921
-
2922
-
2923
- def fix_confusion_matrix_for_multilabel(code: str) -> str:
2924
- """
2925
- Replace ConfusionMatrixDisplay.from_estimator(...) usages with
2926
- from_predictions(...) which works for multi-label loops without requiring
2927
- the estimator to expose _estimator_type.
2928
- """
2929
- return re.sub(
2930
- r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
2931
- r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
2932
- code
2933
- )
2934
-
2935
-
2936
- def smx_auto_title_plots(ctx=None, fallback="Analysis"):
2937
- """
2938
- Ensure every Matplotlib/Seaborn Axes has a title.
2939
- Uses refined_question -> askai_question -> fallback.
2940
- Only sets a title if it's currently empty.
2941
- """
2942
- import matplotlib.pyplot as plt
2943
-
2944
- def _all_figures():
2945
- try:
2946
- from matplotlib._pylab_helpers import Gcf
2947
- return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
2948
- except Exception:
2949
- # Best effort fallback
2950
- nums = plt.get_fignums()
2951
- return [plt.figure(n) for n in nums] if nums else []
2952
-
2953
- # Choose a concise title
2954
- title = None
2955
- if isinstance(ctx, dict):
2956
- title = ctx.get("refined_question") or ctx.get("askai_question")
2957
- title = (str(title).strip().splitlines()[0][:120]) if title else fallback
2958
-
2959
- for fig in _all_figures():
2960
- for ax in getattr(fig, "axes", []):
2961
- try:
2962
- if not (ax.get_title() or "").strip():
2963
- ax.set_title(title)
2964
- except Exception:
2965
- pass
2966
- try:
2967
- fig.tight_layout()
2968
- except Exception:
2969
- pass
2970
-
2971
-
2972
- def patch_fix_sentinel_plot_calls(code: str) -> str:
2973
- """
2974
- Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
2975
- SB_barplot(barplot) -> SB_barplot()
2976
- SB_barplot(barplot, ...) -> SB_barplot(...)
2977
- sns.barplot(barplot) -> SB_barplot()
2978
- sns.barplot(barplot, ...) -> SB_barplot(...)
2979
- Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
2980
- """
2981
- names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
2982
- for n in names:
2983
- # SB_* with sentinel as the first arg (with or without trailing args)
2984
- code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
2985
- code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
2986
- # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
2987
- code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
2988
- code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
2989
- return code
2990
-
2991
-
2992
- def patch_rmse_calls(code: str) -> str:
2993
- """
2994
- Make RMSE robust across sklearn versions.
2995
- - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
2996
- - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
2997
- """
2998
- import re
2999
- # (a) Specific RMSE pattern
3000
- code = re.sub(
3001
- r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
3002
- r"_SMX_rmse(\1)",
3003
- code,
3004
- flags=re.DOTALL
3005
- )
3006
- # (b) Guard any other MSE calls
3007
- code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
3008
- return code