syntaxmatrix 2.5.8.2__py3-none-any.whl → 2.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syntaxmatrix/agentic/agents.py +1149 -54
- syntaxmatrix/agentic/agents_orchestrer.py +326 -0
- syntaxmatrix/agentic/code_tools_registry.py +27 -32
- syntaxmatrix/commentary.py +16 -16
- syntaxmatrix/core.py +145 -75
- syntaxmatrix/db.py +416 -4
- syntaxmatrix/{display.py → display_html.py} +2 -6
- syntaxmatrix/gpt_models_latest.py +1 -1
- syntaxmatrix/media/__init__.py +0 -0
- syntaxmatrix/media/media_pixabay.py +277 -0
- syntaxmatrix/models.py +1 -1
- syntaxmatrix/page_builder_defaults.py +183 -0
- syntaxmatrix/page_builder_generation.py +1122 -0
- syntaxmatrix/page_layout_contract.py +644 -0
- syntaxmatrix/page_patch_publish.py +1471 -0
- syntaxmatrix/preface.py +128 -8
- syntaxmatrix/profiles.py +26 -13
- syntaxmatrix/routes.py +1475 -429
- syntaxmatrix/selftest_page_templates.py +360 -0
- syntaxmatrix/settings/client_items.py +28 -0
- syntaxmatrix/settings/model_map.py +1022 -208
- syntaxmatrix/settings/prompts.py +328 -130
- syntaxmatrix/static/assets/hero-default.svg +22 -0
- syntaxmatrix/static/icons/bot-icon.png +0 -0
- syntaxmatrix/static/icons/favicon.png +0 -0
- syntaxmatrix/static/icons/logo.png +0 -0
- syntaxmatrix/static/icons/logo3.png +0 -0
- syntaxmatrix/templates/admin_branding.html +104 -0
- syntaxmatrix/templates/admin_secretes.html +108 -0
- syntaxmatrix/templates/dashboard.html +116 -72
- syntaxmatrix/templates/edit_page.html +2535 -0
- syntaxmatrix/utils.py +2365 -2411
- {syntaxmatrix-2.5.8.2.dist-info → syntaxmatrix-2.6.1.dist-info}/METADATA +6 -2
- {syntaxmatrix-2.5.8.2.dist-info → syntaxmatrix-2.6.1.dist-info}/RECORD +37 -24
- syntaxmatrix/generate_page.py +0 -644
- syntaxmatrix/static/icons/hero_bg.jpg +0 -0
- {syntaxmatrix-2.5.8.2.dist-info → syntaxmatrix-2.6.1.dist-info}/WHEEL +0 -0
- {syntaxmatrix-2.5.8.2.dist-info → syntaxmatrix-2.6.1.dist-info}/licenses/LICENSE.txt +0 -0
- {syntaxmatrix-2.5.8.2.dist-info → syntaxmatrix-2.6.1.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py
CHANGED
|
@@ -6,89 +6,41 @@ from difflib import get_close_matches
|
|
|
6
6
|
from typing import Iterable, Tuple, Dict
|
|
7
7
|
import inspect
|
|
8
8
|
from sklearn.preprocessing import OneHotEncoder
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
from syntaxmatrix.agentic.model_templates import (
|
|
12
|
-
classification, regression, multilabel_classification,
|
|
13
|
-
eda_overview, eda_correlation,
|
|
14
|
-
anomaly_detection, ts_anomaly_detection,
|
|
15
|
-
dimensionality_reduction, feature_selection,
|
|
16
|
-
time_series_forecasting, time_series_classification,
|
|
17
|
-
unknown_group_proxy_pack, viz_line,
|
|
18
|
-
clustering, recommendation, topic_modelling,
|
|
19
|
-
viz_pie, viz_count_bar, viz_box, viz_scatter,
|
|
20
|
-
viz_stacked_bar, viz_distribution, viz_area, viz_kde,
|
|
21
|
-
)
|
|
22
9
|
import ast
|
|
23
10
|
|
|
24
11
|
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
|
25
12
|
|
|
26
|
-
|
|
27
|
-
"classification",
|
|
28
|
-
"multilabel_classification",
|
|
29
|
-
"regression",
|
|
30
|
-
"anomaly_detection",
|
|
31
|
-
"time_series_forecasting",
|
|
32
|
-
"time_series_classification",
|
|
33
|
-
"ts_anomaly_detection",
|
|
34
|
-
"dimensionality_reduction",
|
|
35
|
-
"feature_selection",
|
|
36
|
-
"clustering",
|
|
37
|
-
"eda",
|
|
38
|
-
"correlation_analysis",
|
|
39
|
-
"visualisation",
|
|
40
|
-
"recommendation",
|
|
41
|
-
"topic_modelling",
|
|
42
|
-
]
|
|
43
|
-
|
|
44
|
-
def classify_ml_job(prompt: str) -> str:
|
|
13
|
+
def patch_quiet_specific_warnings(code: str) -> str:
|
|
45
14
|
"""
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
'classification' | 'regression' | 'eda'
|
|
15
|
+
Inserts targeted warning filters (not blanket ignores).
|
|
16
|
+
- seaborn palette/hue deprecation
|
|
17
|
+
- python-dotenv parse chatter
|
|
50
18
|
"""
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
))
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
"suspicious"
|
|
77
|
-
)):
|
|
78
|
-
return "anomaly_detection"
|
|
79
|
-
|
|
80
|
-
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
81
|
-
return "stat_test"
|
|
82
|
-
if "forecast" in p or "prophet" in p:
|
|
83
|
-
return "time_series"
|
|
84
|
-
if "cluster" in p or "kmeans" in p:
|
|
85
|
-
return "clustering"
|
|
86
|
-
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
87
|
-
return "classification"
|
|
88
|
-
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
89
|
-
return "regression"
|
|
90
|
-
|
|
91
|
-
return "eda"
|
|
19
|
+
prelude = (
|
|
20
|
+
"import warnings\n"
|
|
21
|
+
"warnings.filterwarnings(\n"
|
|
22
|
+
" 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
|
|
23
|
+
"warnings.filterwarnings(\n"
|
|
24
|
+
" 'ignore', message=r'python-dotenv could not parse statement.*')\n"
|
|
25
|
+
)
|
|
26
|
+
# If warnings already imported once, just add filters; else insert full prelude.
|
|
27
|
+
if "import warnings" in code:
|
|
28
|
+
code = re.sub(
|
|
29
|
+
r"(import warnings[^\n]*\n)",
|
|
30
|
+
lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
|
|
31
|
+
code,
|
|
32
|
+
count=1
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
# place after first import block if possible
|
|
37
|
+
m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
|
|
38
|
+
if m:
|
|
39
|
+
idx = m.end()
|
|
40
|
+
code = code[:idx] + prelude + code[idx:]
|
|
41
|
+
else:
|
|
42
|
+
code = prelude + code
|
|
43
|
+
return code
|
|
92
44
|
|
|
93
45
|
|
|
94
46
|
def _indent(code: str, spaces: int = 4) -> str:
|
|
@@ -121,7 +73,7 @@ def wrap_llm_code_safe(body: str) -> str:
|
|
|
121
73
|
" df_local = globals().get('df')\n"
|
|
122
74
|
" if df_local is not None:\n"
|
|
123
75
|
" import pandas as pd\n"
|
|
124
|
-
" from syntaxmatrix.preface import SB_histplot, _SMX_export_png\n"
|
|
76
|
+
" from syntaxmatrix.preface import SB_histplot, SB_boxplot, SB_scatterplot, SB_heatmap, _SMX_export_png\n"
|
|
125
77
|
" num_cols = df_local.select_dtypes(include=['number', 'bool']).columns.tolist()\n"
|
|
126
78
|
" cat_cols = [c for c in df_local.columns if c not in num_cols]\n"
|
|
127
79
|
" info = {\n"
|
|
@@ -135,11 +87,78 @@ def wrap_llm_code_safe(body: str) -> str:
|
|
|
135
87
|
" if num_cols:\n"
|
|
136
88
|
" SB_histplot()\n"
|
|
137
89
|
" _SMX_export_png()\n"
|
|
90
|
+
" if len(num_cols) >= 2:\n"
|
|
91
|
+
" SB_scatterplot()\n"
|
|
92
|
+
" _SMX_export_png()\n"
|
|
93
|
+
" if num_cols and cat_cols:\n"
|
|
94
|
+
" SB_boxplot()\n"
|
|
95
|
+
" _SMX_export_png()\n"
|
|
96
|
+
" if len(num_cols) >= 2:\n"
|
|
97
|
+
" SB_heatmap()\n"
|
|
98
|
+
" _SMX_export_png()\n"
|
|
138
99
|
" except Exception as _f:\n"
|
|
139
100
|
" show(f\"⚠️ Fallback EDA failed: {type(_f).__name__}: {_f}\")\n"
|
|
140
101
|
)
|
|
141
102
|
|
|
142
103
|
|
|
104
|
+
def fix_print_html(code: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
|
|
107
|
+
not printed as `<IPython.core.display.HTML object>` to the server console.
|
|
108
|
+
- Rewrites: print(HTML(...)) → display(HTML(...))
|
|
109
|
+
print(display(...)) → display(...)
|
|
110
|
+
print(df.to_html(...)) → display(HTML(df.to_html(...)))
|
|
111
|
+
Also prepends `from IPython.display import display, HTML` if required.
|
|
112
|
+
"""
|
|
113
|
+
import re
|
|
114
|
+
|
|
115
|
+
new = code
|
|
116
|
+
|
|
117
|
+
# 1) print(HTML(...)) -> display(HTML(...))
|
|
118
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
|
|
119
|
+
|
|
120
|
+
# 2) print(display(...)) -> display(...)
|
|
121
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
|
|
122
|
+
|
|
123
|
+
# 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
|
|
124
|
+
new = re.sub(
|
|
125
|
+
r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
|
|
126
|
+
r"display(HTML(\1.to_html(", new
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# If code references HTML() or display() make sure the import exists
|
|
130
|
+
if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
|
|
131
|
+
"from IPython.display import display, HTML" not in new:
|
|
132
|
+
new = "from IPython.display import display, HTML\n" + new
|
|
133
|
+
|
|
134
|
+
return new
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def ensure_ipy_display(code: str) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Guarantee that the cell has proper IPython display imports so that
|
|
140
|
+
display(HTML(...)) produces 'display_data' events the kernel captures.
|
|
141
|
+
"""
|
|
142
|
+
if "display(" in code and "from IPython.display import display, HTML" not in code:
|
|
143
|
+
return "from IPython.display import display, HTML\n" + code
|
|
144
|
+
return code
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def fix_plain_prints(code: str) -> str:
|
|
148
|
+
"""
|
|
149
|
+
Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
|
|
150
|
+
to go through SyntaxMatrix's smart display (so it renders in the dashboard).
|
|
151
|
+
Keeps string prints alone.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
# Skip obvious string-literal prints
|
|
155
|
+
new = re.sub(
|
|
156
|
+
r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
|
|
157
|
+
r"from syntaxmatrix.display import show\nshow(\1)",
|
|
158
|
+
code,
|
|
159
|
+
)
|
|
160
|
+
return new
|
|
161
|
+
|
|
143
162
|
def harden_ai_code(code: str) -> str:
|
|
144
163
|
"""
|
|
145
164
|
Make any AI-generated cell resilient:
|
|
@@ -202,21 +221,37 @@ def harden_ai_code(code: str) -> str:
|
|
|
202
221
|
|
|
203
222
|
return pattern.sub(repl, code)
|
|
204
223
|
|
|
224
|
+
|
|
205
225
|
def _wrap_metric_calls(code: str) -> str:
|
|
206
226
|
names = [
|
|
207
|
-
"r2_score",
|
|
208
|
-
"
|
|
209
|
-
"
|
|
210
|
-
"
|
|
211
|
-
"
|
|
227
|
+
"r2_score",
|
|
228
|
+
"accuracy_score",
|
|
229
|
+
"precision_score",
|
|
230
|
+
"recall_score",
|
|
231
|
+
"f1_score",
|
|
232
|
+
"roc_auc_score",
|
|
233
|
+
"classification_report",
|
|
234
|
+
"confusion_matrix",
|
|
235
|
+
"mean_absolute_error",
|
|
236
|
+
"mean_absolute_percentage_error",
|
|
237
|
+
"explained_variance_score",
|
|
238
|
+
"log_loss",
|
|
239
|
+
"average_precision_score",
|
|
240
|
+
"precision_recall_fscore_support",
|
|
241
|
+
"mean_squared_error",
|
|
212
242
|
]
|
|
213
|
-
pat = re.compile(
|
|
243
|
+
pat = re.compile(
|
|
244
|
+
r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\("
|
|
245
|
+
)
|
|
246
|
+
|
|
214
247
|
def repl(m):
|
|
215
248
|
prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
|
|
216
249
|
name = m.group(2)
|
|
217
250
|
return f"_SMX_call({prefix}{name}, "
|
|
251
|
+
|
|
218
252
|
return pat.sub(repl, code)
|
|
219
253
|
|
|
254
|
+
|
|
220
255
|
def _smx_patch_mean_squared_error_squared_kw():
|
|
221
256
|
"""
|
|
222
257
|
sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
|
|
@@ -291,6 +326,8 @@ def harden_ai_code(code: str) -> str:
|
|
|
291
326
|
needed.add("r2_score")
|
|
292
327
|
if "mean_absolute_error" in code:
|
|
293
328
|
needed.add("mean_absolute_error")
|
|
329
|
+
if "mean_squared_error" in code:
|
|
330
|
+
needed.add("mean_squared_error")
|
|
294
331
|
# ... add others if you like ...
|
|
295
332
|
|
|
296
333
|
if not needed:
|
|
@@ -351,7 +388,8 @@ def harden_ai_code(code: str) -> str:
|
|
|
351
388
|
- then falls back to generic but useful EDA.
|
|
352
389
|
|
|
353
390
|
It assumes `from syntaxmatrix.preface import *` has already been done,
|
|
354
|
-
so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `
|
|
391
|
+
so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `SB_boxplot`,
|
|
392
|
+
`SB_scatterplot`, `SB_heatmap`, `_SMX_export_png` and the
|
|
355
393
|
patched `show()` are available.
|
|
356
394
|
"""
|
|
357
395
|
import textwrap
|
|
@@ -461,10 +499,25 @@ def harden_ai_code(code: str) -> str:
|
|
|
461
499
|
show(df.head(), title='Sample of data')
|
|
462
500
|
show(info, title='Dataset summary')
|
|
463
501
|
|
|
464
|
-
#
|
|
502
|
+
# 1) Distribution of a numeric column
|
|
465
503
|
if num_cols:
|
|
466
504
|
SB_histplot()
|
|
467
505
|
_SMX_export_png()
|
|
506
|
+
|
|
507
|
+
# 2) Relationship between two numeric columns
|
|
508
|
+
if len(num_cols) >= 2:
|
|
509
|
+
SB_scatterplot()
|
|
510
|
+
_SMX_export_png()
|
|
511
|
+
|
|
512
|
+
# 3) Distribution of a numeric by the first categorical column
|
|
513
|
+
if num_cols and cat_cols:
|
|
514
|
+
SB_boxplot()
|
|
515
|
+
_SMX_export_png()
|
|
516
|
+
|
|
517
|
+
# 4) Correlation heatmap across numeric columns
|
|
518
|
+
if len(num_cols) >= 2:
|
|
519
|
+
SB_heatmap()
|
|
520
|
+
_SMX_export_png()
|
|
468
521
|
except Exception as _eda_e:
|
|
469
522
|
show(f"⚠ EDA fallback failed: {type(_eda_e).__name__}: {_eda_e}")
|
|
470
523
|
"""
|
|
@@ -596,19 +649,21 @@ def harden_ai_code(code: str) -> str:
|
|
|
596
649
|
except (SyntaxError, IndentationError):
|
|
597
650
|
fixed = _fallback_snippet()
|
|
598
651
|
|
|
652
|
+
|
|
653
|
+
# Fix placeholder Ellipsis handlers from LLM
|
|
599
654
|
fixed = re.sub(
|
|
600
655
|
r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
|
|
601
656
|
"except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
|
|
602
657
|
fixed,
|
|
603
658
|
)
|
|
604
659
|
|
|
605
|
-
#
|
|
660
|
+
# redirect that import to the real template module.
|
|
606
661
|
fixed = re.sub(
|
|
607
|
-
r"
|
|
608
|
-
"
|
|
662
|
+
r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
|
|
663
|
+
r"from syntaxmatrix.agentic.model_templates import \1",
|
|
609
664
|
fixed,
|
|
610
665
|
)
|
|
611
|
-
|
|
666
|
+
|
|
612
667
|
try:
|
|
613
668
|
class _SMXMatmulRewriter(ast.NodeTransformer):
|
|
614
669
|
def visit_BinOp(self, node):
|
|
@@ -628,14 +683,6 @@ def harden_ai_code(code: str) -> str:
|
|
|
628
683
|
fixed = fixed.replace("\t", " ")
|
|
629
684
|
fixed = textwrap.dedent(fixed).strip("\n")
|
|
630
685
|
|
|
631
|
-
# Normalise any mistaken template imports the LLM may have invented.
|
|
632
|
-
# If the model writes "from syntaxmatrix.templates import viz_count_bar",
|
|
633
|
-
# redirect that import to the real template module.
|
|
634
|
-
fixed = re.sub(
|
|
635
|
-
r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
|
|
636
|
-
r"from syntaxmatrix.agentic.model_templates import \1",
|
|
637
|
-
fixed,
|
|
638
|
-
)
|
|
639
686
|
fixed = _ensure_metrics_imports(fixed)
|
|
640
687
|
fixed = _strip_stray_backrefs(fixed)
|
|
641
688
|
fixed = _wrap_metric_calls(fixed)
|
|
@@ -661,822 +708,822 @@ def harden_ai_code(code: str) -> str:
|
|
|
661
708
|
return wrapped
|
|
662
709
|
|
|
663
710
|
|
|
664
|
-
def indent_code(code: str, spaces: int = 4) -> str:
|
|
665
|
-
|
|
666
|
-
|
|
711
|
+
# def indent_code(code: str, spaces: int = 4) -> str:
|
|
712
|
+
# pad = " " * spaces
|
|
713
|
+
# return "\n".join(pad + line for line in code.splitlines())
|
|
667
714
|
|
|
668
715
|
|
|
669
|
-
def fix_boxplot_placeholder(code: str) -> str:
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
716
|
+
# def fix_boxplot_placeholder(code: str) -> str:
|
|
717
|
+
# # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
|
|
718
|
+
# return re.sub(
|
|
719
|
+
# r"sns\.boxplot\(\s*boxplot\s*\)",
|
|
720
|
+
# "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
|
|
721
|
+
# code
|
|
722
|
+
# )
|
|
676
723
|
|
|
677
724
|
|
|
678
|
-
def relax_required_columns(code: str) -> str:
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
725
|
+
# def relax_required_columns(code: str) -> str:
|
|
726
|
+
# # Remove hard failure on required_cols; keep a soft filter instead
|
|
727
|
+
# return re.sub(
|
|
728
|
+
# r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
729
|
+
# "required_cols = [c for c in df.columns]\n",
|
|
730
|
+
# code,
|
|
731
|
+
# flags=re.S
|
|
732
|
+
# )
|
|
686
733
|
|
|
687
734
|
|
|
688
|
-
def make_numeric_vars_dynamic(code: str) -> str:
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
735
|
+
# def make_numeric_vars_dynamic(code: str) -> str:
|
|
736
|
+
# # Replace any static numeric_vars list with a dynamic selection
|
|
737
|
+
# return re.sub(
|
|
738
|
+
# r"numeric_vars\s*=\s*\[.*?\]",
|
|
739
|
+
# "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
740
|
+
# code,
|
|
741
|
+
# flags=re.S
|
|
742
|
+
# )
|
|
696
743
|
|
|
697
744
|
|
|
698
|
-
def auto_inject_template(code: str, intents, df) -> str:
|
|
699
|
-
|
|
745
|
+
# def auto_inject_template(code: str, intents, df) -> str:
|
|
746
|
+
# """If the LLM forgot the core logic, prepend a skeleton block."""
|
|
700
747
|
|
|
701
|
-
|
|
702
|
-
|
|
748
|
+
# has_fit = ".fit(" in code
|
|
749
|
+
# has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
|
|
703
750
|
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
751
|
+
# UNKNOWN_TOKENS = {
|
|
752
|
+
# "unknown","not reported","not_reported","not known","n/a","na",
|
|
753
|
+
# "none","nan","missing","unreported","unspecified","null","-",""
|
|
754
|
+
# }
|
|
755
|
+
|
|
756
|
+
# # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
|
|
757
|
+
# def _call_template(func, df, **hints):
|
|
758
|
+
# import inspect
|
|
759
|
+
# try:
|
|
760
|
+
# params = inspect.signature(func).parameters
|
|
761
|
+
# kw = {k: v for k, v in hints.items() if k in params}
|
|
762
|
+
# try:
|
|
763
|
+
# return func(df, **kw)
|
|
764
|
+
# except TypeError:
|
|
765
|
+
# # In case the template changed its signature at runtime
|
|
766
|
+
# return func(df)
|
|
767
|
+
# except Exception:
|
|
768
|
+
# # Absolute safety net
|
|
769
|
+
# try:
|
|
770
|
+
# return func(df)
|
|
771
|
+
# except Exception:
|
|
772
|
+
# # As a last resort, return empty code so we don't 500
|
|
773
|
+
# return ""
|
|
774
|
+
|
|
775
|
+
# def _guess_classification_target(df: pd.DataFrame) -> str | None:
|
|
776
|
+
# cols = list(df.columns)
|
|
777
|
+
|
|
778
|
+
# # Helper: does this column look like a sensible label?
|
|
779
|
+
# def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
|
|
780
|
+
# try:
|
|
781
|
+
# nunq = s.dropna().nunique()
|
|
782
|
+
# except Exception:
|
|
783
|
+
# return False
|
|
784
|
+
# # need at least 2 classes, but not hundreds
|
|
785
|
+
# if nunq < 2 or nunq > 20:
|
|
786
|
+
# return False
|
|
787
|
+
# bad_name_keys = ("id", "identifier", "index", "uuid", "key")
|
|
788
|
+
# name = str(col_name).lower()
|
|
789
|
+
# if any(k in name for k in bad_name_keys):
|
|
790
|
+
# return False
|
|
791
|
+
# return True
|
|
792
|
+
|
|
793
|
+
# # 1) columns whose names look like labels
|
|
794
|
+
# label_keys = ("target", "label", "outcome", "class", "y", "status")
|
|
795
|
+
# name_candidates: list[str] = []
|
|
796
|
+
# for key in label_keys:
|
|
797
|
+
# for c in cols:
|
|
798
|
+
# if key in str(c).lower():
|
|
799
|
+
# name_candidates.append(c)
|
|
800
|
+
# if name_candidates:
|
|
801
|
+
# break # keep the earliest matching key-group
|
|
802
|
+
|
|
803
|
+
# # prioritise name-based candidates that also look like proper label columns
|
|
804
|
+
# for c in name_candidates:
|
|
805
|
+
# if _is_reasonable_class_col(df[c], c):
|
|
806
|
+
# return c
|
|
807
|
+
# if name_candidates:
|
|
808
|
+
# # fall back to the first name-based candidate if none passed the shape test
|
|
809
|
+
# return name_candidates[0]
|
|
810
|
+
|
|
811
|
+
# # 2) any column with a small number of distinct values (likely a class label)
|
|
812
|
+
# for c in cols:
|
|
813
|
+
# s = df[c]
|
|
814
|
+
# if _is_reasonable_class_col(s, c):
|
|
815
|
+
# return c
|
|
816
|
+
|
|
817
|
+
# # Nothing suitable found
|
|
818
|
+
# return None
|
|
819
|
+
|
|
820
|
+
# def _guess_regression_target(df: pd.DataFrame) -> str | None:
|
|
821
|
+
# num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
|
|
822
|
+
# if not num_cols:
|
|
823
|
+
# return None
|
|
824
|
+
# # Avoid obvious ID-like columns
|
|
825
|
+
# bad_keys = ("id", "identifier", "index")
|
|
826
|
+
# candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
|
|
827
|
+
# return (candidates or num_cols)[-1]
|
|
781
828
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
829
|
+
# def _guess_time_col(df: pd.DataFrame) -> str | None:
|
|
830
|
+
# # Prefer actual datetime dtype
|
|
831
|
+
# dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
|
|
832
|
+
# if dt_cols:
|
|
833
|
+
# return dt_cols[0]
|
|
834
|
+
|
|
835
|
+
# # Fallback: name-based hints
|
|
836
|
+
# name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
|
|
837
|
+
# for c in df.columns:
|
|
838
|
+
# name = str(c).lower()
|
|
839
|
+
# if any(k in name for k in name_keys):
|
|
840
|
+
# return c
|
|
841
|
+
# return None
|
|
842
|
+
|
|
843
|
+
# def _guess_entity_col(df: pd.DataFrame) -> str | None:
|
|
844
|
+
# # Typical sequence IDs: id, patient, subject, device, series, entity
|
|
845
|
+
# keys = ["id", "patient", "subject", "device", "series", "entity"]
|
|
846
|
+
# candidates = []
|
|
847
|
+
# for c in df.columns:
|
|
848
|
+
# name = str(c).lower()
|
|
849
|
+
# if any(k in name for k in keys):
|
|
850
|
+
# candidates.append(c)
|
|
851
|
+
# return candidates[0] if candidates else None
|
|
852
|
+
|
|
853
|
+
# def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
|
|
854
|
+
# # Try label-like names first
|
|
855
|
+
# keys = ["target", "label", "class", "outcome", "y"]
|
|
856
|
+
# for key in keys:
|
|
857
|
+
# for c in df.columns:
|
|
858
|
+
# if key in str(c).lower():
|
|
859
|
+
# return c
|
|
860
|
+
|
|
861
|
+
# # Fallback: any column with few distinct values (e.g. <= 10)
|
|
862
|
+
# for c in df.columns:
|
|
863
|
+
# s = df[c]
|
|
864
|
+
# # avoid obvious IDs
|
|
865
|
+
# if any(k in str(c).lower() for k in ["id", "index"]):
|
|
866
|
+
# continue
|
|
867
|
+
# try:
|
|
868
|
+
# nunq = s.dropna().nunique()
|
|
869
|
+
# except Exception:
|
|
870
|
+
# continue
|
|
871
|
+
# if 1 < nunq <= 10:
|
|
872
|
+
# return c
|
|
873
|
+
|
|
874
|
+
# return None
|
|
875
|
+
|
|
876
|
+
# def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
|
|
877
|
+
# cols = list(df.columns)
|
|
878
|
+
# lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
|
|
879
|
+
# # also include boolean/binary columns with suitable names
|
|
880
|
+
# for c in cols:
|
|
881
|
+
# s = df[c]
|
|
882
|
+
# try:
|
|
883
|
+
# nunq = s.dropna().nunique()
|
|
884
|
+
# except Exception:
|
|
885
|
+
# continue
|
|
886
|
+
# if nunq in (2,) and c not in lbl_like:
|
|
887
|
+
# # avoid obvious IDs
|
|
888
|
+
# if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
|
|
889
|
+
# lbl_like.append(c)
|
|
890
|
+
# # keep at most, say, 12 to avoid accidental flood
|
|
891
|
+
# return lbl_like[:12]
|
|
892
|
+
|
|
893
|
+
# def _find_unknownish_column(df: pd.DataFrame) -> str | None:
|
|
894
|
+
# # Search categorical-like columns for any 'unknown-like' values or high missingness
|
|
895
|
+
# candidates = []
|
|
896
|
+
# for c in df.columns:
|
|
897
|
+
# s = df[c]
|
|
898
|
+
# # focus on object/category/boolean-ish or low-card columns
|
|
899
|
+
# if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
|
|
900
|
+
# continue
|
|
901
|
+
# try:
|
|
902
|
+
# vals = s.astype(str).str.strip().str.lower()
|
|
903
|
+
# except Exception:
|
|
904
|
+
# continue
|
|
905
|
+
# # score: presence of unknown tokens + missing rate
|
|
906
|
+
# token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
|
|
907
|
+
# miss_rate = s.isna().mean()
|
|
908
|
+
# name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
|
|
909
|
+
# score = 3*token_hit + 2*name_bonus + miss_rate
|
|
910
|
+
# if token_hit or miss_rate > 0.05 or name_bonus:
|
|
911
|
+
# candidates.append((score, c))
|
|
912
|
+
# if not candidates:
|
|
913
|
+
# return None
|
|
914
|
+
# candidates.sort(reverse=True)
|
|
915
|
+
# return candidates[0][1]
|
|
916
|
+
|
|
917
|
+
# def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
|
|
918
|
+
# cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
|
|
919
|
+
# # prefer non-constant columns
|
|
920
|
+
# scored = []
|
|
921
|
+
# for c in cols:
|
|
922
|
+
# try:
|
|
923
|
+
# v = df[c].dropna()
|
|
924
|
+
# var = float(v.var()) if len(v) else 0.0
|
|
925
|
+
# scored.append((var, c))
|
|
926
|
+
# except Exception:
|
|
927
|
+
# continue
|
|
928
|
+
# scored.sort(reverse=True)
|
|
929
|
+
# return [c for _, c in scored[:max_n]]
|
|
930
|
+
|
|
931
|
+
# def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
|
|
932
|
+
# exclude = exclude or set()
|
|
933
|
+
# picks = []
|
|
934
|
+
# for c in df.columns:
|
|
935
|
+
# if c in exclude:
|
|
936
|
+
# continue
|
|
937
|
+
# s = df[c]
|
|
938
|
+
# if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
|
|
939
|
+
# nunq = s.dropna().nunique()
|
|
940
|
+
# if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
|
|
941
|
+
# picks.append((nunq, c))
|
|
942
|
+
# picks.sort(reverse=True)
|
|
943
|
+
# return [c for _, c in picks[:max_n]]
|
|
944
|
+
|
|
945
|
+
# def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
|
|
946
|
+
# exclude = exclude or set()
|
|
947
|
+
# # name hints first
|
|
948
|
+
# name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
|
|
949
|
+
# for c in df.columns:
|
|
950
|
+
# if c in exclude:
|
|
951
|
+
# continue
|
|
952
|
+
# name = str(c).lower()
|
|
953
|
+
# if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
|
|
954
|
+
# return c
|
|
955
|
+
# # fallback: any binary numeric
|
|
956
|
+
# for c in df.select_dtypes(include=[np.number, "bool"]).columns:
|
|
957
|
+
# if c in exclude:
|
|
958
|
+
# continue
|
|
959
|
+
# try:
|
|
960
|
+
# if df[c].dropna().nunique() == 2:
|
|
961
|
+
# return c
|
|
962
|
+
# except Exception:
|
|
963
|
+
# continue
|
|
964
|
+
# return None
|
|
918
965
|
|
|
919
966
|
|
|
920
|
-
|
|
921
|
-
|
|
967
|
+
# def _pick_viz_template(signal: str):
|
|
968
|
+
# s = signal.lower()
|
|
922
969
|
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
970
|
+
# # explicit chart requests
|
|
971
|
+
# if any(k in s for k in ("pie", "donut")):
|
|
972
|
+
# return viz_pie
|
|
926
973
|
|
|
927
|
-
|
|
928
|
-
|
|
974
|
+
# if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
|
|
975
|
+
# return viz_stacked_bar
|
|
929
976
|
|
|
930
|
-
|
|
931
|
-
|
|
977
|
+
# if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
|
|
978
|
+
# return viz_distribution
|
|
932
979
|
|
|
933
|
-
|
|
934
|
-
|
|
980
|
+
# if any(k in s for k in ("kde", "density")):
|
|
981
|
+
# return viz_kde
|
|
935
982
|
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
983
|
+
# # these three you asked about
|
|
984
|
+
# if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
|
|
985
|
+
# return viz_box
|
|
939
986
|
|
|
940
|
-
|
|
941
|
-
|
|
987
|
+
# if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
|
|
988
|
+
# return viz_scatter
|
|
942
989
|
|
|
943
|
-
|
|
944
|
-
|
|
990
|
+
# if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
|
|
991
|
+
# return viz_count_bar
|
|
945
992
|
|
|
946
|
-
|
|
947
|
-
|
|
993
|
+
# if any(k in s for k in ("area", "trend", "over time", "time series")):
|
|
994
|
+
# return viz_area
|
|
948
995
|
|
|
949
|
-
|
|
950
|
-
|
|
996
|
+
# # fallback
|
|
997
|
+
# return viz_line
|
|
951
998
|
|
|
952
|
-
|
|
999
|
+
# for intent in intents:
|
|
953
1000
|
|
|
954
|
-
|
|
955
|
-
|
|
1001
|
+
# if intent not in INJECTABLE_INTENTS:
|
|
1002
|
+
# return code
|
|
956
1003
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
1004
|
+
# # Correlation analysis
|
|
1005
|
+
# if intent == "correlation_analysis" and not has_fit:
|
|
1006
|
+
# return eda_correlation(df) + "\n\n" + code
|
|
1007
|
+
|
|
1008
|
+
# # Generic visualisation (keyword-based)
|
|
1009
|
+
# if intent == "visualisation" and not has_fit and not has_plot:
|
|
1010
|
+
# rq = str(globals().get("refined_question", ""))
|
|
1011
|
+
# # aq = str(globals().get("askai_question", ""))
|
|
1012
|
+
# signal = rq + "\n" + str(intents) + "\n" + code
|
|
1013
|
+
# tpl = _pick_viz_template(signal)
|
|
1014
|
+
# return tpl(df) + "\n\n" + code
|
|
968
1015
|
|
|
969
|
-
|
|
970
|
-
|
|
1016
|
+
# if intent == "clustering" and not has_fit:
|
|
1017
|
+
# return clustering(df) + "\n\n" + code
|
|
971
1018
|
|
|
972
|
-
|
|
973
|
-
|
|
1019
|
+
# if intent == "recommendation" and not has_fit:
|
|
1020
|
+
# return recommendation(df) + "\\n\\n" + code
|
|
974
1021
|
|
|
975
|
-
|
|
976
|
-
|
|
1022
|
+
# if intent == "topic_modelling" and not has_fit:
|
|
1023
|
+
# return topic_modelling(df) + "\\n\\n" + code
|
|
977
1024
|
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1025
|
+
# if intent == "eda" and not has_fit:
|
|
1026
|
+
# return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
|
|
1027
|
+
|
|
1028
|
+
# # --- Classification ------------------------------------------------
|
|
1029
|
+
# if intent == "classification" and not has_fit:
|
|
1030
|
+
# target = _guess_classification_target(df)
|
|
1031
|
+
# if target:
|
|
1032
|
+
# return classification(df) + "\n\n" + code
|
|
1033
|
+
# # return _call_template(classification, df, target) + "\n\n" + code
|
|
1034
|
+
|
|
1035
|
+
# # --- Regression ----------------------------------------------------
|
|
1036
|
+
# if intent == "regression" and not has_fit:
|
|
1037
|
+
# target = _guess_regression_target(df)
|
|
1038
|
+
# if target:
|
|
1039
|
+
# return regression(df) + "\n\n" + code
|
|
1040
|
+
# # return _call_template(regression, df, target) + "\n\n" + code
|
|
1041
|
+
|
|
1042
|
+
# # --- Anomaly detection --------------------------------------------
|
|
1043
|
+
# if intent == "anomaly_detection":
|
|
1044
|
+
# uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
|
|
1045
|
+
# if not uses_anomaly:
|
|
1046
|
+
# return anomaly_detection(df) + "\n\n" + code
|
|
1047
|
+
|
|
1048
|
+
# # --- Time-series anomaly detection --------------------------------
|
|
1049
|
+
# if intent == "ts_anomaly_detection":
|
|
1050
|
+
# uses_ts = "STL(" in code or "seasonal_decompose(" in code
|
|
1051
|
+
# if not uses_ts:
|
|
1052
|
+
# return ts_anomaly_detection(df) + "\n\n" + code
|
|
1006
1053
|
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1054
|
+
# # --- Time-series classification --------------------------------
|
|
1055
|
+
# if intent == "time_series_classification" and not has_fit:
|
|
1056
|
+
# time_col = _guess_time_col(df)
|
|
1057
|
+
# entity_col = _guess_entity_col(df)
|
|
1058
|
+
# target_col = _guess_ts_class_target(df)
|
|
1059
|
+
|
|
1060
|
+
# # If we can't confidently identify these, do NOT inject anything
|
|
1061
|
+
# if time_col and entity_col and target_col:
|
|
1062
|
+
# return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
|
|
1063
|
+
|
|
1064
|
+
# # --- Dimensionality reduction --------------------------------------
|
|
1065
|
+
# if intent == "dimensionality_reduction":
|
|
1066
|
+
# uses_dr = any(k in code for k in ("PCA(", "TSNE("))
|
|
1067
|
+
# if not uses_dr:
|
|
1068
|
+
# return dimensionality_reduction(df) + "\n\n" + code
|
|
1069
|
+
|
|
1070
|
+
# # --- Feature selection ---------------------------------------------
|
|
1071
|
+
# if intent == "feature_selection":
|
|
1072
|
+
# uses_fs = any(k in code for k in (
|
|
1073
|
+
# "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
|
|
1074
|
+
# ))
|
|
1075
|
+
# if not uses_fs:
|
|
1076
|
+
# return feature_selection(df) + "\n\n" + code
|
|
1077
|
+
|
|
1078
|
+
# # --- EDA / correlation / visualisation -----------------------------
|
|
1079
|
+
# if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
|
|
1080
|
+
# if intent == "correlation_analysis":
|
|
1081
|
+
# return eda_correlation(df) + "\n\n" + code
|
|
1082
|
+
# else:
|
|
1083
|
+
# return eda_overview(df) + "\n\n" + code
|
|
1084
|
+
|
|
1085
|
+
# # --- Time-series forecasting ---------------------------------------
|
|
1086
|
+
# if intent == "time_series_forecasting" and not has_fit:
|
|
1087
|
+
# uses_ts_forecast = any(k in code for k in (
|
|
1088
|
+
# "ARIMA", "ExponentialSmoothing", "forecast", "predict("
|
|
1089
|
+
# ))
|
|
1090
|
+
# if not uses_ts_forecast:
|
|
1091
|
+
# return time_series_forecasting(df) + "\n\n" + code
|
|
1045
1092
|
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1093
|
+
# # --- Multi-label classification -----------------------------------
|
|
1094
|
+
# if intent in ("multilabel_classification",) and not has_fit:
|
|
1095
|
+
# label_cols = _guess_multilabel_cols(df)
|
|
1096
|
+
# if len(label_cols) >= 2:
|
|
1097
|
+
# return multilabel_classification(df, label_cols) + "\n\n" + code
|
|
1098
|
+
|
|
1099
|
+
# group_col = _find_unknownish_column(df)
|
|
1100
|
+
# if group_col:
|
|
1101
|
+
# num_cols = _guess_numeric_cols(df)
|
|
1102
|
+
# cat_cols = _guess_categorical_cols(df, exclude={group_col})
|
|
1103
|
+
# outcome_col = None # generic; let template skip if not present
|
|
1104
|
+
# tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
|
|
1105
|
+
|
|
1106
|
+
# # Return template + guarded (repaired) LLM code, so it never crashes
|
|
1107
|
+
# repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
|
|
1108
|
+
# return tpl + "\n\n" + wrap_llm_code_safe(repaired)
|
|
1062
1109
|
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
def strip_describe_slice(code: str) -> str:
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
def remove_plt_show(code: str) -> str:
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
def patch_plot_with_table(code: str) -> str:
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1110
|
+
# return code
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
# def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1114
|
+
# """
|
|
1115
|
+
# If a previous pass injected numeric_only=True into a NumPy-style sum,
|
|
1116
|
+
# e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
|
|
1117
|
+
# """
|
|
1118
|
+
# # .values.sum(numeric_only=True, ...)
|
|
1119
|
+
# code = re.sub(
|
|
1120
|
+
# r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1121
|
+
# ".to_numpy().sum()",
|
|
1122
|
+
# code,
|
|
1123
|
+
# flags=re.IGNORECASE,
|
|
1124
|
+
# )
|
|
1125
|
+
# # .to_numpy().sum(numeric_only=True, ...)
|
|
1126
|
+
# code = re.sub(
|
|
1127
|
+
# r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1128
|
+
# ".to_numpy().sum()",
|
|
1129
|
+
# code,
|
|
1130
|
+
# flags=re.IGNORECASE,
|
|
1131
|
+
# )
|
|
1132
|
+
# return code
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
# def strip_describe_slice(code: str) -> str:
|
|
1136
|
+
# """
|
|
1137
|
+
# Remove any pattern like df.groupby(...).describe()[[ ... ]] because
|
|
1138
|
+
# slicing a SeriesGroupBy.describe() causes AttributeError.
|
|
1139
|
+
# We leave the plain .describe() in place (harmless) and let our own
|
|
1140
|
+
# table patcher add the correct .agg() table afterwards.
|
|
1141
|
+
# """
|
|
1142
|
+
# pat = re.compile(
|
|
1143
|
+
# r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
|
|
1144
|
+
# flags=re.DOTALL,
|
|
1145
|
+
# )
|
|
1146
|
+
# return pat.sub(r"\1)", code)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
# def remove_plt_show(code: str) -> str:
|
|
1150
|
+
# """Removes all plt.show() calls from the generated code string."""
|
|
1151
|
+
# return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
# def patch_plot_with_table(code: str) -> str:
|
|
1155
|
+
# """
|
|
1156
|
+
# ▸ strips every `plt.show()` (avoids warnings)
|
|
1157
|
+
# ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
|
|
1158
|
+
# rendered in the dashboard
|
|
1159
|
+
# ▸ appends a summary-stats table **after** the plot
|
|
1160
|
+
# """
|
|
1161
|
+
# # 0. drop plt.show()
|
|
1162
|
+
# lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
|
|
1163
|
+
|
|
1164
|
+
# # 1. locate the last plotting line
|
|
1165
|
+
# plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
|
|
1166
|
+
# last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
|
|
1167
|
+
# if last_plot == -1:
|
|
1168
|
+
# return "\n".join(lines) # nothing to do
|
|
1169
|
+
|
|
1170
|
+
# whole = "\n".join(lines)
|
|
1171
|
+
|
|
1172
|
+
# # 2. detect group / feature (if any)
|
|
1173
|
+
# group, feature = None, None
|
|
1174
|
+
# xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
|
|
1175
|
+
# ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
|
|
1176
|
+
# if xm and ym:
|
|
1177
|
+
# group, feature = xm.group(1), ym.group(1)
|
|
1178
|
+
# else:
|
|
1179
|
+
# cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
|
|
1180
|
+
# if cm:
|
|
1181
|
+
# feature, group = cm.group(1), cm.group(2)
|
|
1182
|
+
|
|
1183
|
+
# # 3. code that captures current fig → PNG → HTML
|
|
1184
|
+
# img_block = textwrap.dedent("""
|
|
1185
|
+
# import io, base64
|
|
1186
|
+
# buf = io.BytesIO()
|
|
1187
|
+
# plt.savefig(buf, format='png', bbox_inches='tight')
|
|
1188
|
+
# buf.seek(0)
|
|
1189
|
+
# img_b64 = base64.b64encode(buf.read()).decode('utf-8')
|
|
1190
|
+
# from IPython.display import display, HTML
|
|
1191
|
+
# display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
|
|
1192
|
+
# plt.close()
|
|
1193
|
+
# """)
|
|
1194
|
+
|
|
1195
|
+
# # 4. build summary-table code
|
|
1196
|
+
# if group and feature:
|
|
1197
|
+
# tbl_block = (
|
|
1198
|
+
# f"summary_table = (\n"
|
|
1199
|
+
# f" df.groupby('{group}')['{feature}']\n"
|
|
1200
|
+
# f" .agg(['count','mean','std','min','median','max'])\n"
|
|
1201
|
+
# f" .rename(columns={{'median':'50%'}})\n"
|
|
1202
|
+
# f")\n"
|
|
1203
|
+
# )
|
|
1204
|
+
# elif ym:
|
|
1205
|
+
# feature = ym.group(1)
|
|
1206
|
+
# tbl_block = (
|
|
1207
|
+
# f"summary_table = (\n"
|
|
1208
|
+
# f" df['{feature}']\n"
|
|
1209
|
+
# f" .agg(['count','mean','std','min','median','max'])\n"
|
|
1210
|
+
# f" .rename(columns={{'median':'50%'}})\n"
|
|
1211
|
+
# f")\n"
|
|
1212
|
+
# )
|
|
1166
1213
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1214
|
+
# # 3️⃣ grid-search results
|
|
1215
|
+
# elif "GridSearchCV(" in code:
|
|
1216
|
+
# tbl_block = textwrap.dedent("""
|
|
1217
|
+
# # build tidy CV-results table
|
|
1218
|
+
# cv_df = (
|
|
1219
|
+
# pd.DataFrame(grid_search.cv_results_)
|
|
1220
|
+
# .loc[:, ['param_n_estimators', 'param_max_depth',
|
|
1221
|
+
# 'mean_test_score', 'std_test_score']]
|
|
1222
|
+
# .rename(columns={
|
|
1223
|
+
# 'param_n_estimators': 'n_estimators',
|
|
1224
|
+
# 'param_max_depth': 'max_depth',
|
|
1225
|
+
# 'mean_test_score': 'mean_cv_accuracy',
|
|
1226
|
+
# 'std_test_score': 'std'
|
|
1227
|
+
# })
|
|
1228
|
+
# .sort_values('mean_cv_accuracy', ascending=False)
|
|
1229
|
+
# .reset_index(drop=True)
|
|
1230
|
+
# )
|
|
1231
|
+
# summary_table = cv_df
|
|
1232
|
+
# """).strip() + "\n"
|
|
1233
|
+
# else:
|
|
1234
|
+
# tbl_block = (
|
|
1235
|
+
# "summary_table = (\n"
|
|
1236
|
+
# " df.describe().T[['count','mean','std','min','50%','max']]\n"
|
|
1237
|
+
# ")\n"
|
|
1238
|
+
# )
|
|
1239
|
+
|
|
1240
|
+
# tbl_block += "show(summary_table, title='Summary Statistics')"
|
|
1241
|
+
|
|
1242
|
+
# # 5. inject image-export block, then table block, after the plot
|
|
1243
|
+
# patched = (
|
|
1244
|
+
# lines[:last_plot+1]
|
|
1245
|
+
# + img_block.splitlines()
|
|
1246
|
+
# + tbl_block.splitlines()
|
|
1247
|
+
# + lines[last_plot+1:]
|
|
1248
|
+
# )
|
|
1249
|
+
# patched_code = "\n".join(patched)
|
|
1250
|
+
# # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
|
|
1251
|
+
# return textwrap.dedent(patched_code)
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
# def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
1255
|
+
# """
|
|
1256
|
+
# Rewrites user's EDA question to avoid classic mistakes:
|
|
1257
|
+
# - For line plots and scatter: recommend aggregation or sampling if large.
|
|
1258
|
+
# - For histograms/bar: clarify which variable to plot and bin count.
|
|
1259
|
+
# - For correlation: suggest a heatmap.
|
|
1260
|
+
# - For counts: direct request for df.shape[0].
|
|
1261
|
+
# df (optional): pass DataFrame for column inspection.
|
|
1262
|
+
# """
|
|
1263
|
+
|
|
1264
|
+
# # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
|
|
1265
|
+
# pc = re.match(
|
|
1266
|
+
# r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
|
|
1267
|
+
# raw_question, re.I
|
|
1268
|
+
# )
|
|
1269
|
+
# if pc:
|
|
1270
|
+
# col1, col2 = pc.group(1), pc.group(3)
|
|
1271
|
+
# # Return an instruction that preserves the exact intent
|
|
1272
|
+
# return (
|
|
1273
|
+
# f"Compute the Pearson correlation coefficient (r) and p-value "
|
|
1274
|
+
# f"between {col1} and {col2}. "
|
|
1275
|
+
# f"Print a short interpretation."
|
|
1276
|
+
# )
|
|
1277
|
+
# # -----------------------------------------------------------------
|
|
1278
|
+
# # ── Detect "predict <column>" intent ──────────────────────────────
|
|
1279
|
+
# c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
|
|
1280
|
+
# if c:
|
|
1281
|
+
# target = c.group(1)
|
|
1282
|
+
# raw_question += (
|
|
1283
|
+
# f" IMPORTANT: do NOT recreate or overwrite the existing target column "
|
|
1284
|
+
# f"“{target}”. Use it as-is for y = df['{target}']."
|
|
1285
|
+
# )
|
|
1286
|
+
|
|
1287
|
+
# q = raw_question.strip()
|
|
1288
|
+
# # REMOVE explicit summary-table instructions
|
|
1289
|
+
# # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
|
|
1290
|
+
# q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
|
|
1291
|
+
# q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
|
|
1292
|
+
|
|
1293
|
+
# ql = q.lower()
|
|
1294
|
+
|
|
1295
|
+
# # ── NEW: if the text contains an exact column name, leave it alone ──
|
|
1296
|
+
# if df is not None:
|
|
1297
|
+
# for col in df.columns:
|
|
1298
|
+
# if col.lower() in ql:
|
|
1299
|
+
# return q
|
|
1300
|
+
|
|
1301
|
+
# modelling_keywords = (
|
|
1302
|
+
# "random forest", "gradient-boost", "tree-based model",
|
|
1303
|
+
# "feature importance", "feature importances",
|
|
1304
|
+
# "overall accuracy", "train a model", "predict "
|
|
1305
|
+
# )
|
|
1306
|
+
# if any(k in ql for k in modelling_keywords):
|
|
1307
|
+
# return q
|
|
1308
|
+
|
|
1309
|
+
# # 1. Line plots: average if plotting raw numeric vs numeric
|
|
1310
|
+
# if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
|
|
1311
|
+
# match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
|
|
1312
|
+
# if match:
|
|
1313
|
+
# y, _, x = match.groups()
|
|
1314
|
+
# return f"Show me the average {y} by {x} as a line plot."
|
|
1315
|
+
|
|
1316
|
+
# # 2. Scatter plots: sample if too large
|
|
1317
|
+
# if "scatter" in ql or "scatter plot" in ql:
|
|
1318
|
+
# if df is not None and df.shape[0] > max_points:
|
|
1319
|
+
# return q + " (use only a random sample of 1000 points to avoid overplotting)"
|
|
1320
|
+
# else:
|
|
1321
|
+
# return q
|
|
1322
|
+
|
|
1323
|
+
# # 3. Histogram: specify bins and column
|
|
1324
|
+
# if "histogram" in ql:
|
|
1325
|
+
# match = re.search(r'histogram of ([\w_]+)', ql)
|
|
1326
|
+
# if match:
|
|
1327
|
+
# col = match.group(1)
|
|
1328
|
+
# return f"Show me a histogram of {col} using 20 bins."
|
|
1329
|
+
|
|
1330
|
+
# # Special case: histogram for column with most missing values
|
|
1331
|
+
# if "most missing" in ql:
|
|
1332
|
+
# return (
|
|
1333
|
+
# "Show a histogram for the column with the most missing values. "
|
|
1334
|
+
# "First, select the column using: "
|
|
1335
|
+
# "column_with_most_missing = df.isnull().sum().idxmax(); "
|
|
1336
|
+
# "then plot its histogram with: "
|
|
1337
|
+
# "df[column_with_most_missing].hist()"
|
|
1338
|
+
# )
|
|
1339
|
+
|
|
1340
|
+
# # 4. Bar plot: show top N
|
|
1341
|
+
# if "bar plot" in ql or "bar chart" in ql:
|
|
1342
|
+
# match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
|
|
1343
|
+
# if match:
|
|
1344
|
+
# col = match.group(2)
|
|
1345
|
+
# return f"Show me a bar plot of the top 10 {col} values."
|
|
1346
|
+
|
|
1347
|
+
# # 5. Correlation or heatmap
|
|
1348
|
+
# if "correlation" in ql:
|
|
1349
|
+
# return (
|
|
1350
|
+
# "Show a correlation heatmap for all numeric columns only. "
|
|
1351
|
+
# "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
|
|
1352
|
+
# )
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
# # 6. Counts/size
|
|
1356
|
+
# if "how many record" in ql or "row count" in ql or "number of rows" in ql:
|
|
1357
|
+
# return "How many rows are in the dataset?"
|
|
1358
|
+
|
|
1359
|
+
# # 7. General best-practices fallback: add axis labels/titles
|
|
1360
|
+
# if "plot" in ql:
|
|
1361
|
+
# return q + " (make sure the axes are labeled and the plot is readable)"
|
|
1315
1362
|
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1363
|
+
# # 8.
|
|
1364
|
+
# if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
|
|
1365
|
+
# match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
|
|
1366
|
+
# col = match.group(1) if match else None
|
|
1367
|
+
# if col:
|
|
1368
|
+
# return (
|
|
1369
|
+
# f"Show a bar plot of the counts of {col} using: "
|
|
1370
|
+
# f"df['{col}'].value_counts().plot(kind='bar'); "
|
|
1371
|
+
# "add axis labels and a title, then plt.show()."
|
|
1372
|
+
# )
|
|
1326
1373
|
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1374
|
+
# if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
|
|
1375
|
+
# return (
|
|
1376
|
+
# "Show a table of the mean, median, and standard deviation for all numeric columns. "
|
|
1377
|
+
# "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
|
|
1378
|
+
# )
|
|
1332
1379
|
|
|
1333
|
-
|
|
1334
|
-
|
|
1380
|
+
# # 9. Fallback: return the raw question
|
|
1381
|
+
# return q
|
|
1335
1382
|
|
|
1336
1383
|
|
|
1337
|
-
def patch_plot_code(code, df, user_question=None):
|
|
1384
|
+
# def patch_plot_code(code, df, user_question=None):
|
|
1338
1385
|
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1386
|
+
# # ── Early guard: abort nicely if the generated code references columns that
|
|
1387
|
+
# # do not exist in the DataFrame. This prevents KeyError crashes.
|
|
1388
|
+
# import re
|
|
1342
1389
|
|
|
1343
1390
|
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1391
|
+
# # ── Detect columns referenced in the code ──────────────────────────
|
|
1392
|
+
# col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
|
|
1393
|
+
|
|
1394
|
+
# # Columns that will be newly CREATED (appear left of '=')
|
|
1395
|
+
# new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
|
|
1396
|
+
|
|
1397
|
+
# missing_cols = [
|
|
1398
|
+
# col for col in col_refs
|
|
1399
|
+
# if col not in df.columns and col not in new_cols
|
|
1400
|
+
# ]
|
|
1401
|
+
|
|
1402
|
+
# if missing_cols:
|
|
1403
|
+
# cols_list = ", ".join(missing_cols)
|
|
1404
|
+
# warning = (
|
|
1405
|
+
# f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
|
|
1406
|
+
# "These must either exist in df or be created earlier in the code; "
|
|
1407
|
+
# "otherwise you may see a KeyError.')\n"
|
|
1408
|
+
# )
|
|
1409
|
+
# # Prepend the warning but keep the original code so it can still run
|
|
1410
|
+
# code = warning + code
|
|
1364
1411
|
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1412
|
+
# # 1. For line plots (auto-aggregate)
|
|
1413
|
+
# m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
1414
|
+
# if m_l:
|
|
1415
|
+
# x, y = m_l.groups()
|
|
1416
|
+
# if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
|
|
1417
|
+
# return (
|
|
1418
|
+
# f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
|
|
1419
|
+
# f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
|
|
1420
|
+
# f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
|
|
1421
|
+
# )
|
|
1422
|
+
|
|
1423
|
+
# # 2. For scatter plots: sample to 1000 points max
|
|
1424
|
+
# m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
1425
|
+
# if m_s:
|
|
1426
|
+
# x, y = m_s.groups()
|
|
1427
|
+
# if len(df) > 1000:
|
|
1428
|
+
# return (
|
|
1429
|
+
# f"samp = df.sample(1000, random_state=42)\n"
|
|
1430
|
+
# f"plt.scatter(samp['{x}'], samp['{y}'])\n"
|
|
1431
|
+
# f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
|
|
1432
|
+
# )
|
|
1433
|
+
|
|
1434
|
+
# # 3. For histograms: use bins=20 for numeric, value_counts for categorical
|
|
1435
|
+
# m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
|
|
1436
|
+
# if m_h:
|
|
1437
|
+
# col = m_h.group(1)
|
|
1438
|
+
# if pd.api.types.is_numeric_dtype(df[col]):
|
|
1439
|
+
# return (
|
|
1440
|
+
# f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
|
|
1441
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
|
|
1442
|
+
# )
|
|
1443
|
+
# else:
|
|
1444
|
+
# # If categorical, show bar plot of value counts
|
|
1445
|
+
# return (
|
|
1446
|
+
# f"df['{col}'].value_counts().plot(kind='bar')\n"
|
|
1447
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
|
|
1448
|
+
# )
|
|
1449
|
+
|
|
1450
|
+
# # 4. For bar plots: show only top 20
|
|
1451
|
+
# m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
|
|
1452
|
+
# if m_b:
|
|
1453
|
+
# col = m_b.group(1)
|
|
1454
|
+
# if df[col].nunique() > 20:
|
|
1455
|
+
# return (
|
|
1456
|
+
# f"topN = df['{col}'].value_counts().head(20)\n"
|
|
1457
|
+
# f"topN.plot(kind='bar')\n"
|
|
1458
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
|
|
1459
|
+
# )
|
|
1460
|
+
|
|
1461
|
+
# # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
|
|
1462
|
+
# if "df.plot" in code and len(df) > 10000:
|
|
1463
|
+
# return (
|
|
1464
|
+
# f"samp = df.sample(1000, random_state=42)\n"
|
|
1465
|
+
# + code.replace("df.", "samp.")
|
|
1466
|
+
# )
|
|
1420
1467
|
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1468
|
+
# # ── Block assignment to an existing target column ────────────────
|
|
1469
|
+
# #*******************************************************
|
|
1470
|
+
# target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
|
|
1471
|
+
# if target_match:
|
|
1472
|
+
# target = target_match.group(1)
|
|
1473
|
+
|
|
1474
|
+
# # pattern for an assignment to that target
|
|
1475
|
+
# assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
|
|
1476
|
+
# assign_line = re.search(assign_pat + r".*", code)
|
|
1477
|
+
# if assign_line:
|
|
1478
|
+
# # runtime check: keep the assignment **only if** the column is absent
|
|
1479
|
+
# guard = (
|
|
1480
|
+
# f"if '{target}' in df.columns:\n"
|
|
1481
|
+
# f" print('⚠️ {target} already exists – overwrite skipped.');\n"
|
|
1482
|
+
# f"else:\n"
|
|
1483
|
+
# f" {assign_line.group(0)}"
|
|
1484
|
+
# )
|
|
1485
|
+
# # remove original assignment line and insert guarded block
|
|
1486
|
+
# code = code.replace(assign_line.group(0), guard, 1)
|
|
1487
|
+
# # ***************************************************
|
|
1441
1488
|
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1489
|
+
# # 6. Grouped bar plot for two categoricals
|
|
1490
|
+
# # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
|
|
1491
|
+
# if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
|
|
1492
|
+
# # Try to infer columns from user question if possible:
|
|
1493
|
+
# group, cat = None, None
|
|
1494
|
+
# if user_question:
|
|
1495
|
+
# # crude parse for "counts of X for each Y"
|
|
1496
|
+
# m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
|
|
1497
|
+
# if m:
|
|
1498
|
+
# cat, group = m.groups()
|
|
1499
|
+
# if not (cat and group):
|
|
1500
|
+
# # fallback: use two most frequent categoricals
|
|
1501
|
+
# categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
|
|
1502
|
+
# if len(categoricals) >= 2:
|
|
1503
|
+
# cat, group = categoricals[:2]
|
|
1504
|
+
# else:
|
|
1505
|
+
# # fallback: any
|
|
1506
|
+
# cat, group = df.columns[:2]
|
|
1507
|
+
# return (
|
|
1508
|
+
# f"import pandas as pd\n"
|
|
1509
|
+
# f"import matplotlib.pyplot as plt\n"
|
|
1510
|
+
# f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
|
|
1511
|
+
# f"ct.plot(kind='bar')\n"
|
|
1512
|
+
# f"plt.title('Counts of {cat} for each {group}')\n"
|
|
1513
|
+
# f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
|
|
1514
|
+
# )
|
|
1515
|
+
|
|
1516
|
+
# # Fallback: Return original code
|
|
1517
|
+
# return code
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
# def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1521
|
+
# import re
|
|
1522
|
+
# makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
|
|
1523
|
+
# has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
|
|
1524
|
+
# if makes_plot and not has_title:
|
|
1525
|
+
# code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
|
|
1526
|
+
# return code
|
|
1480
1527
|
|
|
1481
1528
|
|
|
1482
1529
|
def ensure_output(code: str) -> str:
|
|
@@ -1555,1548 +1602,1455 @@ def ensure_output(code: str) -> str:
|
|
|
1555
1602
|
return "\n".join(lines)
|
|
1556
1603
|
|
|
1557
1604
|
|
|
1558
|
-
def get_plotting_imports(code):
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
def patch_pairplot(code, df):
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
def ensure_image_output(code: str) -> str:
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1605
|
+
# def get_plotting_imports(code):
|
|
1606
|
+
# imports = []
|
|
1607
|
+
# if "plt." in code and "import matplotlib.pyplot as plt" not in code:
|
|
1608
|
+
# imports.append("import matplotlib.pyplot as plt")
|
|
1609
|
+
# if "sns." in code and "import seaborn as sns" not in code:
|
|
1610
|
+
# imports.append("import seaborn as sns")
|
|
1611
|
+
# if "px." in code and "import plotly.express as px" not in code:
|
|
1612
|
+
# imports.append("import plotly.express as px")
|
|
1613
|
+
# if "pd." in code and "import pandas as pd" not in code:
|
|
1614
|
+
# imports.append("import pandas as pd")
|
|
1615
|
+
# if "np." in code and "import numpy as np" not in code:
|
|
1616
|
+
# imports.append("import numpy as np")
|
|
1617
|
+
# if "display(" in code and "from IPython.display import display" not in code:
|
|
1618
|
+
# imports.append("from IPython.display import display")
|
|
1619
|
+
# # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
|
|
1620
|
+
# if imports:
|
|
1621
|
+
# code = "\n".join(imports) + "\n\n" + code
|
|
1622
|
+
# return code
|
|
1623
|
+
|
|
1624
|
+
|
|
1625
|
+
# def patch_pairplot(code, df):
|
|
1626
|
+
# if "sns.pairplot" in code:
|
|
1627
|
+
# # Always assign and print pairgrid
|
|
1628
|
+
# code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
|
|
1629
|
+
# if "plt.show()" not in code:
|
|
1630
|
+
# code += "\nplt.show()"
|
|
1631
|
+
# if "print(pairgrid)" not in code:
|
|
1632
|
+
# code += "\nprint(pairgrid)"
|
|
1633
|
+
# return code
|
|
1634
|
+
|
|
1635
|
+
|
|
1636
|
+
# def ensure_image_output(code: str) -> str:
|
|
1637
|
+
# """
|
|
1638
|
+
# Replace each plt.show() with an indented _SMX_export_png() call.
|
|
1639
|
+
# This keeps block indentation valid and still renders images in the dashboard.
|
|
1640
|
+
# """
|
|
1641
|
+
# if "plt.show()" not in code:
|
|
1642
|
+
# return code
|
|
1643
|
+
|
|
1644
|
+
# import re
|
|
1645
|
+
# out_lines = []
|
|
1646
|
+
# for ln in code.splitlines():
|
|
1647
|
+
# if "plt.show()" not in ln:
|
|
1648
|
+
# out_lines.append(ln)
|
|
1649
|
+
# continue
|
|
1650
|
+
|
|
1651
|
+
# # works for:
|
|
1652
|
+
# # plt.show()
|
|
1653
|
+
# # plt.tight_layout(); plt.show()
|
|
1654
|
+
# # ... ; plt.show(); ... (multiple on one line)
|
|
1655
|
+
# indent = re.match(r"^(\s*)", ln).group(1)
|
|
1656
|
+
# parts = ln.split("plt.show()")
|
|
1657
|
+
|
|
1658
|
+
# # keep whatever is before the first plt.show()
|
|
1659
|
+
# if parts[0].strip():
|
|
1660
|
+
# out_lines.append(parts[0].rstrip())
|
|
1661
|
+
|
|
1662
|
+
# # for every plt.show() we removed, insert exporter at same indent
|
|
1663
|
+
# for _ in range(len(parts) - 1):
|
|
1664
|
+
# out_lines.append(indent + "_SMX_export_png()")
|
|
1665
|
+
|
|
1666
|
+
# # keep whatever comes after the last plt.show()
|
|
1667
|
+
# if parts[-1].strip():
|
|
1668
|
+
# out_lines.append(indent + parts[-1].lstrip())
|
|
1669
|
+
|
|
1670
|
+
# return "\n".join(out_lines)
|
|
1671
|
+
|
|
1672
|
+
|
|
1673
|
+
# def clean_llm_code(code: str) -> str:
|
|
1674
|
+
# """
|
|
1675
|
+
# Make LLM output safe to exec:
|
|
1676
|
+
# - If fenced blocks exist, keep the largest one (usually the real code).
|
|
1677
|
+
# - Otherwise strip any stray ``` / ```python lines.
|
|
1678
|
+
# - Remove common markdown/preamble junk.
|
|
1679
|
+
# """
|
|
1680
|
+
# code = str(code or "")
|
|
1681
|
+
|
|
1682
|
+
# # Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
|
|
1683
|
+
# # is accidentally passed here as `code`. In that case, extract the actual
|
|
1684
|
+
# # Python code from the ChatCompletionMessage(content=...) field.
|
|
1685
|
+
# if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
|
|
1686
|
+
# try:
|
|
1687
|
+
# extracted = None
|
|
1688
|
+
|
|
1689
|
+
# class _ChatCompletionVisitor(ast.NodeVisitor):
|
|
1690
|
+
# def visit_Call(self, node):
|
|
1691
|
+
# nonlocal extracted
|
|
1692
|
+
# func = node.func
|
|
1693
|
+
# fname = getattr(func, "id", None) or getattr(func, "attr", None)
|
|
1694
|
+
# if fname == "ChatCompletionMessage":
|
|
1695
|
+
# for kw in node.keywords:
|
|
1696
|
+
# if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
|
|
1697
|
+
# extracted = kw.value.value
|
|
1698
|
+
# self.generic_visit(node)
|
|
1699
|
+
|
|
1700
|
+
# tree = ast.parse(code, mode="exec")
|
|
1701
|
+
# _ChatCompletionVisitor().visit(tree)
|
|
1702
|
+
# if extracted:
|
|
1703
|
+
# code = extracted
|
|
1704
|
+
# except Exception:
|
|
1705
|
+
# # Best-effort regex fallback if AST parsing fails
|
|
1706
|
+
# m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
|
|
1707
|
+
# if m:
|
|
1708
|
+
# code = m.group("body")
|
|
1709
|
+
|
|
1710
|
+
# # Existing logic continues unchanged below...
|
|
1711
|
+
# # Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1712
|
+
# blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1713
|
+
|
|
1714
|
+
# if blocks:
|
|
1715
|
+
# # pick the largest block; small trailing blocks are usually garbage
|
|
1716
|
+
# largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1717
|
+
# if len(largest.strip().splitlines()) >= 10:
|
|
1718
|
+
# code = largest
|
|
1719
|
+
|
|
1720
|
+
# # Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1721
|
+
# blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1722
|
+
|
|
1723
|
+
# if blocks:
|
|
1724
|
+
# # pick the largest block; small trailing blocks are usually garbage
|
|
1725
|
+
# largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1726
|
+
# if len(largest.strip().splitlines()) >= 10:
|
|
1727
|
+
# code = largest
|
|
1728
|
+
# else:
|
|
1729
|
+
# # if no meaningful block, just remove fence markers
|
|
1730
|
+
# code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1731
|
+
# else:
|
|
1732
|
+
# # no complete blocks — still remove any stray fence lines
|
|
1733
|
+
# code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1734
|
+
|
|
1735
|
+
# # Strip common markdown/preamble lines
|
|
1736
|
+
# drop_prefixes = (
|
|
1737
|
+
# "here is", "here's", "below is", "sure,", "certainly",
|
|
1738
|
+
# "explanation", "note:", "```"
|
|
1739
|
+
# )
|
|
1740
|
+
# cleaned_lines = []
|
|
1741
|
+
# for ln in code.splitlines():
|
|
1742
|
+
# s = ln.strip().lower()
|
|
1743
|
+
# if any(s.startswith(p) for p in drop_prefixes):
|
|
1744
|
+
# continue
|
|
1745
|
+
# cleaned_lines.append(ln)
|
|
1746
|
+
|
|
1747
|
+
# return "\n".join(cleaned_lines).strip()
|
|
1748
|
+
|
|
1749
|
+
|
|
1750
|
+
# def fix_groupby_describe_slice(code: str) -> str:
|
|
1751
|
+
# """
|
|
1752
|
+
# Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
|
|
1753
|
+
# so it works for both SeriesGroupBy and DataFrameGroupBy.
|
|
1754
|
+
# """
|
|
1755
|
+
# pat = re.compile(
|
|
1756
|
+
# r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
|
|
1757
|
+
# re.MULTILINE
|
|
1758
|
+
# )
|
|
1759
|
+
# def repl(match):
|
|
1760
|
+
# inner = match.group(0)
|
|
1761
|
+
# # extract group and feature to build df.groupby('g')['f']
|
|
1762
|
+
# g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
|
|
1763
|
+
# f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
|
|
1764
|
+
# return (
|
|
1765
|
+
# f"df.groupby('{g}')['{f}']"
|
|
1766
|
+
# ".agg(['count','mean','std','min','median','max'])"
|
|
1767
|
+
# ".rename(columns={'median':'50%'})"
|
|
1768
|
+
# )
|
|
1769
|
+
# return pat.sub(repl, code)
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
# def fix_importance_groupby(code: str) -> str:
|
|
1773
|
+
# pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
|
|
1774
|
+
# if "importance_df" in code:
|
|
1775
|
+
# return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
|
|
1776
|
+
# return code
|
|
1777
|
+
|
|
1778
|
+
# def inject_auto_preprocessing(code: str) -> str:
|
|
1779
|
+
# """
|
|
1780
|
+
# • Detects a RandomForestClassifier in the generated code.
|
|
1781
|
+
# • Finds the target column from `y = df['target']`.
|
|
1782
|
+
# • Prepends a fully-dedented preprocessing snippet that:
|
|
1783
|
+
# – auto-detects numeric & categorical columns
|
|
1784
|
+
# – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
|
|
1785
|
+
# The dedent() call guarantees no leading-space IndentationError.
|
|
1786
|
+
# """
|
|
1787
|
+
# if "RandomForestClassifier" not in code:
|
|
1788
|
+
# return code # nothing to patch
|
|
1789
|
+
|
|
1790
|
+
# y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
1791
|
+
# if not y_match:
|
|
1792
|
+
# return code # can't infer target safely
|
|
1793
|
+
# target = y_match.group(1)
|
|
1794
|
+
|
|
1795
|
+
# prep_snippet = textwrap.dedent(f"""
|
|
1796
|
+
# # ── automatic preprocessing ───────────────────────────────
|
|
1797
|
+
# num_cols = df.select_dtypes(include=['number']).columns.tolist()
|
|
1798
|
+
# cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
|
1799
|
+
# num_cols = [c for c in num_cols if c != '{target}']
|
|
1800
|
+
# cat_cols = [c for c in cat_cols if c != '{target}']
|
|
1801
|
+
|
|
1802
|
+
# from sklearn.compose import ColumnTransformer
|
|
1803
|
+
# from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
|
1804
|
+
|
|
1805
|
+
# preproc = ColumnTransformer(
|
|
1806
|
+
# transformers=[
|
|
1807
|
+
# ('num', StandardScaler(), num_cols),
|
|
1808
|
+
# ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
|
|
1809
|
+
# ],
|
|
1810
|
+
# remainder='drop',
|
|
1811
|
+
# )
|
|
1812
|
+
# # ───────────────────────────────────────────────────────────
|
|
1813
|
+
# """).strip() + "\n\n"
|
|
1814
|
+
|
|
1815
|
+
# # simply prepend; model code that follows can wrap estimator in a Pipeline
|
|
1816
|
+
# return prep_snippet + code
|
|
1817
|
+
|
|
1818
|
+
|
|
1819
|
+
# def fix_to_datetime_errors(code: str) -> str:
|
|
1820
|
+
# """
|
|
1821
|
+
# Force every pd.to_datetime(…) call to ignore bad dates so that
|
|
1822
|
+
|
|
1823
|
+
# 'year 16500 is out of range' and similar issues don’t crash runs.
|
|
1824
|
+
# """
|
|
1825
|
+
# import re
|
|
1826
|
+
# # look for any pd.to_datetime( … )
|
|
1827
|
+
# pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
|
|
1828
|
+
# def repl(m):
|
|
1829
|
+
# inside = m.group(1)
|
|
1830
|
+
# # if the call already has errors=, leave it unchanged
|
|
1831
|
+
# if "errors=" in inside:
|
|
1832
|
+
# return m.group(0)
|
|
1833
|
+
# return f"pd.to_datetime({inside}, errors='coerce')"
|
|
1834
|
+
# return pat.sub(repl, code)
|
|
1835
|
+
|
|
1836
|
+
|
|
1837
|
+
# def fix_numeric_sum(code: str) -> str:
|
|
1838
|
+
# """
|
|
1839
|
+
# Make .sum(...) code safe across pandas versions by removing any
|
|
1840
|
+
# numeric_only=... argument (True/False/None) from function calls.
|
|
1841
|
+
|
|
1842
|
+
# This avoids errors on pandas versions where numeric_only is not
|
|
1843
|
+
# supported for Series/grouped sums, and we rely instead on explicit
|
|
1844
|
+
# numeric column selection (e.g. select_dtypes) in the generated code.
|
|
1845
|
+
# """
|
|
1846
|
+
# # Case 1: ..., numeric_only=True/False/None
|
|
1847
|
+
# code = re.sub(
|
|
1848
|
+
# r",\s*numeric_only\s*=\s*(True|False|None)",
|
|
1849
|
+
# "",
|
|
1850
|
+
# code,
|
|
1851
|
+
# flags=re.IGNORECASE,
|
|
1852
|
+
# )
|
|
1853
|
+
|
|
1854
|
+
# # Case 2: numeric_only=True/False/None, ... (as first argument)
|
|
1855
|
+
# code = re.sub(
|
|
1856
|
+
# r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
|
|
1857
|
+
# "",
|
|
1858
|
+
# code,
|
|
1859
|
+
# flags=re.IGNORECASE,
|
|
1860
|
+
# )
|
|
1861
|
+
|
|
1862
|
+
# # Case 3: numeric_only=True/False/None (only argument)
|
|
1863
|
+
# code = re.sub(
|
|
1864
|
+
# r"numeric_only\s*=\s*(True|False|None)",
|
|
1865
|
+
# "",
|
|
1866
|
+
# code,
|
|
1867
|
+
# flags=re.IGNORECASE,
|
|
1868
|
+
# )
|
|
1869
|
+
|
|
1870
|
+
# return code
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
# def fix_concat_empty_list(code: str) -> str:
|
|
1874
|
+
# """
|
|
1875
|
+
# Make pd.concat calls resilient to empty lists of objects.
|
|
1876
|
+
|
|
1877
|
+
# Transforms patterns like:
|
|
1878
|
+
# pd.concat(frames, ignore_index=True)
|
|
1879
|
+
# pd.concat(frames)
|
|
1880
|
+
|
|
1881
|
+
# into:
|
|
1882
|
+
# pd.concat(frames or [pd.DataFrame()], ignore_index=True)
|
|
1883
|
+
# pd.concat(frames or [pd.DataFrame()])
|
|
1884
|
+
|
|
1885
|
+
# Only triggers when the first argument is a simple variable name.
|
|
1886
|
+
# """
|
|
1887
|
+
# pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
|
|
1888
|
+
|
|
1889
|
+
# def _repl(m):
|
|
1890
|
+
# name = m.group(1)
|
|
1891
|
+
# sep = m.group(2) # ',' or ')'
|
|
1892
|
+
# return f"pd.concat({name} or [pd.DataFrame()]{sep}"
|
|
1893
|
+
|
|
1894
|
+
# return pattern.sub(_repl, code)
|
|
1895
|
+
|
|
1896
|
+
|
|
1897
|
+
# def fix_numeric_aggs(code: str) -> str:
|
|
1898
|
+
# _AGG_FUNCS = ("sum", "mean")
|
|
1899
|
+
# pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
|
|
1900
|
+
# def _repl(m):
|
|
1901
|
+
# func, args = m.group(1), m.group(2) or ""
|
|
1902
|
+
# if "numeric_only" in args:
|
|
1903
|
+
# return m.group(0)
|
|
1904
|
+
# args = args.rstrip()
|
|
1905
|
+
# if args:
|
|
1906
|
+
# args += ", "
|
|
1907
|
+
# return f".{func}({args}numeric_only=True)"
|
|
1908
|
+
# return pat.sub(_repl, code)
|
|
1909
|
+
|
|
1910
|
+
|
|
1911
|
+
# def ensure_accuracy_block(code: str) -> str:
|
|
1912
|
+
# """
|
|
1913
|
+
# Inject a sensible evaluation block right after the last `<est>.fit(...)`
|
|
1914
|
+
# Classification → accuracy + weighted F1
|
|
1915
|
+
# Regression → R², RMSE, MAE
|
|
1916
|
+
# Heuristic: infer task from estimator names present in the code.
|
|
1917
|
+
# """
|
|
1918
|
+
# import re, textwrap
|
|
1919
|
+
|
|
1920
|
+
# # If any proper metric already exists, do nothing
|
|
1921
|
+
# if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
|
|
1922
|
+
# return code
|
|
1923
|
+
|
|
1924
|
+
# # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
|
|
1925
|
+
# m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
|
|
1926
|
+
# if not m:
|
|
1927
|
+
# return code # no estimator
|
|
1928
|
+
|
|
1929
|
+
# var = m[-1].group(1)
|
|
1930
|
+
# # indent with same leading whitespace used on that line
|
|
1931
|
+
# indent = re.match(r"\s*", code[m[-1].start():]).group(0)
|
|
1932
|
+
|
|
1933
|
+
# # Detect regression by estimator names / hints in code
|
|
1934
|
+
# is_regression = bool(
|
|
1935
|
+
# re.search(
|
|
1936
|
+
# r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
1937
|
+
# r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
1938
|
+
# r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
1939
|
+
# )
|
|
1940
|
+
# or re.search(r"\bOLS\s*\(", code)
|
|
1941
|
+
# or re.search(r"\bRegressor\b", code)
|
|
1942
|
+
# )
|
|
1943
|
+
|
|
1944
|
+
# if is_regression:
|
|
1945
|
+
# # inject numpy import if needed for RMSE
|
|
1946
|
+
# if "import numpy as np" not in code and "np." not in code:
|
|
1947
|
+
# code = "import numpy as np\n" + code
|
|
1948
|
+
# eval_block = textwrap.dedent(f"""
|
|
1949
|
+
# {indent}# ── automatic regression evaluation ─────────
|
|
1950
|
+
# {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
|
1951
|
+
# {indent}y_pred = {var}.predict(X_test)
|
|
1952
|
+
# {indent}r2 = r2_score(y_test, y_pred)
|
|
1953
|
+
# {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
|
1954
|
+
# {indent}mae = float(mean_absolute_error(y_test, y_pred))
|
|
1955
|
+
# {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
|
|
1956
|
+
# """)
|
|
1957
|
+
# else:
|
|
1958
|
+
# eval_block = textwrap.dedent(f"""
|
|
1959
|
+
# {indent}# ── automatic classification evaluation ─────────
|
|
1960
|
+
# {indent}from sklearn.metrics import accuracy_score, f1_score
|
|
1961
|
+
# {indent}y_pred = {var}.predict(X_test)
|
|
1962
|
+
# {indent}acc = accuracy_score(y_test, y_pred)
|
|
1963
|
+
# {indent}f1 = f1_score(y_test, y_pred, average='weighted')
|
|
1964
|
+
# {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
|
|
1965
|
+
# """)
|
|
1966
|
+
|
|
1967
|
+
# insert_at = code.find("\n", m[-1].end()) + 1
|
|
1968
|
+
# return code[:insert_at] + eval_block + code[insert_at:]
|
|
1969
|
+
|
|
1970
|
+
|
|
1971
|
+
# def fix_scatter_and_summary(code: str) -> str:
|
|
1972
|
+
# """
|
|
1973
|
+
# 1. Change cmap='spectral' (any case) → cmap='Spectral'
|
|
1974
|
+
# 2. If the LLM forgets to close the parenthesis in
|
|
1975
|
+
# summary_table = ( df.describe()... <missing )>
|
|
1976
|
+
# insert the ')' right before the next 'from' or 'show('.
|
|
1977
|
+
# """
|
|
1978
|
+
# # 1️⃣ colormap case
|
|
1979
|
+
# code = re.sub(
|
|
1980
|
+
# r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
|
|
1981
|
+
# "cmap='Spectral'",
|
|
1982
|
+
# code,
|
|
1983
|
+
# flags=re.IGNORECASE,
|
|
1984
|
+
# )
|
|
1985
|
+
|
|
1986
|
+
# # 2️⃣ close summary_table = ( ... )
|
|
1987
|
+
# code = re.sub(
|
|
1988
|
+
# r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
|
|
1989
|
+
# r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
|
|
1990
|
+
# r"\1)", # keep group 1 and add ')'
|
|
1991
|
+
# code,
|
|
1992
|
+
# flags=re.MULTILINE,
|
|
1993
|
+
# )
|
|
1994
|
+
|
|
1995
|
+
# return code
|
|
1996
|
+
|
|
1997
|
+
|
|
1998
|
+
# def auto_format_with_black(code: str) -> str:
|
|
1999
|
+
# """
|
|
2000
|
+
# Format the generated code with Black. Falls back silently if Black
|
|
2001
|
+
# is missing or raises (so the dashboard never 500s).
|
|
2002
|
+
# """
|
|
2003
|
+
# try:
|
|
2004
|
+
# import black # make sure black is in your v-env: pip install black
|
|
2005
|
+
|
|
2006
|
+
# mode = black.FileMode() # default settings
|
|
2007
|
+
# return black.format_str(code, mode=mode)
|
|
2008
|
+
|
|
2009
|
+
# except Exception:
|
|
2010
|
+
# return code
|
|
2011
|
+
|
|
2012
|
+
|
|
2013
|
+
# def ensure_preproc_in_pipeline(code: str) -> str:
|
|
2014
|
+
# """
|
|
2015
|
+
# If code defines `preproc = ColumnTransformer(...)` but then builds
|
|
2016
|
+
# `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
|
|
2017
|
+
# that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
|
|
2018
|
+
# """
|
|
2019
|
+
# return re.sub(
|
|
2020
|
+
# r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
|
|
2021
|
+
# "Pipeline([('prep', preproc)",
|
|
2022
|
+
# code
|
|
2023
|
+
# )
|
|
2024
|
+
|
|
2025
|
+
# def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
2026
|
+
# """
|
|
2027
|
+
# Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
|
|
2028
|
+
# if the generated cell is *regression*. We infer this from:
|
|
2029
|
+
# 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
|
|
2030
|
+
# 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
|
|
2031
|
+
# Safe across datasets and queries.
|
|
2032
|
+
# """
|
|
2033
|
+
# import re
|
|
2034
|
+
# import pandas as pd
|
|
2035
|
+
|
|
2036
|
+
# # 1) Heuristic by estimator names in the *code* (fast path)
|
|
2037
|
+
# regression_by_model = bool(re.search(
|
|
2038
|
+
# r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
2039
|
+
# r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
2040
|
+
# r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
2041
|
+
# ) or re.search(r"\bOLS\s*\(", code))
|
|
2042
|
+
|
|
2043
|
+
# is_regression = regression_by_model
|
|
2044
|
+
|
|
2045
|
+
# # 2) If not obvious from the model, try to infer from y dtype (if we can)
|
|
2046
|
+
# if not is_regression:
|
|
2047
|
+
# try:
|
|
2048
|
+
# # Try to parse: y = df['target']
|
|
2049
|
+
# m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
2050
|
+
# if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
|
|
2051
|
+
# y = y_or_df[m.group(1)]
|
|
2052
|
+
# if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2053
|
+
# is_regression = True
|
|
2054
|
+
# else:
|
|
2055
|
+
# # If a Series was passed
|
|
2056
|
+
# y = y_or_df
|
|
2057
|
+
# if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2058
|
+
# is_regression = True
|
|
2059
|
+
# except Exception:
|
|
2060
|
+
# pass
|
|
2061
|
+
|
|
2062
|
+
# if is_regression:
|
|
2063
|
+
# # Strip classification-only lines
|
|
2064
|
+
# for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
|
|
2065
|
+
# code = re.sub(pat, "", code, flags=re.I)
|
|
2066
|
+
|
|
2067
|
+
# return code
|
|
2068
|
+
|
|
2069
|
+
|
|
2070
|
+
# def force_capture_display(code: str) -> str:
|
|
2071
|
+
# """
|
|
2072
|
+
# Ensure our executor captures HTML output:
|
|
2073
|
+
# - Remove any import that would override our 'display' hook.
|
|
2074
|
+
# - Keep/allow importing HTML only.
|
|
2075
|
+
# - Handle alias cases like 'display as d'.
|
|
2076
|
+
# """
|
|
2077
|
+
# import re
|
|
2078
|
+
# new = code
|
|
2079
|
+
|
|
2080
|
+
# # 'from IPython.display import display, HTML' -> keep HTML only
|
|
2081
|
+
# new = re.sub(
|
|
2082
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
|
|
2083
|
+
# r"from IPython.display import HTML\1", new
|
|
2084
|
+
# )
|
|
2085
|
+
|
|
2086
|
+
# # 'from IPython.display import display as d' -> 'd = display'
|
|
2087
|
+
# new = re.sub(
|
|
2088
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
|
|
2089
|
+
# r"\1 = display", new
|
|
2090
|
+
# )
|
|
2091
|
+
|
|
2092
|
+
# # 'from IPython.display import display' -> remove (use our injected display)
|
|
2093
|
+
# new = re.sub(
|
|
2094
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
|
|
2095
|
+
# r"# display import removed (SMX capture active)", new
|
|
2096
|
+
# )
|
|
2097
|
+
|
|
2098
|
+
# # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
|
|
2099
|
+
# new = re.sub(
|
|
2100
|
+
# r"(?m)\bIPython\.display\.display\s*\(",
|
|
2101
|
+
# "display(", new
|
|
2102
|
+
# )
|
|
2103
|
+
# new = re.sub(
|
|
2104
|
+
# r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
|
|
2105
|
+
# r"(?=.*import\s+IPython\.display\s+as\s+\1)",
|
|
2106
|
+
# "display(", new
|
|
2107
|
+
# )
|
|
2108
|
+
# return new
|
|
2109
|
+
|
|
2110
|
+
|
|
2111
|
+
# def strip_matplotlib_show(code: str) -> str:
|
|
2112
|
+
# """Remove blocking plt.show() calls (we export base64 instead)."""
|
|
2113
|
+
# import re
|
|
2114
|
+
# return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
|
|
2115
|
+
|
|
2116
|
+
|
|
2117
|
+
# def inject_display_shim(code: str) -> str:
|
|
2118
|
+
# """
|
|
2119
|
+
# Provide display()/HTML() if missing, forwarding to our executor hook.
|
|
2120
|
+
# Harmless if the names already exist.
|
|
2121
|
+
# """
|
|
2122
|
+
# shim = (
|
|
2123
|
+
# "try:\n"
|
|
2124
|
+
# " display\n"
|
|
2125
|
+
# "except NameError:\n"
|
|
2126
|
+
# " def display(obj=None, **kwargs):\n"
|
|
2127
|
+
# " __builtins__.get('_smx_display', print)(obj)\n"
|
|
2128
|
+
# "try:\n"
|
|
2129
|
+
# " HTML\n"
|
|
2130
|
+
# "except NameError:\n"
|
|
2131
|
+
# " class HTML:\n"
|
|
2132
|
+
# " def __init__(self, data): self.data = str(data)\n"
|
|
2133
|
+
# " def _repr_html_(self): return self.data\n"
|
|
2134
|
+
# "\n"
|
|
2135
|
+
# )
|
|
2136
|
+
# return shim + code
|
|
2137
|
+
|
|
2138
|
+
|
|
2139
|
+
# def strip_spurious_column_tokens(code: str) -> str:
|
|
2140
|
+
# """
|
|
2141
|
+
# Remove common stop-words ('the','whether', ...) when they appear
|
|
2142
|
+
# inside column lists, e.g.:
|
|
2143
|
+
# predictors = ['BMI','the','HbA1c']
|
|
2144
|
+
# df[['GGT','whether','BMI']]
|
|
2145
|
+
# Leaves other strings intact.
|
|
2146
|
+
# """
|
|
2147
|
+
# STOP = {
|
|
2148
|
+
# "the","whether","a","an","and","or","of","to","in","on","for","by",
|
|
2149
|
+
# "with","as","at","from","that","this","these","those","is","are","was","were",
|
|
2150
|
+
# "coef", "Coef", "coefficient", "Coefficient"
|
|
2151
|
+
# }
|
|
2152
|
+
|
|
2153
|
+
# def _norm(s: str) -> str:
|
|
2154
|
+
# return re.sub(r"[^a-z0-9]+", "", s.lower())
|
|
2155
|
+
|
|
2156
|
+
# def _clean_list(content: str) -> str:
|
|
2157
|
+
# # Rebuild a string list, keeping only non-stopword items
|
|
2158
|
+
# items = re.findall(r"(['\"])(.*?)\1", content)
|
|
2159
|
+
# if not items:
|
|
2160
|
+
# return "[" + content + "]"
|
|
2161
|
+
# keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
|
|
2162
|
+
# return "[" + ", ".join(keep) + "]"
|
|
2163
|
+
|
|
2164
|
+
# # Variable assignments: predictors/features/columns/cols = [...]
|
|
2165
|
+
# code = re.sub(
|
|
2166
|
+
# r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
|
|
2167
|
+
# lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
|
|
2168
|
+
# code
|
|
2169
|
+
# )
|
|
2170
|
+
|
|
2171
|
+
# # df[[ ... ]] selections
|
|
2172
|
+
# code = re.sub(
|
|
2173
|
+
# r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
|
|
2174
|
+
|
|
2175
|
+
# return code
|
|
2176
|
+
|
|
2177
|
+
|
|
2178
|
+
# def patch_prefix_seaborn_calls(code: str) -> str:
|
|
2179
|
+
# """
|
|
2180
|
+
# Ensure bare seaborn calls are prefixed with `sns.`.
|
|
2181
|
+
# E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
|
|
2182
|
+
# """
|
|
2183
|
+
# if "sns." in code:
|
|
2184
|
+
# # still fix any leftover bare calls alongside prefixed ones
|
|
2185
|
+
# pass
|
|
2186
|
+
|
|
2187
|
+
# # functions commonly used from seaborn
|
|
2188
|
+
# funcs = [
|
|
2189
|
+
# "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
|
|
2190
|
+
# "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
|
|
2191
|
+
# "scatterplot","lineplot","catplot","displot","lmplot"
|
|
2192
|
+
# ]
|
|
2193
|
+
# # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
|
|
2194
|
+
# # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
|
|
2195
|
+
# pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
|
|
2196
|
+
|
|
2197
|
+
# def _add_prefix(m):
|
|
2198
|
+
# fn = m.group(1)
|
|
2199
|
+
# return f"sns.{fn}("
|
|
2200
|
+
|
|
2201
|
+
# return pattern.sub(_add_prefix, code)
|
|
2202
|
+
|
|
2203
|
+
|
|
2204
|
+
# def patch_ensure_seaborn_import(code: str) -> str:
|
|
2205
|
+
# """
|
|
2206
|
+
# If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
|
|
2207
|
+
# Also set a quiet theme for consistent visuals.
|
|
2208
|
+
# """
|
|
2209
|
+
# needs_sns = "sns." in code
|
|
2210
|
+
# has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2211
|
+
# if needs_sns and not has_import:
|
|
2212
|
+
# # Insert after the first block of imports if possible, else at top
|
|
2213
|
+
# import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2214
|
+
# inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2215
|
+
# if import_block:
|
|
2216
|
+
# start = import_block.end()
|
|
2217
|
+
# code = code[:start] + inject + code[start:]
|
|
2218
|
+
# else:
|
|
2219
|
+
# code = inject + code
|
|
2220
|
+
# return code
|
|
2221
|
+
|
|
2222
|
+
|
|
2223
|
+
# def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2224
|
+
# """
|
|
2225
|
+
# Normalise pie-chart requests.
|
|
2226
|
+
|
|
2227
|
+
# Supports three patterns:
|
|
2228
|
+
# A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2229
|
+
# B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2230
|
+
# → one pie per facet level (grid) + counts bar of facet sizes.
|
|
2231
|
+
# C) Single pie when no split/facet is requested.
|
|
2232
|
+
|
|
2233
|
+
# Notes:
|
|
2234
|
+
# - Pie variables must be categorical (or numeric binned).
|
|
2235
|
+
# - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2236
|
+
# """
|
|
2237
|
+
|
|
2238
|
+
# q = (user_question or "")
|
|
2239
|
+
# q_low = q.lower()
|
|
2240
|
+
|
|
2241
|
+
# # Prefer explicit: df['col'].value_counts()
|
|
2242
|
+
# m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2243
|
+
# col = m.group(1) if m else None
|
|
2244
|
+
|
|
2245
|
+
# # ---------- helpers ----------
|
|
2246
|
+
# def _is_cat(col):
|
|
2247
|
+
# return (str(df[col].dtype).startswith("category")
|
|
2248
|
+
# or df[col].dtype == "object"
|
|
2249
|
+
# or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2250
|
+
|
|
2251
|
+
# def _cats_from_question(question: str):
|
|
2252
|
+
# found = []
|
|
2253
|
+
# for c in df.columns:
|
|
2254
|
+
# if c.lower() in question.lower() and _is_cat(c):
|
|
2255
|
+
# found.append(c)
|
|
2256
|
+
# # dedupe preserve order
|
|
2257
|
+
# seen, out = set(), []
|
|
2258
|
+
# for c in found:
|
|
2259
|
+
# if c not in seen:
|
|
2260
|
+
# out.append(c); seen.add(c)
|
|
2261
|
+
# return out
|
|
2262
|
+
|
|
2263
|
+
# def _fallback_cat():
|
|
2264
|
+
# cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2265
|
+
# if not cats: return None
|
|
2266
|
+
# cats.sort(key=lambda t: t[1])
|
|
2267
|
+
# return cats[0][0]
|
|
2268
|
+
|
|
2269
|
+
# def _infer_comp_pref(question: str) -> str:
|
|
2270
|
+
# ql = (question or "").lower()
|
|
2271
|
+
# if "heatmap" in ql or "matrix" in ql:
|
|
2272
|
+
# return "heatmap"
|
|
2273
|
+
# if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
|
|
2274
|
+
# return "stacked_bar_pct"
|
|
2275
|
+
# if "stacked" in ql:
|
|
2276
|
+
# return "stacked_bar"
|
|
2277
|
+
# if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
|
|
2278
|
+
# return "grouped_bar"
|
|
2279
|
+
# return "counts_bar"
|
|
2280
|
+
|
|
2281
|
+
# # parse threshold split like "HbA1c ≥ 6.5"
|
|
2282
|
+
# def _parse_split(question: str):
|
|
2283
|
+
# ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
|
|
2284
|
+
# m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
|
|
2285
|
+
# if not m: return None
|
|
2286
|
+
# col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
|
|
2287
|
+
# op = ops_map.get(op_raw);
|
|
2288
|
+
# if not op: return None
|
|
2289
|
+
# # case-insensitive column match
|
|
2290
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2291
|
+
# col = candidates.get(col_raw.lower())
|
|
2292
|
+
# if not col: return None
|
|
2293
|
+
# try: val = float(val_raw)
|
|
2294
|
+
# except Exception: return None
|
|
2295
|
+
# return (col, op, val)
|
|
2296
|
+
|
|
2297
|
+
# # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
|
|
2298
|
+
# def _extract_facet(question: str):
|
|
2299
|
+
# # 1) explicit "by/ across / within / per <col>"
|
|
2300
|
+
# for kw in [" by ", " across ", " within ", " within each ", " per "]:
|
|
2301
|
+
# m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
|
|
2302
|
+
# if m:
|
|
2303
|
+
# col_raw = m.group(1).strip()
|
|
2304
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2305
|
+
# if col_raw.lower() in candidates:
|
|
2306
|
+
# return (candidates[col_raw.lower()], "auto")
|
|
2307
|
+
# # 2) "bin <col>"
|
|
2308
|
+
# m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
|
|
2309
|
+
# if m2:
|
|
2310
|
+
# col_raw = m2.group(1).strip()
|
|
2311
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2312
|
+
# if col_raw.lower() in candidates:
|
|
2313
|
+
# return (candidates[col_raw.lower()], "bin")
|
|
2314
|
+
# # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
|
|
2315
|
+
# if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
|
|
2316
|
+
# any(c.lower() == "bmi" for c in df.columns.str.lower()):
|
|
2317
|
+
# bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
|
|
2318
|
+
# return (bmi_col, "bmi")
|
|
2319
|
+
# return None
|
|
2320
|
+
|
|
2321
|
+
# def _bmi_bins(series: pd.Series):
|
|
2322
|
+
# # WHO cutoffs
|
|
2323
|
+
# bins = [-np.inf, 18.5, 25, 30, np.inf]
|
|
2324
|
+
# labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
|
|
2325
|
+
# return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
|
|
2326
|
+
|
|
2327
|
+
# wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
|
|
2328
|
+
# if not wants_pie:
|
|
2329
|
+
# return code
|
|
2330
|
+
|
|
2331
|
+
# split = _parse_split(q)
|
|
2332
|
+
# facet = _extract_facet(q)
|
|
2333
|
+
# cats = _cats_from_question(q)
|
|
2334
|
+
# _comp_pref = _infer_comp_pref(q)
|
|
2335
|
+
|
|
2336
|
+
# # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
|
|
2337
|
+
# for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
|
|
2338
|
+
# if hard in df.columns and hard not in cats and hard.lower() in q_low:
|
|
2339
|
+
# cats.append(hard)
|
|
2340
|
+
|
|
2341
|
+
# # --------------- CASE A: threshold split (cohorts) ---------------
|
|
2342
|
+
# if split:
|
|
2343
|
+
# if not (cats or any(_is_cat(c) for c in df.columns)):
|
|
2344
|
+
# return code
|
|
2345
|
+
# if not cats:
|
|
2346
|
+
# pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2347
|
+
# pool.sort(key=lambda t: t[1])
|
|
2348
|
+
# cats = [t[0] for t in pool[:3]] if pool else []
|
|
2349
|
+
# if not cats:
|
|
2350
|
+
# return code
|
|
2351
|
+
|
|
2352
|
+
# split_col, op, val = split
|
|
2353
|
+
# cond_str = f"(df['{split_col}'] {op} {val})"
|
|
2354
|
+
# snippet = f"""
|
|
2355
|
+
# import numpy as np
|
|
2356
|
+
# import pandas as pd
|
|
2357
|
+
# import matplotlib.pyplot as plt
|
|
2358
|
+
|
|
2359
|
+
# _mask_a = ({cond_str}) & df['{split_col}'].notna()
|
|
2360
|
+
# _mask_b = (~({cond_str})) & df['{split_col}'].notna()
|
|
2361
|
+
|
|
2362
|
+
# _cohort_a_name = "{split_col} {op} {val}"
|
|
2363
|
+
# _cohort_b_name = "NOT ({split_col} {op} {val})"
|
|
2364
|
+
|
|
2365
|
+
# _cat_cols = {cats!r}
|
|
2366
|
+
# n = len(_cat_cols)
|
|
2367
|
+
# fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
|
|
2368
|
+
# if n == 1:
|
|
2369
|
+
# axes = np.array([axes])
|
|
2370
|
+
|
|
2371
|
+
# for i, col in enumerate(_cat_cols):
|
|
2372
|
+
# s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
|
|
2373
|
+
# s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
|
|
2374
|
+
|
|
2375
|
+
# ax_a = axes[i, 0]; ax_b = axes[i, 1]
|
|
2376
|
+
# if len(s_a) > 0:
|
|
2377
|
+
# ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
|
|
2378
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2379
|
+
# ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
|
|
2380
|
+
|
|
2381
|
+
# if len(s_b) > 0:
|
|
2382
|
+
# ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
|
|
2383
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2384
|
+
# ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
|
|
2385
|
+
|
|
2386
|
+
# plt.tight_layout(); plt.show()
|
|
2387
|
+
|
|
2388
|
+
# # grouped bar complement
|
|
2389
|
+
# for col in _cat_cols:
|
|
2390
|
+
# _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
|
|
2391
|
+
# .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
|
|
2392
|
+
# _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
|
|
2393
|
+
# _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2394
|
+
|
|
2395
|
+
# if _comp_pref == "grouped_bar":
|
|
2396
|
+
# ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
|
|
2397
|
+
# ax.set_title(f"{col} by cohort (grouped)")
|
|
2398
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2399
|
+
# plt.tight_layout(); plt.show()
|
|
2400
|
+
|
|
2401
|
+
# elif _comp_pref == "stacked_bar":
|
|
2402
|
+
# ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2403
|
+
# ax.set_title(f"{col} by cohort (stacked)")
|
|
2404
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2405
|
+
# plt.tight_layout(); plt.show()
|
|
2406
|
+
|
|
2407
|
+
# elif _comp_pref == "stacked_bar_pct":
|
|
2408
|
+
# _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2409
|
+
# ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2410
|
+
# ax.set_title(f"{col} by cohort (100% stacked)")
|
|
2411
|
+
# ax.set_xlabel(col); ax.set_ylabel("Percent")
|
|
2412
|
+
# plt.tight_layout(); plt.show()
|
|
2413
|
+
|
|
2414
|
+
# elif _comp_pref == "heatmap":
|
|
2415
|
+
# _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2416
|
+
# import numpy as np
|
|
2417
|
+
# fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
|
|
2418
|
+
# im = ax.imshow(_perc.values, aspect='auto')
|
|
2419
|
+
# ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2420
|
+
# ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2421
|
+
# ax.set_title(f"{col} by cohort — % heatmap")
|
|
2422
|
+
# for i in range(_perc.shape[0]):
|
|
2423
|
+
# for j in range(_perc.shape[1]):
|
|
2424
|
+
# ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2425
|
+
# fig.colorbar(im, ax=ax, label="%")
|
|
2426
|
+
# plt.tight_layout(); plt.show()
|
|
2427
|
+
|
|
2428
|
+
# else: # counts_bar (default)
|
|
2429
|
+
# ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
|
|
2430
|
+
# ax.set_title(f"{col}: total counts (both cohorts)")
|
|
2431
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2432
|
+
# plt.tight_layout(); plt.show()
|
|
2433
|
+
# """.lstrip()
|
|
2434
|
+
# return snippet
|
|
2435
|
+
|
|
2436
|
+
# # --------------- CASE B: facet-by (categories/bins) ---------------
|
|
2437
|
+
# if facet:
|
|
2438
|
+
# facet_col, how = facet
|
|
2439
|
+
# # Build facet series
|
|
2440
|
+
# if pd.api.types.is_numeric_dtype(df[facet_col]):
|
|
2441
|
+
# if how == "bmi":
|
|
2442
|
+
# facet_series = _bmi_bins(df[facet_col])
|
|
2443
|
+
# else:
|
|
2444
|
+
# # generic numeric bins: 3 equal-width bins by default
|
|
2445
|
+
# facet_series = pd.cut(df[facet_col].astype(float), bins=3)
|
|
2446
|
+
# else:
|
|
2447
|
+
# facet_series = df[facet_col].astype(str)
|
|
2448
|
+
|
|
2449
|
+
# # Choose pie dimension (categorical to count inside each facet)
|
|
2450
|
+
# pie_dim = None
|
|
2451
|
+
# for c in cats:
|
|
2452
|
+
# if c in df.columns and _is_cat(c):
|
|
2453
|
+
# pie_dim = c; break
|
|
2454
|
+
# if pie_dim is None:
|
|
2455
|
+
# pie_dim = _fallback_cat()
|
|
2456
|
+
# if pie_dim is None:
|
|
2457
|
+
# return code
|
|
2458
|
+
|
|
2459
|
+
# snippet = f"""
|
|
2460
|
+
# import math
|
|
2461
|
+
# import pandas as pd
|
|
2462
|
+
# import matplotlib.pyplot as plt
|
|
2463
|
+
|
|
2464
|
+
# df = df.copy()
|
|
2465
|
+
# _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
|
|
2466
|
+
|
|
2467
|
+
# def _select_facet_col(df, preferred=None):
|
|
2468
|
+
# if preferred is not None:
|
|
2469
|
+
# return preferred
|
|
2470
|
+
# # Prefer low-cardinality categoricals (readable pies/grids)
|
|
2471
|
+
# cat_cols = [
|
|
2472
|
+
# c for c in df.columns
|
|
2473
|
+
# if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
|
2474
|
+
# and df[c].nunique() > 1 and df[c].nunique() <= 20
|
|
2475
|
+
# ]
|
|
2476
|
+
# if cat_cols:
|
|
2477
|
+
# cat_cols.sort(key=lambda c: df[c].nunique())
|
|
2478
|
+
# return cat_cols[0]
|
|
2479
|
+
# # Else fall back to first usable numeric
|
|
2480
|
+
# num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
|
|
2481
|
+
# return num_cols[0] if num_cols else None
|
|
2482
|
+
|
|
2483
|
+
# _facet_col = _select_facet_col(df, _preferred)
|
|
2484
|
+
|
|
2485
|
+
# if _facet_col is None:
|
|
2486
|
+
# # Nothing suitable → single facet keeps pipeline alive
|
|
2487
|
+
# df["__facet__"] = "All"
|
|
2488
|
+
# else:
|
|
2489
|
+
# s = df[_facet_col]
|
|
2490
|
+
# if pd.api.types.is_numeric_dtype(s):
|
|
2491
|
+
# # Robust numeric binning: quantiles first, fallback to equal-width
|
|
2492
|
+
# uniq = pd.Series(s).dropna().nunique()
|
|
2493
|
+
# q = 3 if uniq < 10 else 4 if uniq < 30 else 5
|
|
2494
|
+
# try:
|
|
2495
|
+
# df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
|
|
2496
|
+
# except Exception:
|
|
2497
|
+
# df["__facet__"] = pd.cut(s.astype(float), bins=q)
|
|
2498
|
+
# else:
|
|
2499
|
+
# # Cap long tails; keep top categories
|
|
2500
|
+
# vc = s.astype(str).value_counts()
|
|
2501
|
+
# keep = vc.index[:{top_n}]
|
|
2502
|
+
# df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
|
|
2503
|
+
|
|
2504
|
+
# levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
|
|
2505
|
+
# levels = [x for x in levels if x != "nan"]
|
|
2506
|
+
# levels.sort()
|
|
2507
|
+
|
|
2508
|
+
# m = len(levels)
|
|
2509
|
+
# cols = 3 if m >= 3 else m or 1
|
|
2510
|
+
# rows = int(math.ceil(m / cols))
|
|
2511
|
+
|
|
2512
|
+
# fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
|
|
2513
|
+
# if not isinstance(axes, (list, np.ndarray)):
|
|
2514
|
+
# axes = np.array([[axes]])
|
|
2515
|
+
# axes = axes.reshape(rows, cols)
|
|
2516
|
+
|
|
2517
|
+
# for i, lvl in enumerate(levels):
|
|
2518
|
+
# r, c = divmod(i, cols)
|
|
2519
|
+
# ax = axes[r, c]
|
|
2520
|
+
# s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
|
|
2521
|
+
# .astype(str).value_counts().nlargest({top_n}))
|
|
2522
|
+
# if len(s) > 0:
|
|
2523
|
+
# ax.pie(s.values, labels=[str(x) for x in s.index],
|
|
2524
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2525
|
+
# ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
|
|
2526
|
+
|
|
2527
|
+
# # hide any empty subplots
|
|
2528
|
+
# for j in range(m, rows*cols):
|
|
2529
|
+
# r, c = divmod(j, cols)
|
|
2530
|
+
# axes[r, c].axis("off")
|
|
2531
|
+
|
|
2532
|
+
# plt.tight_layout(); plt.show()
|
|
2533
|
+
|
|
2534
|
+
# # --- companion visual (adaptive) ---
|
|
2535
|
+
# _comp_pref = "{_comp_pref}"
|
|
2536
|
+
# # build contingency table: pie_dim x facet
|
|
2537
|
+
# _tab = (df[["__facet__", "{pie_dim}"]]
|
|
2538
|
+
# .dropna()
|
|
2539
|
+
# .astype({{"__facet__": str, "{pie_dim}": str}})
|
|
2540
|
+
# .value_counts()
|
|
2541
|
+
# .unstack(level="__facet__")
|
|
2542
|
+
# .fillna(0))
|
|
2543
|
+
|
|
2544
|
+
# # keep top categories by overall size
|
|
2545
|
+
# _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2546
|
+
|
|
2547
|
+
# if _comp_pref == "grouped_bar":
|
|
2548
|
+
# ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2549
|
+
# ax.set_title("{pie_dim} by {facet_col} (grouped)")
|
|
2550
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2551
|
+
# plt.tight_layout(); plt.show()
|
|
2552
|
+
|
|
2553
|
+
# elif _comp_pref == "stacked_bar":
|
|
2554
|
+
# ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2555
|
+
# ax.set_title("{pie_dim} by {facet_col} (stacked)")
|
|
2556
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2557
|
+
# plt.tight_layout(); plt.show()
|
|
2558
|
+
|
|
2559
|
+
# elif _comp_pref == "stacked_bar_pct":
|
|
2560
|
+
# _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
|
|
2561
|
+
# ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
|
|
2562
|
+
# ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
|
|
2563
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
|
|
2564
|
+
# plt.tight_layout(); plt.show()
|
|
2565
|
+
|
|
2566
|
+
# elif _comp_pref == "heatmap":
|
|
2567
|
+
# _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
|
|
2568
|
+
# import numpy as np
|
|
2569
|
+
# fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
|
|
2570
|
+
# im = ax.imshow(_perc.values, aspect='auto')
|
|
2571
|
+
# ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2572
|
+
# ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2573
|
+
# ax.set_title("{pie_dim} by {facet_col} — % heatmap")
|
|
2574
|
+
# for i in range(_perc.shape[0]):
|
|
2575
|
+
# for j in range(_perc.shape[1]):
|
|
2576
|
+
# ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2577
|
+
# fig.colorbar(im, ax=ax, label="%")
|
|
2578
|
+
# plt.tight_layout(); plt.show()
|
|
2579
|
+
|
|
2580
|
+
# else: # counts_bar (default denominators)
|
|
2581
|
+
# _counts = df["__facet"].value_counts()
|
|
2582
|
+
# ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
|
|
2583
|
+
# ax.set_title("Counts by {facet_col}")
|
|
2584
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2585
|
+
# plt.tight_layout(); plt.show()
|
|
2586
|
+
|
|
2587
|
+
# """.lstrip()
|
|
2588
|
+
# return snippet
|
|
2589
|
+
|
|
2590
|
+
# # --------------- CASE C: single pie ---------------
|
|
2591
|
+
# chosen = None
|
|
2592
|
+
# for c in cats:
|
|
2593
|
+
# if c in df.columns and _is_cat(c):
|
|
2594
|
+
# chosen = c; break
|
|
2595
|
+
# if chosen is None:
|
|
2596
|
+
# chosen = _fallback_cat()
|
|
2597
|
+
|
|
2598
|
+
# if chosen:
|
|
2599
|
+
# snippet = f"""
|
|
2600
|
+
# import matplotlib.pyplot as plt
|
|
2601
|
+
# counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
|
|
2602
|
+
# fig, ax = plt.subplots()
|
|
2603
|
+
# if len(counts) > 0:
|
|
2604
|
+
# ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2605
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2606
|
+
# ax.set_title('Distribution of {chosen} (top {top_n})')
|
|
2607
|
+
# ax.axis('equal')
|
|
2608
|
+
# plt.show()
|
|
2609
|
+
# """.lstrip()
|
|
2610
|
+
# return snippet
|
|
2611
|
+
|
|
2612
|
+
# # numeric last resort
|
|
2613
|
+
# num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
2614
|
+
# if num_cols:
|
|
2615
|
+
# col = num_cols[0]
|
|
2616
|
+
# snippet = f"""
|
|
2617
|
+
# import pandas as pd
|
|
2618
|
+
# import matplotlib.pyplot as plt
|
|
2619
|
+
# bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
|
|
2620
|
+
# counts = bins.value_counts().sort_index()
|
|
2621
|
+
# fig, ax = plt.subplots()
|
|
2622
|
+
# if len(counts) > 0:
|
|
2623
|
+
# ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2624
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2625
|
+
# ax.set_title('Distribution of {col} (binned)')
|
|
2626
|
+
# ax.axis('equal')
|
|
2627
|
+
# plt.show()
|
|
2628
|
+
# """.lstrip()
|
|
2629
|
+
# return snippet
|
|
2630
|
+
|
|
2631
|
+
# return code
|
|
2632
|
+
|
|
2633
|
+
|
|
2634
|
+
# def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2635
|
+
# """
|
|
2636
|
+
# Removes seaborn `palette=` when no `hue=` is present in the same call.
|
|
2637
|
+
# Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
|
|
2638
|
+
# """
|
|
2639
|
+
# if "sns." not in code:
|
|
2640
|
+
# return code
|
|
2641
|
+
|
|
2642
|
+
# # Targets common seaborn plotters
|
|
2643
|
+
# funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
|
|
2644
|
+
# pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
|
|
2645
|
+
|
|
2646
|
+
# def _fix_call(m):
|
|
2647
|
+
# head, inner = m.group(1), m.group(2)
|
|
2648
|
+
# # If there's already hue=, keep as is
|
|
2649
|
+
# if re.search(r"(?<!\w)hue\s*=", inner):
|
|
2650
|
+
# return f"{head}{inner})"
|
|
2651
|
+
# # Otherwise remove palette=... safely (and any adjacent comma spacing)
|
|
2652
|
+
# inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
|
|
2653
|
+
# inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
|
|
2654
|
+
# inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
|
|
2655
|
+
# return f"{head}{inner2})"
|
|
2656
|
+
|
|
2657
|
+
# return pattern.sub(_fix_call, code)
|
|
2658
|
+
|
|
2659
|
+
# def _norm_col_name(s: str) -> str:
|
|
2660
|
+
# """normalise a column name: lowercase + strip non-alphanumerics."""
|
|
2661
|
+
# return re.sub(r"[^a-z0-9]+", "", str(s).lower())
|
|
2662
|
+
|
|
2663
|
+
|
|
2664
|
+
# def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
2665
|
+
# """return the actual df column that matches any candidate (after normalisation)."""
|
|
2666
|
+
# norm_map = {_norm_col_name(c): c for c in df.columns}
|
|
2667
|
+
# for cand in candidates:
|
|
2668
|
+
# hit = norm_map.get(_norm_col_name(cand))
|
|
2669
|
+
# if hit is not None:
|
|
2670
|
+
# return hit
|
|
2671
|
+
# return None
|
|
2672
|
+
|
|
2673
|
+
|
|
2674
|
+
# def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
|
|
2675
|
+
# """
|
|
2676
|
+
# If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
|
|
2677
|
+
# Returns (df, found_bool).
|
|
2678
|
+
# """
|
|
2679
|
+
# if target in df.columns:
|
|
2680
|
+
# return df, True
|
|
2681
|
+
# col = _first_present(df, [target, *aliases])
|
|
2682
|
+
# if col is None:
|
|
2683
|
+
# return df, False
|
|
2684
|
+
# df[target] = df[col]
|
|
2685
|
+
# return df, True
|
|
2686
|
+
|
|
2687
|
+
|
|
2688
|
+
# def strip_python_dotenv(code: str) -> str:
|
|
2689
|
+
# """
|
|
2690
|
+
# Remove any use of python-dotenv from generated code, including:
|
|
2691
|
+
# - single and multi-line 'from dotenv import ...'
|
|
2692
|
+
# - 'import dotenv' (with or without alias) and calls via any alias
|
|
2693
|
+
# - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
|
|
2694
|
+
# - IPython magics (%load_ext dotenv, %dotenv, %env …)
|
|
2695
|
+
# - shell installs like '!pip install python-dotenv'
|
|
2696
|
+
# """
|
|
2697
|
+
# original = code
|
|
2698
|
+
|
|
2699
|
+
# # 0) Kill IPython magics & shell installs referencing dotenv
|
|
2700
|
+
# code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
|
|
2701
|
+
# code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
|
|
2702
|
+
# code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
|
|
2703
|
+
# code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
|
|
2704
|
+
|
|
2705
|
+
# # 1) Remove single-line 'from dotenv import ...'
|
|
2706
|
+
# code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
|
|
2707
|
+
|
|
2708
|
+
# # 2) Remove multi-line 'from dotenv import ( ... )' blocks
|
|
2709
|
+
# code = re.sub(
|
|
2710
|
+
# r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
|
|
2711
|
+
# "",
|
|
2712
|
+
# code,
|
|
2713
|
+
# flags=re.MULTILINE,
|
|
2714
|
+
# )
|
|
2715
|
+
|
|
2716
|
+
# # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
|
|
2717
|
+
# aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
|
|
2718
|
+
# code, flags=re.MULTILINE)
|
|
2719
|
+
# code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
|
|
2720
|
+
# "", code, flags=re.MULTILINE)
|
|
2721
|
+
|
|
2722
|
+
# # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
|
|
2723
|
+
# # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
|
|
2724
|
+
# fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
|
|
2725
|
+
# # bare calls
|
|
2726
|
+
# code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2727
|
+
# # dotted calls with any identifier prefix (alias or module)
|
|
2728
|
+
# code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
|
|
2729
|
+
# "", code, flags=re.MULTILINE)
|
|
2730
|
+
|
|
2731
|
+
# # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
|
|
2732
|
+
# for al in aliases:
|
|
2733
|
+
# code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2734
|
+
|
|
2735
|
+
# # 6) Tidy excess blank lines
|
|
2736
|
+
# code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
|
|
2737
|
+
# return code
|
|
2738
|
+
|
|
2739
|
+
|
|
2740
|
+
# def fix_predict_calls_records_arg(code: str) -> str:
|
|
2741
|
+
# """
|
|
2742
|
+
# If generated code calls predict_* with a list-of-dicts via .to_dict('records')
|
|
2743
|
+
# (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
|
|
2744
|
+
# Works line-by-line to avoid over-rewrites elsewhere.
|
|
2745
|
+
# Examples fixed:
|
|
2746
|
+
# predict_patient(X_test.iloc[:5].to_dict('records'))
|
|
2747
|
+
# predict_risk(df.head(3).to_dict(orient="records"))
|
|
2748
|
+
# → predict_patient(X_test.iloc[:5])
|
|
2749
|
+
# """
|
|
2750
|
+
# fixed_lines = []
|
|
2751
|
+
# for line in code.splitlines():
|
|
2752
|
+
# if "predict_" in line and "to_dict" in line and "records" in line:
|
|
2753
|
+
# line = re.sub(
|
|
2754
|
+
# r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
|
|
2755
|
+
# "",
|
|
2756
|
+
# line
|
|
2757
|
+
# )
|
|
2758
|
+
# fixed_lines.append(line)
|
|
2759
|
+
# return "\n".join(fixed_lines)
|
|
2760
|
+
|
|
2761
|
+
|
|
2762
|
+
# def fix_fstring_backslash_paths(code: str) -> str:
|
|
2763
|
+
# """
|
|
2764
|
+
# Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
|
|
2765
|
+
# → f"...{os.path.join(out_dir, r'plots\\img.png')}"
|
|
2766
|
+
# Only touches f-strings that contain a backslash path inside {...}.
|
|
2767
|
+
# """
|
|
2768
|
+
# def _fix_line(line: str) -> str:
|
|
2769
|
+
# # quick check: only f-strings need scanning
|
|
2770
|
+
# if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
|
|
2771
|
+
# return line
|
|
2772
|
+
# # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
|
|
2773
|
+
# pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
|
|
2774
|
+
# def repl(m):
|
|
2775
|
+
# left = m.group(1)
|
|
2776
|
+
# right = m.group(2).strip().replace('"', '\\"')
|
|
2777
|
+
# return "{os.path.join(" + left + ', r"' + right + '")}'
|
|
2778
|
+
# return pattern.sub(repl, line)
|
|
2779
|
+
|
|
2780
|
+
# return "\n".join(_fix_line(ln) for ln in code.splitlines())
|
|
2781
|
+
|
|
2782
|
+
|
|
2783
|
+
# def ensure_os_import(code: str) -> str:
|
|
2784
|
+
# """
|
|
2785
|
+
# If os.path.join is used but 'import os' is missing, inject it at the top.
|
|
2786
|
+
# """
|
|
2787
|
+
# needs = "os.path.join(" in code
|
|
2788
|
+
# has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
|
|
2789
|
+
# has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
|
|
2790
|
+
# if needs and not (has_import_os or has_from_os):
|
|
2791
|
+
# return "import os\n" + code
|
|
2792
|
+
# return code
|
|
2793
|
+
|
|
2794
|
+
|
|
2795
|
+
# def fix_seaborn_boxplot_nameerror(code: str) -> str:
|
|
2796
|
+
# """
|
|
2797
|
+
# Fix bad calls like: sns.boxplot(boxplot)
|
|
2798
|
+
# Heuristic:
|
|
2799
|
+
# - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
|
|
2800
|
+
# - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
|
|
2801
|
+
# - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
|
|
2802
|
+
# - Else → sns.boxplot(ax=ax)
|
|
2803
|
+
# Ensures a matplotlib Axes 'ax' exists.
|
|
2804
|
+
# """
|
|
2805
|
+
# pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
|
|
2806
|
+
# if not pattern.search(code):
|
|
2807
|
+
# return code
|
|
2808
|
+
|
|
2809
|
+
# has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2810
|
+
# has_var = re.search(r"\bvar\b", code) is not None
|
|
2811
|
+
# has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2812
|
+
|
|
2813
|
+
# if has_plot_df and has_var and has_fh:
|
|
2814
|
+
# replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2815
|
+
# elif has_plot_df and has_var:
|
|
2816
|
+
# replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
|
|
2817
|
+
# elif has_plot_df:
|
|
2818
|
+
# replacement = "sns.boxplot(data=plot_df, ax=ax)"
|
|
2819
|
+
# else:
|
|
2820
|
+
# replacement = "sns.boxplot(ax=ax)"
|
|
2821
|
+
|
|
2822
|
+
# fixed = pattern.sub(replacement, code)
|
|
2823
|
+
|
|
2824
|
+
# # Ensure 'fig, ax = plt.subplots(...)' exists
|
|
2825
|
+
# if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
|
|
2826
|
+
# # Insert right before the first seaborn call
|
|
2827
|
+
# m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
|
|
2828
|
+
# insert_at = m.start() if m else 0
|
|
2829
|
+
# fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
|
|
2830
|
+
|
|
2831
|
+
# return fixed
|
|
2832
|
+
|
|
2833
|
+
|
|
2834
|
+
# def fix_seaborn_barplot_nameerror(code: str) -> str:
|
|
2835
|
+
# """
|
|
2836
|
+
# Fix bad calls like: sns.barplot(barplot)
|
|
2837
|
+
# Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
|
|
2838
|
+
# otherwise degrade safely to an empty call on an existing Axes.
|
|
2839
|
+
# """
|
|
2840
|
+
# import re
|
|
2841
|
+
# pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
|
|
2842
|
+
# if not pattern.search(code):
|
|
2843
|
+
# return code
|
|
2844
|
+
|
|
2845
|
+
# has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2846
|
+
# has_var = re.search(r"\bvar\b", code) is not None
|
|
2847
|
+
# has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2848
|
+
|
|
2849
|
+
# if has_plot_df and has_var and has_fh:
|
|
2850
|
+
# replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2851
|
+
# elif has_plot_df and has_var:
|
|
2852
|
+
# replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2853
|
+
# elif has_plot_df:
|
|
2854
|
+
# replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2855
|
+
# else:
|
|
2856
|
+
# replacement = "sns.barplot(ax=ax)"
|
|
2857
|
+
|
|
2858
|
+
# # ensure an Axes 'ax' exists (no-op if already present)
|
|
2859
|
+
# if "ax =" not in code:
|
|
2860
|
+
# code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2861
|
+
|
|
2862
|
+
# return pattern.sub(replacement, code)
|
|
2863
|
+
|
|
2864
|
+
|
|
2865
|
+
# def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2866
|
+
# """
|
|
2867
|
+
# Parses the raw text to extract and format the 'refined question',
|
|
2868
|
+
# 'intents (tasks)', and 'chronology of tasks' sections.
|
|
2869
|
+
# Args:
|
|
2870
|
+
# raw_text: The complete input string containing the ML pipeline structure.
|
|
2871
|
+
# Returns:
|
|
2872
|
+
# A tuple containing:
|
|
2873
|
+
# (formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2874
|
+
# """
|
|
2875
|
+
# # --- 1. Regex Pattern to Extract Sections ---
|
|
2876
|
+
# # The pattern uses capturing groups (?) to look for the section headers
|
|
2877
|
+
# # (e.g., 'refined question:') and captures all the content until the next
|
|
2878
|
+
# # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2879
|
+
|
|
2880
|
+
# pattern = re.compile(
|
|
2881
|
+
# r"refined question:(?P<question>.*?)"
|
|
2882
|
+
# r"intents \(tasks\):(?P<intents>.*?)"
|
|
2883
|
+
# r"Chronology of tasks:(?P<chronology>.*)",
|
|
2884
|
+
# re.IGNORECASE | re.DOTALL
|
|
2885
|
+
# )
|
|
2886
|
+
|
|
2887
|
+
# match = pattern.search(raw_text)
|
|
2888
|
+
|
|
2889
|
+
# if not match:
|
|
2890
|
+
# raise ValueError("Input text structure does not match the expected pattern.")
|
|
1596
2891
|
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
out_lines.append(ln)
|
|
1602
|
-
continue
|
|
2892
|
+
# # --- 2. Extract Content ---
|
|
2893
|
+
# question_content = match.group('question').strip()
|
|
2894
|
+
# intents_content = match.group('intents').strip()
|
|
2895
|
+
# chronology_content = match.group('chronology').strip()
|
|
1603
2896
|
|
|
1604
|
-
|
|
1605
|
-
# plt.show()
|
|
1606
|
-
# plt.tight_layout(); plt.show()
|
|
1607
|
-
# ... ; plt.show(); ... (multiple on one line)
|
|
1608
|
-
indent = re.match(r"^(\s*)", ln).group(1)
|
|
1609
|
-
parts = ln.split("plt.show()")
|
|
2897
|
+
# # --- 3. Formatting Functions ---
|
|
1610
2898
|
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
2899
|
+
# def format_question(content):
|
|
2900
|
+
# """Formats the Refined Question section."""
|
|
2901
|
+
# # Clean up leading/trailing whitespace and ensure clean paragraphs
|
|
2902
|
+
# content = content.strip().replace('\n', ' ').replace(' ', ' ')
|
|
2903
|
+
|
|
2904
|
+
# # Simple formatting using Markdown headers and bolding
|
|
2905
|
+
# formatted = (
|
|
2906
|
+
# # "## 1. Project Goal and Objectives\n\n"
|
|
2907
|
+
# "<b> Refined Question:</b>\n"
|
|
2908
|
+
# f"{content}\n"
|
|
2909
|
+
# )
|
|
2910
|
+
# return formatted
|
|
2911
|
+
|
|
2912
|
+
# def format_intents(content):
|
|
2913
|
+
# """Formats the Intents (Tasks) section as a structured list."""
|
|
2914
|
+
# # Use regex to find and format each numbered task
|
|
2915
|
+
# # It finds 'N. **Text** - ...' and breaks it down.
|
|
2916
|
+
|
|
2917
|
+
# tasks = []
|
|
2918
|
+
# # Pattern: N. **Text** - Content (including newlines, non-greedy)
|
|
2919
|
+
# # We need to explicitly handle the list items starting with '-' within the content
|
|
2920
|
+
# task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
|
|
2921
|
+
|
|
2922
|
+
# # Split the content by lines and join tasks back into clean strings
|
|
2923
|
+
# raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
|
|
2924
|
+
|
|
2925
|
+
# for task in raw_tasks:
|
|
2926
|
+
# # Replace the initial task number and **Heading** with a Heading 3
|
|
2927
|
+
# task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
|
|
2928
|
+
|
|
2929
|
+
# # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
|
|
2930
|
+
# task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
|
|
2931
|
+
# tasks.append(task)
|
|
2932
|
+
|
|
2933
|
+
# formatted_tasks = "\n\n".join(tasks)
|
|
2934
|
+
|
|
2935
|
+
# return (
|
|
2936
|
+
# "\n---\n"
|
|
2937
|
+
# "## 2. Methodology and Tasks\n\n"
|
|
2938
|
+
# f"{formatted_tasks}\n"
|
|
2939
|
+
# )
|
|
2940
|
+
|
|
2941
|
+
# def format_chronology(content):
|
|
2942
|
+
# """Formats the Chronology section."""
|
|
2943
|
+
# # Uses the given LaTeX format
|
|
2944
|
+
# content = content.strip().replace(' ', ' \rightarrow ')
|
|
2945
|
+
# formatted = (
|
|
2946
|
+
# "\n---\n"
|
|
2947
|
+
# "## 3. Chronology of Tasks\n"
|
|
2948
|
+
# f"$$\\text{{{content}}}$$"
|
|
2949
|
+
# )
|
|
2950
|
+
# return formatted
|
|
2951
|
+
|
|
2952
|
+
# # --- 4. Format and Return ---
|
|
2953
|
+
# formatted_question = format_question(question_content)
|
|
2954
|
+
# formatted_intents = format_intents(intents_content)
|
|
2955
|
+
# formatted_chronology = format_chronology(chronology_content)
|
|
2956
|
+
|
|
2957
|
+
# return formatted_question, formatted_intents, formatted_chronology
|
|
2958
|
+
|
|
2959
|
+
|
|
2960
|
+
# def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
2961
|
+
# """Combines all formatted parts into a final report string."""
|
|
2962
|
+
# return (
|
|
2963
|
+
# "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
2964
|
+
# f"{formatted_question}\n"
|
|
2965
|
+
# f"{formatted_intents}\n"
|
|
2966
|
+
# f"{formatted_chronology}\n"
|
|
2967
|
+
# )
|
|
2968
|
+
|
|
2969
|
+
|
|
2970
|
+
# def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
2971
|
+
# """
|
|
2972
|
+
# Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
2973
|
+
# from_predictions(...) which works for multi-label loops without requiring
|
|
2974
|
+
# the estimator to expose _estimator_type.
|
|
2975
|
+
# """
|
|
2976
|
+
# return re.sub(
|
|
2977
|
+
# r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
2978
|
+
# r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
2979
|
+
# code
|
|
2980
|
+
# )
|
|
2981
|
+
|
|
2982
|
+
|
|
2983
|
+
# def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
2984
|
+
# """
|
|
2985
|
+
# Ensure every Matplotlib/Seaborn Axes has a title.
|
|
2986
|
+
# Uses refined_question -> askai_question -> fallback.
|
|
2987
|
+
# Only sets a title if it's currently empty.
|
|
2988
|
+
# """
|
|
2989
|
+
# import matplotlib.pyplot as plt
|
|
2990
|
+
|
|
2991
|
+
# def _all_figures():
|
|
2992
|
+
# try:
|
|
2993
|
+
# from matplotlib._pylab_helpers import Gcf
|
|
2994
|
+
# return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
2995
|
+
# except Exception:
|
|
2996
|
+
# # Best effort fallback
|
|
2997
|
+
# nums = plt.get_fignums()
|
|
2998
|
+
# return [plt.figure(n) for n in nums] if nums else []
|
|
2999
|
+
|
|
3000
|
+
# # Choose a concise title
|
|
3001
|
+
# title = None
|
|
3002
|
+
# if isinstance(ctx, dict):
|
|
3003
|
+
# title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
3004
|
+
# title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
3005
|
+
|
|
3006
|
+
# for fig in _all_figures():
|
|
3007
|
+
# for ax in getattr(fig, "axes", []):
|
|
3008
|
+
# try:
|
|
3009
|
+
# if not (ax.get_title() or "").strip():
|
|
3010
|
+
# ax.set_title(title)
|
|
3011
|
+
# except Exception:
|
|
3012
|
+
# pass
|
|
3013
|
+
# try:
|
|
3014
|
+
# fig.tight_layout()
|
|
3015
|
+
# except Exception:
|
|
3016
|
+
# pass
|
|
3017
|
+
|
|
3018
|
+
|
|
3019
|
+
# def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
3020
|
+
# """
|
|
3021
|
+
# Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
3022
|
+
# SB_barplot(barplot) -> SB_barplot()
|
|
3023
|
+
# SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
3024
|
+
# sns.barplot(barplot) -> SB_barplot()
|
|
3025
|
+
# sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
3026
|
+
# Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
3027
|
+
# """
|
|
3028
|
+
# names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
3029
|
+
# for n in names:
|
|
3030
|
+
# # SB_* with sentinel as the first arg (with or without trailing args)
|
|
3031
|
+
# code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3032
|
+
# code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3033
|
+
# # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
3034
|
+
# code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3035
|
+
# code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3036
|
+
# return code
|
|
3037
|
+
|
|
3038
|
+
|
|
3039
|
+
# def patch_rmse_calls(code: str) -> str:
|
|
3040
|
+
# """
|
|
3041
|
+
# Make RMSE robust across sklearn versions.
|
|
3042
|
+
# - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
3043
|
+
# - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
3044
|
+
# """
|
|
3045
|
+
# import re
|
|
3046
|
+
# # (a) Specific RMSE pattern
|
|
3047
|
+
# code = re.sub(
|
|
3048
|
+
# r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3049
|
+
# r"_SMX_rmse(\1)",
|
|
3050
|
+
# code,
|
|
3051
|
+
# flags=re.DOTALL
|
|
3052
|
+
# )
|
|
3053
|
+
# # (b) Guard any other MSE calls
|
|
3054
|
+
# code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3055
|
+
# return code
|
|
1614
3056
|
|
|
1615
|
-
# for every plt.show() we removed, insert exporter at same indent
|
|
1616
|
-
for _ in range(len(parts) - 1):
|
|
1617
|
-
out_lines.append(indent + "_SMX_export_png()")
|
|
1618
|
-
|
|
1619
|
-
# keep whatever comes after the last plt.show()
|
|
1620
|
-
if parts[-1].strip():
|
|
1621
|
-
out_lines.append(indent + parts[-1].lstrip())
|
|
1622
|
-
|
|
1623
|
-
return "\n".join(out_lines)
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
def clean_llm_code(code: str) -> str:
|
|
1627
|
-
"""
|
|
1628
|
-
Make LLM output safe to exec:
|
|
1629
|
-
- If fenced blocks exist, keep the largest one (usually the real code).
|
|
1630
|
-
- Otherwise strip any stray ``` / ```python lines.
|
|
1631
|
-
- Remove common markdown/preamble junk.
|
|
1632
|
-
"""
|
|
1633
|
-
code = str(code or "")
|
|
1634
|
-
|
|
1635
|
-
# Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
|
|
1636
|
-
# is accidentally passed here as `code`. In that case, extract the actual
|
|
1637
|
-
# Python code from the ChatCompletionMessage(content=...) field.
|
|
1638
|
-
if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
|
|
1639
|
-
try:
|
|
1640
|
-
extracted = None
|
|
1641
|
-
|
|
1642
|
-
class _ChatCompletionVisitor(ast.NodeVisitor):
|
|
1643
|
-
def visit_Call(self, node):
|
|
1644
|
-
nonlocal extracted
|
|
1645
|
-
func = node.func
|
|
1646
|
-
fname = getattr(func, "id", None) or getattr(func, "attr", None)
|
|
1647
|
-
if fname == "ChatCompletionMessage":
|
|
1648
|
-
for kw in node.keywords:
|
|
1649
|
-
if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
|
|
1650
|
-
extracted = kw.value.value
|
|
1651
|
-
self.generic_visit(node)
|
|
1652
|
-
|
|
1653
|
-
tree = ast.parse(code, mode="exec")
|
|
1654
|
-
_ChatCompletionVisitor().visit(tree)
|
|
1655
|
-
if extracted:
|
|
1656
|
-
code = extracted
|
|
1657
|
-
except Exception:
|
|
1658
|
-
# Best-effort regex fallback if AST parsing fails
|
|
1659
|
-
m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
|
|
1660
|
-
if m:
|
|
1661
|
-
code = m.group("body")
|
|
1662
|
-
|
|
1663
|
-
# Existing logic continues unchanged below...
|
|
1664
|
-
# Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1665
|
-
blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1666
|
-
|
|
1667
|
-
if blocks:
|
|
1668
|
-
# pick the largest block; small trailing blocks are usually garbage
|
|
1669
|
-
largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1670
|
-
if len(largest.strip().splitlines()) >= 10:
|
|
1671
|
-
code = largest
|
|
1672
|
-
|
|
1673
|
-
# Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1674
|
-
blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1675
|
-
|
|
1676
|
-
if blocks:
|
|
1677
|
-
# pick the largest block; small trailing blocks are usually garbage
|
|
1678
|
-
largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1679
|
-
if len(largest.strip().splitlines()) >= 10:
|
|
1680
|
-
code = largest
|
|
1681
|
-
else:
|
|
1682
|
-
# if no meaningful block, just remove fence markers
|
|
1683
|
-
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1684
|
-
else:
|
|
1685
|
-
# no complete blocks — still remove any stray fence lines
|
|
1686
|
-
code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1687
|
-
|
|
1688
|
-
# Strip common markdown/preamble lines
|
|
1689
|
-
drop_prefixes = (
|
|
1690
|
-
"here is", "here's", "below is", "sure,", "certainly",
|
|
1691
|
-
"explanation", "note:", "```"
|
|
1692
|
-
)
|
|
1693
|
-
cleaned_lines = []
|
|
1694
|
-
for ln in code.splitlines():
|
|
1695
|
-
s = ln.strip().lower()
|
|
1696
|
-
if any(s.startswith(p) for p in drop_prefixes):
|
|
1697
|
-
continue
|
|
1698
|
-
cleaned_lines.append(ln)
|
|
1699
|
-
|
|
1700
|
-
return "\n".join(cleaned_lines).strip()
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
def fix_groupby_describe_slice(code: str) -> str:
|
|
1704
|
-
"""
|
|
1705
|
-
Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
|
|
1706
|
-
so it works for both SeriesGroupBy and DataFrameGroupBy.
|
|
1707
|
-
"""
|
|
1708
|
-
pat = re.compile(
|
|
1709
|
-
r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
|
|
1710
|
-
re.MULTILINE
|
|
1711
|
-
)
|
|
1712
|
-
def repl(match):
|
|
1713
|
-
inner = match.group(0)
|
|
1714
|
-
# extract group and feature to build df.groupby('g')['f']
|
|
1715
|
-
g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
|
|
1716
|
-
f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
|
|
1717
|
-
return (
|
|
1718
|
-
f"df.groupby('{g}')['{f}']"
|
|
1719
|
-
".agg(['count','mean','std','min','median','max'])"
|
|
1720
|
-
".rename(columns={'median':'50%'})"
|
|
1721
|
-
)
|
|
1722
|
-
return pat.sub(repl, code)
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
def fix_importance_groupby(code: str) -> str:
|
|
1726
|
-
pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
|
|
1727
|
-
if "importance_df" in code:
|
|
1728
|
-
return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
|
|
1729
|
-
return code
|
|
1730
|
-
|
|
1731
|
-
def inject_auto_preprocessing(code: str) -> str:
|
|
1732
|
-
"""
|
|
1733
|
-
• Detects a RandomForestClassifier in the generated code.
|
|
1734
|
-
• Finds the target column from `y = df['target']`.
|
|
1735
|
-
• Prepends a fully-dedented preprocessing snippet that:
|
|
1736
|
-
– auto-detects numeric & categorical columns
|
|
1737
|
-
– builds a ColumnTransformer (OneHotEncoder + StandardScaler)
|
|
1738
|
-
The dedent() call guarantees no leading-space IndentationError.
|
|
1739
|
-
"""
|
|
1740
|
-
if "RandomForestClassifier" not in code:
|
|
1741
|
-
return code # nothing to patch
|
|
1742
|
-
|
|
1743
|
-
y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
1744
|
-
if not y_match:
|
|
1745
|
-
return code # can't infer target safely
|
|
1746
|
-
target = y_match.group(1)
|
|
1747
|
-
|
|
1748
|
-
prep_snippet = textwrap.dedent(f"""
|
|
1749
|
-
# ── automatic preprocessing ───────────────────────────────
|
|
1750
|
-
num_cols = df.select_dtypes(include=['number']).columns.tolist()
|
|
1751
|
-
cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
|
1752
|
-
num_cols = [c for c in num_cols if c != '{target}']
|
|
1753
|
-
cat_cols = [c for c in cat_cols if c != '{target}']
|
|
1754
|
-
|
|
1755
|
-
from sklearn.compose import ColumnTransformer
|
|
1756
|
-
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
|
1757
|
-
|
|
1758
|
-
preproc = ColumnTransformer(
|
|
1759
|
-
transformers=[
|
|
1760
|
-
('num', StandardScaler(), num_cols),
|
|
1761
|
-
('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
|
|
1762
|
-
],
|
|
1763
|
-
remainder='drop',
|
|
1764
|
-
)
|
|
1765
|
-
# ───────────────────────────────────────────────────────────
|
|
1766
|
-
""").strip() + "\n\n"
|
|
1767
|
-
|
|
1768
|
-
# simply prepend; model code that follows can wrap estimator in a Pipeline
|
|
1769
|
-
return prep_snippet + code
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
def fix_to_datetime_errors(code: str) -> str:
|
|
1773
|
-
"""
|
|
1774
|
-
Force every pd.to_datetime(…) call to ignore bad dates so that
|
|
1775
|
-
|
|
1776
|
-
'year 16500 is out of range' and similar issues don’t crash runs.
|
|
1777
|
-
"""
|
|
1778
|
-
import re
|
|
1779
|
-
# look for any pd.to_datetime( … )
|
|
1780
|
-
pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
|
|
1781
|
-
def repl(m):
|
|
1782
|
-
inside = m.group(1)
|
|
1783
|
-
# if the call already has errors=, leave it unchanged
|
|
1784
|
-
if "errors=" in inside:
|
|
1785
|
-
return m.group(0)
|
|
1786
|
-
return f"pd.to_datetime({inside}, errors='coerce')"
|
|
1787
|
-
return pat.sub(repl, code)
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
def fix_numeric_sum(code: str) -> str:
|
|
1791
|
-
"""
|
|
1792
|
-
Make .sum(...) code safe across pandas versions by removing any
|
|
1793
|
-
numeric_only=... argument (True/False/None) from function calls.
|
|
1794
|
-
|
|
1795
|
-
This avoids errors on pandas versions where numeric_only is not
|
|
1796
|
-
supported for Series/grouped sums, and we rely instead on explicit
|
|
1797
|
-
numeric column selection (e.g. select_dtypes) in the generated code.
|
|
1798
|
-
"""
|
|
1799
|
-
# Case 1: ..., numeric_only=True/False/None
|
|
1800
|
-
code = re.sub(
|
|
1801
|
-
r",\s*numeric_only\s*=\s*(True|False|None)",
|
|
1802
|
-
"",
|
|
1803
|
-
code,
|
|
1804
|
-
flags=re.IGNORECASE,
|
|
1805
|
-
)
|
|
1806
|
-
|
|
1807
|
-
# Case 2: numeric_only=True/False/None, ... (as first argument)
|
|
1808
|
-
code = re.sub(
|
|
1809
|
-
r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
|
|
1810
|
-
"",
|
|
1811
|
-
code,
|
|
1812
|
-
flags=re.IGNORECASE,
|
|
1813
|
-
)
|
|
1814
|
-
|
|
1815
|
-
# Case 3: numeric_only=True/False/None (only argument)
|
|
1816
|
-
code = re.sub(
|
|
1817
|
-
r"numeric_only\s*=\s*(True|False|None)",
|
|
1818
|
-
"",
|
|
1819
|
-
code,
|
|
1820
|
-
flags=re.IGNORECASE,
|
|
1821
|
-
)
|
|
1822
|
-
|
|
1823
|
-
return code
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
def fix_concat_empty_list(code: str) -> str:
|
|
1827
|
-
"""
|
|
1828
|
-
Make pd.concat calls resilient to empty lists of objects.
|
|
1829
|
-
|
|
1830
|
-
Transforms patterns like:
|
|
1831
|
-
pd.concat(frames, ignore_index=True)
|
|
1832
|
-
pd.concat(frames)
|
|
1833
|
-
|
|
1834
|
-
into:
|
|
1835
|
-
pd.concat(frames or [pd.DataFrame()], ignore_index=True)
|
|
1836
|
-
pd.concat(frames or [pd.DataFrame()])
|
|
1837
|
-
|
|
1838
|
-
Only triggers when the first argument is a simple variable name.
|
|
1839
|
-
"""
|
|
1840
|
-
pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
|
|
1841
|
-
|
|
1842
|
-
def _repl(m):
|
|
1843
|
-
name = m.group(1)
|
|
1844
|
-
sep = m.group(2) # ',' or ')'
|
|
1845
|
-
return f"pd.concat({name} or [pd.DataFrame()]{sep}"
|
|
1846
|
-
|
|
1847
|
-
return pattern.sub(_repl, code)
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
def fix_numeric_aggs(code: str) -> str:
|
|
1851
|
-
_AGG_FUNCS = ("sum", "mean")
|
|
1852
|
-
pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
|
|
1853
|
-
def _repl(m):
|
|
1854
|
-
func, args = m.group(1), m.group(2) or ""
|
|
1855
|
-
if "numeric_only" in args:
|
|
1856
|
-
return m.group(0)
|
|
1857
|
-
args = args.rstrip()
|
|
1858
|
-
if args:
|
|
1859
|
-
args += ", "
|
|
1860
|
-
return f".{func}({args}numeric_only=True)"
|
|
1861
|
-
return pat.sub(_repl, code)
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
def ensure_accuracy_block(code: str) -> str:
|
|
1865
|
-
"""
|
|
1866
|
-
Inject a sensible evaluation block right after the last `<est>.fit(...)`
|
|
1867
|
-
Classification → accuracy + weighted F1
|
|
1868
|
-
Regression → R², RMSE, MAE
|
|
1869
|
-
Heuristic: infer task from estimator names present in the code.
|
|
1870
|
-
"""
|
|
1871
|
-
import re, textwrap
|
|
1872
|
-
|
|
1873
|
-
# If any proper metric already exists, do nothing
|
|
1874
|
-
if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
|
|
1875
|
-
return code
|
|
1876
|
-
|
|
1877
|
-
# Find the last "<var>.fit(" occurrence to reuse the estimator variable name
|
|
1878
|
-
m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
|
|
1879
|
-
if not m:
|
|
1880
|
-
return code # no estimator
|
|
1881
|
-
|
|
1882
|
-
var = m[-1].group(1)
|
|
1883
|
-
# indent with same leading whitespace used on that line
|
|
1884
|
-
indent = re.match(r"\s*", code[m[-1].start():]).group(0)
|
|
1885
|
-
|
|
1886
|
-
# Detect regression by estimator names / hints in code
|
|
1887
|
-
is_regression = bool(
|
|
1888
|
-
re.search(
|
|
1889
|
-
r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
1890
|
-
r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
1891
|
-
r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
1892
|
-
)
|
|
1893
|
-
or re.search(r"\bOLS\s*\(", code)
|
|
1894
|
-
or re.search(r"\bRegressor\b", code)
|
|
1895
|
-
)
|
|
1896
|
-
|
|
1897
|
-
if is_regression:
|
|
1898
|
-
# inject numpy import if needed for RMSE
|
|
1899
|
-
if "import numpy as np" not in code and "np." not in code:
|
|
1900
|
-
code = "import numpy as np\n" + code
|
|
1901
|
-
eval_block = textwrap.dedent(f"""
|
|
1902
|
-
{indent}# ── automatic regression evaluation ─────────
|
|
1903
|
-
{indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
|
1904
|
-
{indent}y_pred = {var}.predict(X_test)
|
|
1905
|
-
{indent}r2 = r2_score(y_test, y_pred)
|
|
1906
|
-
{indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
|
1907
|
-
{indent}mae = float(mean_absolute_error(y_test, y_pred))
|
|
1908
|
-
{indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
|
|
1909
|
-
""")
|
|
1910
|
-
else:
|
|
1911
|
-
eval_block = textwrap.dedent(f"""
|
|
1912
|
-
{indent}# ── automatic classification evaluation ─────────
|
|
1913
|
-
{indent}from sklearn.metrics import accuracy_score, f1_score
|
|
1914
|
-
{indent}y_pred = {var}.predict(X_test)
|
|
1915
|
-
{indent}acc = accuracy_score(y_test, y_pred)
|
|
1916
|
-
{indent}f1 = f1_score(y_test, y_pred, average='weighted')
|
|
1917
|
-
{indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
|
|
1918
|
-
""")
|
|
1919
|
-
|
|
1920
|
-
insert_at = code.find("\n", m[-1].end()) + 1
|
|
1921
|
-
return code[:insert_at] + eval_block + code[insert_at:]
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
def fix_scatter_and_summary(code: str) -> str:
|
|
1925
|
-
"""
|
|
1926
|
-
1. Change cmap='spectral' (any case) → cmap='Spectral'
|
|
1927
|
-
2. If the LLM forgets to close the parenthesis in
|
|
1928
|
-
summary_table = ( df.describe()... <missing )>
|
|
1929
|
-
insert the ')' right before the next 'from' or 'show('.
|
|
1930
|
-
"""
|
|
1931
|
-
# 1️⃣ colormap case
|
|
1932
|
-
code = re.sub(
|
|
1933
|
-
r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
|
|
1934
|
-
"cmap='Spectral'",
|
|
1935
|
-
code,
|
|
1936
|
-
flags=re.IGNORECASE,
|
|
1937
|
-
)
|
|
1938
|
-
|
|
1939
|
-
# 2️⃣ close summary_table = ( ... )
|
|
1940
|
-
code = re.sub(
|
|
1941
|
-
r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
|
|
1942
|
-
r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
|
|
1943
|
-
r"\1)", # keep group 1 and add ')'
|
|
1944
|
-
code,
|
|
1945
|
-
flags=re.MULTILINE,
|
|
1946
|
-
)
|
|
1947
|
-
|
|
1948
|
-
return code
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
def auto_format_with_black(code: str) -> str:
|
|
1952
|
-
"""
|
|
1953
|
-
Format the generated code with Black. Falls back silently if Black
|
|
1954
|
-
is missing or raises (so the dashboard never 500s).
|
|
1955
|
-
"""
|
|
1956
|
-
try:
|
|
1957
|
-
import black # make sure black is in your v-env: pip install black
|
|
1958
|
-
|
|
1959
|
-
mode = black.FileMode() # default settings
|
|
1960
|
-
return black.format_str(code, mode=mode)
|
|
1961
|
-
|
|
1962
|
-
except Exception:
|
|
1963
|
-
return code
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
def ensure_preproc_in_pipeline(code: str) -> str:
|
|
1967
|
-
"""
|
|
1968
|
-
If code defines `preproc = ColumnTransformer(...)` but then builds
|
|
1969
|
-
`Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
|
|
1970
|
-
that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
|
|
1971
|
-
"""
|
|
1972
|
-
return re.sub(
|
|
1973
|
-
r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
|
|
1974
|
-
"Pipeline([('prep', preproc)",
|
|
1975
|
-
code
|
|
1976
|
-
)
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
def fix_plain_prints(code: str) -> str:
|
|
1980
|
-
"""
|
|
1981
|
-
Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
|
|
1982
|
-
to go through SyntaxMatrix's smart display (so it renders in the dashboard).
|
|
1983
|
-
Keeps string prints alone.
|
|
1984
|
-
"""
|
|
1985
|
-
|
|
1986
|
-
# Skip obvious string-literal prints
|
|
1987
|
-
new = re.sub(
|
|
1988
|
-
r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
|
|
1989
|
-
r"from syntaxmatrix.display import show\nshow(\1)",
|
|
1990
|
-
code,
|
|
1991
|
-
)
|
|
1992
|
-
return new
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
def fix_print_html(code: str) -> str:
|
|
1996
|
-
"""
|
|
1997
|
-
Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
|
|
1998
|
-
not printed as `<IPython.core.display.HTML object>` to the server console.
|
|
1999
|
-
- Rewrites: print(HTML(...)) → display(HTML(...))
|
|
2000
|
-
print(display(...)) → display(...)
|
|
2001
|
-
print(df.to_html(...)) → display(HTML(df.to_html(...)))
|
|
2002
|
-
Also prepends `from IPython.display import display, HTML` if required.
|
|
2003
|
-
"""
|
|
2004
|
-
import re
|
|
2005
|
-
|
|
2006
|
-
new = code
|
|
2007
|
-
|
|
2008
|
-
# 1) print(HTML(...)) -> display(HTML(...))
|
|
2009
|
-
new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
|
|
2010
|
-
|
|
2011
|
-
# 2) print(display(...)) -> display(...)
|
|
2012
|
-
new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
|
|
2013
|
-
|
|
2014
|
-
# 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
|
|
2015
|
-
new = re.sub(
|
|
2016
|
-
r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
|
|
2017
|
-
r"display(HTML(\1.to_html(", new
|
|
2018
|
-
)
|
|
2019
|
-
|
|
2020
|
-
# If code references HTML() or display() make sure the import exists
|
|
2021
|
-
if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
|
|
2022
|
-
"from IPython.display import display, HTML" not in new:
|
|
2023
|
-
new = "from IPython.display import display, HTML\n" + new
|
|
2024
|
-
|
|
2025
|
-
return new
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
def ensure_ipy_display(code: str) -> str:
|
|
2029
|
-
"""
|
|
2030
|
-
Guarantee that the cell has proper IPython display imports so that
|
|
2031
|
-
display(HTML(...)) produces 'display_data' events the kernel captures.
|
|
2032
|
-
"""
|
|
2033
|
-
if "display(" in code and "from IPython.display import display, HTML" not in code:
|
|
2034
|
-
return "from IPython.display import display, HTML\n" + code
|
|
2035
|
-
return code
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
2039
|
-
"""
|
|
2040
|
-
Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
|
|
2041
|
-
if the generated cell is *regression*. We infer this from:
|
|
2042
|
-
1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
|
|
2043
|
-
2) The target dtype if we can parse y = df['...'] and have the DataFrame.
|
|
2044
|
-
Safe across datasets and queries.
|
|
2045
|
-
"""
|
|
2046
|
-
import re
|
|
2047
|
-
import pandas as pd
|
|
2048
|
-
|
|
2049
|
-
# 1) Heuristic by estimator names in the *code* (fast path)
|
|
2050
|
-
regression_by_model = bool(re.search(
|
|
2051
|
-
r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
2052
|
-
r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
2053
|
-
r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
2054
|
-
) or re.search(r"\bOLS\s*\(", code))
|
|
2055
|
-
|
|
2056
|
-
is_regression = regression_by_model
|
|
2057
|
-
|
|
2058
|
-
# 2) If not obvious from the model, try to infer from y dtype (if we can)
|
|
2059
|
-
if not is_regression:
|
|
2060
|
-
try:
|
|
2061
|
-
# Try to parse: y = df['target']
|
|
2062
|
-
m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
2063
|
-
if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
|
|
2064
|
-
y = y_or_df[m.group(1)]
|
|
2065
|
-
if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2066
|
-
is_regression = True
|
|
2067
|
-
else:
|
|
2068
|
-
# If a Series was passed
|
|
2069
|
-
y = y_or_df
|
|
2070
|
-
if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2071
|
-
is_regression = True
|
|
2072
|
-
except Exception:
|
|
2073
|
-
pass
|
|
2074
|
-
|
|
2075
|
-
if is_regression:
|
|
2076
|
-
# Strip classification-only lines
|
|
2077
|
-
for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
|
|
2078
|
-
code = re.sub(pat, "", code, flags=re.I)
|
|
2079
|
-
|
|
2080
|
-
return code
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
def force_capture_display(code: str) -> str:
|
|
2084
|
-
"""
|
|
2085
|
-
Ensure our executor captures HTML output:
|
|
2086
|
-
- Remove any import that would override our 'display' hook.
|
|
2087
|
-
- Keep/allow importing HTML only.
|
|
2088
|
-
- Handle alias cases like 'display as d'.
|
|
2089
|
-
"""
|
|
2090
|
-
import re
|
|
2091
|
-
new = code
|
|
2092
|
-
|
|
2093
|
-
# 'from IPython.display import display, HTML' -> keep HTML only
|
|
2094
|
-
new = re.sub(
|
|
2095
|
-
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
|
|
2096
|
-
r"from IPython.display import HTML\1", new
|
|
2097
|
-
)
|
|
2098
|
-
|
|
2099
|
-
# 'from IPython.display import display as d' -> 'd = display'
|
|
2100
|
-
new = re.sub(
|
|
2101
|
-
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
|
|
2102
|
-
r"\1 = display", new
|
|
2103
|
-
)
|
|
2104
|
-
|
|
2105
|
-
# 'from IPython.display import display' -> remove (use our injected display)
|
|
2106
|
-
new = re.sub(
|
|
2107
|
-
r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
|
|
2108
|
-
r"# display import removed (SMX capture active)", new
|
|
2109
|
-
)
|
|
2110
|
-
|
|
2111
|
-
# If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
|
|
2112
|
-
new = re.sub(
|
|
2113
|
-
r"(?m)\bIPython\.display\.display\s*\(",
|
|
2114
|
-
"display(", new
|
|
2115
|
-
)
|
|
2116
|
-
new = re.sub(
|
|
2117
|
-
r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
|
|
2118
|
-
r"(?=.*import\s+IPython\.display\s+as\s+\1)",
|
|
2119
|
-
"display(", new
|
|
2120
|
-
)
|
|
2121
|
-
return new
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
def strip_matplotlib_show(code: str) -> str:
|
|
2125
|
-
"""Remove blocking plt.show() calls (we export base64 instead)."""
|
|
2126
|
-
import re
|
|
2127
|
-
return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
def inject_display_shim(code: str) -> str:
|
|
2131
|
-
"""
|
|
2132
|
-
Provide display()/HTML() if missing, forwarding to our executor hook.
|
|
2133
|
-
Harmless if the names already exist.
|
|
2134
|
-
"""
|
|
2135
|
-
shim = (
|
|
2136
|
-
"try:\n"
|
|
2137
|
-
" display\n"
|
|
2138
|
-
"except NameError:\n"
|
|
2139
|
-
" def display(obj=None, **kwargs):\n"
|
|
2140
|
-
" __builtins__.get('_smx_display', print)(obj)\n"
|
|
2141
|
-
"try:\n"
|
|
2142
|
-
" HTML\n"
|
|
2143
|
-
"except NameError:\n"
|
|
2144
|
-
" class HTML:\n"
|
|
2145
|
-
" def __init__(self, data): self.data = str(data)\n"
|
|
2146
|
-
" def _repr_html_(self): return self.data\n"
|
|
2147
|
-
"\n"
|
|
2148
|
-
)
|
|
2149
|
-
return shim + code
|
|
2150
|
-
|
|
2151
|
-
|
|
2152
|
-
def strip_spurious_column_tokens(code: str) -> str:
|
|
2153
|
-
"""
|
|
2154
|
-
Remove common stop-words ('the','whether', ...) when they appear
|
|
2155
|
-
inside column lists, e.g.:
|
|
2156
|
-
predictors = ['BMI','the','HbA1c']
|
|
2157
|
-
df[['GGT','whether','BMI']]
|
|
2158
|
-
Leaves other strings intact.
|
|
2159
|
-
"""
|
|
2160
|
-
STOP = {
|
|
2161
|
-
"the","whether","a","an","and","or","of","to","in","on","for","by",
|
|
2162
|
-
"with","as","at","from","that","this","these","those","is","are","was","were",
|
|
2163
|
-
"coef", "Coef", "coefficient", "Coefficient"
|
|
2164
|
-
}
|
|
2165
|
-
|
|
2166
|
-
def _norm(s: str) -> str:
|
|
2167
|
-
return re.sub(r"[^a-z0-9]+", "", s.lower())
|
|
2168
|
-
|
|
2169
|
-
def _clean_list(content: str) -> str:
|
|
2170
|
-
# Rebuild a string list, keeping only non-stopword items
|
|
2171
|
-
items = re.findall(r"(['\"])(.*?)\1", content)
|
|
2172
|
-
if not items:
|
|
2173
|
-
return "[" + content + "]"
|
|
2174
|
-
keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
|
|
2175
|
-
return "[" + ", ".join(keep) + "]"
|
|
2176
|
-
|
|
2177
|
-
# Variable assignments: predictors/features/columns/cols = [...]
|
|
2178
|
-
code = re.sub(
|
|
2179
|
-
r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
|
|
2180
|
-
lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
|
|
2181
|
-
code
|
|
2182
|
-
)
|
|
2183
|
-
|
|
2184
|
-
# df[[ ... ]] selections
|
|
2185
|
-
code = re.sub(
|
|
2186
|
-
r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
|
|
2187
|
-
|
|
2188
|
-
return code
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
def patch_prefix_seaborn_calls(code: str) -> str:
|
|
2192
|
-
"""
|
|
2193
|
-
Ensure bare seaborn calls are prefixed with `sns.`.
|
|
2194
|
-
E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
|
|
2195
|
-
"""
|
|
2196
|
-
if "sns." in code:
|
|
2197
|
-
# still fix any leftover bare calls alongside prefixed ones
|
|
2198
|
-
pass
|
|
2199
|
-
|
|
2200
|
-
# functions commonly used from seaborn
|
|
2201
|
-
funcs = [
|
|
2202
|
-
"barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
|
|
2203
|
-
"histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
|
|
2204
|
-
"scatterplot","lineplot","catplot","displot","lmplot"
|
|
2205
|
-
]
|
|
2206
|
-
# Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
|
|
2207
|
-
# (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
|
|
2208
|
-
pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
|
|
2209
|
-
|
|
2210
|
-
def _add_prefix(m):
|
|
2211
|
-
fn = m.group(1)
|
|
2212
|
-
return f"sns.{fn}("
|
|
2213
|
-
|
|
2214
|
-
return pattern.sub(_add_prefix, code)
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
def patch_ensure_seaborn_import(code: str) -> str:
|
|
2218
|
-
"""
|
|
2219
|
-
If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
|
|
2220
|
-
Also set a quiet theme for consistent visuals.
|
|
2221
|
-
"""
|
|
2222
|
-
needs_sns = "sns." in code
|
|
2223
|
-
has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2224
|
-
if needs_sns and not has_import:
|
|
2225
|
-
# Insert after the first block of imports if possible, else at top
|
|
2226
|
-
import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2227
|
-
inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2228
|
-
if import_block:
|
|
2229
|
-
start = import_block.end()
|
|
2230
|
-
code = code[:start] + inject + code[start:]
|
|
2231
|
-
else:
|
|
2232
|
-
code = inject + code
|
|
2233
|
-
return code
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2237
|
-
"""
|
|
2238
|
-
Normalise pie-chart requests.
|
|
2239
|
-
|
|
2240
|
-
Supports three patterns:
|
|
2241
|
-
A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2242
|
-
B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2243
|
-
→ one pie per facet level (grid) + counts bar of facet sizes.
|
|
2244
|
-
C) Single pie when no split/facet is requested.
|
|
2245
|
-
|
|
2246
|
-
Notes:
|
|
2247
|
-
- Pie variables must be categorical (or numeric binned).
|
|
2248
|
-
- Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2249
|
-
"""
|
|
2250
|
-
|
|
2251
|
-
q = (user_question or "")
|
|
2252
|
-
q_low = q.lower()
|
|
2253
|
-
|
|
2254
|
-
# Prefer explicit: df['col'].value_counts()
|
|
2255
|
-
m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2256
|
-
col = m.group(1) if m else None
|
|
2257
|
-
|
|
2258
|
-
# ---------- helpers ----------
|
|
2259
|
-
def _is_cat(col):
|
|
2260
|
-
return (str(df[col].dtype).startswith("category")
|
|
2261
|
-
or df[col].dtype == "object"
|
|
2262
|
-
or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2263
|
-
|
|
2264
|
-
def _cats_from_question(question: str):
|
|
2265
|
-
found = []
|
|
2266
|
-
for c in df.columns:
|
|
2267
|
-
if c.lower() in question.lower() and _is_cat(c):
|
|
2268
|
-
found.append(c)
|
|
2269
|
-
# dedupe preserve order
|
|
2270
|
-
seen, out = set(), []
|
|
2271
|
-
for c in found:
|
|
2272
|
-
if c not in seen:
|
|
2273
|
-
out.append(c); seen.add(c)
|
|
2274
|
-
return out
|
|
2275
|
-
|
|
2276
|
-
def _fallback_cat():
|
|
2277
|
-
cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2278
|
-
if not cats: return None
|
|
2279
|
-
cats.sort(key=lambda t: t[1])
|
|
2280
|
-
return cats[0][0]
|
|
2281
|
-
|
|
2282
|
-
def _infer_comp_pref(question: str) -> str:
|
|
2283
|
-
ql = (question or "").lower()
|
|
2284
|
-
if "heatmap" in ql or "matrix" in ql:
|
|
2285
|
-
return "heatmap"
|
|
2286
|
-
if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
|
|
2287
|
-
return "stacked_bar_pct"
|
|
2288
|
-
if "stacked" in ql:
|
|
2289
|
-
return "stacked_bar"
|
|
2290
|
-
if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
|
|
2291
|
-
return "grouped_bar"
|
|
2292
|
-
return "counts_bar"
|
|
2293
|
-
|
|
2294
|
-
# parse threshold split like "HbA1c ≥ 6.5"
|
|
2295
|
-
def _parse_split(question: str):
|
|
2296
|
-
ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
|
|
2297
|
-
m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
|
|
2298
|
-
if not m: return None
|
|
2299
|
-
col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
|
|
2300
|
-
op = ops_map.get(op_raw);
|
|
2301
|
-
if not op: return None
|
|
2302
|
-
# case-insensitive column match
|
|
2303
|
-
candidates = {c.lower(): c for c in df.columns}
|
|
2304
|
-
col = candidates.get(col_raw.lower())
|
|
2305
|
-
if not col: return None
|
|
2306
|
-
try: val = float(val_raw)
|
|
2307
|
-
except Exception: return None
|
|
2308
|
-
return (col, op, val)
|
|
2309
|
-
|
|
2310
|
-
# facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
|
|
2311
|
-
def _extract_facet(question: str):
|
|
2312
|
-
# 1) explicit "by/ across / within / per <col>"
|
|
2313
|
-
for kw in [" by ", " across ", " within ", " within each ", " per "]:
|
|
2314
|
-
m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
|
|
2315
|
-
if m:
|
|
2316
|
-
col_raw = m.group(1).strip()
|
|
2317
|
-
candidates = {c.lower(): c for c in df.columns}
|
|
2318
|
-
if col_raw.lower() in candidates:
|
|
2319
|
-
return (candidates[col_raw.lower()], "auto")
|
|
2320
|
-
# 2) "bin <col>"
|
|
2321
|
-
m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
|
|
2322
|
-
if m2:
|
|
2323
|
-
col_raw = m2.group(1).strip()
|
|
2324
|
-
candidates = {c.lower(): c for c in df.columns}
|
|
2325
|
-
if col_raw.lower() in candidates:
|
|
2326
|
-
return (candidates[col_raw.lower()], "bin")
|
|
2327
|
-
# 3) BMI special: mentions of normal/overweight/obese imply BMI categories
|
|
2328
|
-
if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
|
|
2329
|
-
any(c.lower() == "bmi" for c in df.columns.str.lower()):
|
|
2330
|
-
bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
|
|
2331
|
-
return (bmi_col, "bmi")
|
|
2332
|
-
return None
|
|
2333
|
-
|
|
2334
|
-
def _bmi_bins(series: pd.Series):
|
|
2335
|
-
# WHO cutoffs
|
|
2336
|
-
bins = [-np.inf, 18.5, 25, 30, np.inf]
|
|
2337
|
-
labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
|
|
2338
|
-
return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
|
|
2339
|
-
|
|
2340
|
-
wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
|
|
2341
|
-
if not wants_pie:
|
|
2342
|
-
return code
|
|
2343
|
-
|
|
2344
|
-
split = _parse_split(q)
|
|
2345
|
-
facet = _extract_facet(q)
|
|
2346
|
-
cats = _cats_from_question(q)
|
|
2347
|
-
_comp_pref = _infer_comp_pref(q)
|
|
2348
|
-
|
|
2349
|
-
# Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
|
|
2350
|
-
for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
|
|
2351
|
-
if hard in df.columns and hard not in cats and hard.lower() in q_low:
|
|
2352
|
-
cats.append(hard)
|
|
2353
|
-
|
|
2354
|
-
# --------------- CASE A: threshold split (cohorts) ---------------
|
|
2355
|
-
if split:
|
|
2356
|
-
if not (cats or any(_is_cat(c) for c in df.columns)):
|
|
2357
|
-
return code
|
|
2358
|
-
if not cats:
|
|
2359
|
-
pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2360
|
-
pool.sort(key=lambda t: t[1])
|
|
2361
|
-
cats = [t[0] for t in pool[:3]] if pool else []
|
|
2362
|
-
if not cats:
|
|
2363
|
-
return code
|
|
2364
|
-
|
|
2365
|
-
split_col, op, val = split
|
|
2366
|
-
cond_str = f"(df['{split_col}'] {op} {val})"
|
|
2367
|
-
snippet = f"""
|
|
2368
|
-
import numpy as np
|
|
2369
|
-
import pandas as pd
|
|
2370
|
-
import matplotlib.pyplot as plt
|
|
2371
|
-
|
|
2372
|
-
_mask_a = ({cond_str}) & df['{split_col}'].notna()
|
|
2373
|
-
_mask_b = (~({cond_str})) & df['{split_col}'].notna()
|
|
2374
|
-
|
|
2375
|
-
_cohort_a_name = "{split_col} {op} {val}"
|
|
2376
|
-
_cohort_b_name = "NOT ({split_col} {op} {val})"
|
|
2377
|
-
|
|
2378
|
-
_cat_cols = {cats!r}
|
|
2379
|
-
n = len(_cat_cols)
|
|
2380
|
-
fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
|
|
2381
|
-
if n == 1:
|
|
2382
|
-
axes = np.array([axes])
|
|
2383
|
-
|
|
2384
|
-
for i, col in enumerate(_cat_cols):
|
|
2385
|
-
s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
|
|
2386
|
-
s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
|
|
2387
|
-
|
|
2388
|
-
ax_a = axes[i, 0]; ax_b = axes[i, 1]
|
|
2389
|
-
if len(s_a) > 0:
|
|
2390
|
-
ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
|
|
2391
|
-
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2392
|
-
ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
|
|
2393
|
-
|
|
2394
|
-
if len(s_b) > 0:
|
|
2395
|
-
ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
|
|
2396
|
-
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2397
|
-
ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
|
|
2398
|
-
|
|
2399
|
-
plt.tight_layout(); plt.show()
|
|
2400
|
-
|
|
2401
|
-
# grouped bar complement
|
|
2402
|
-
for col in _cat_cols:
|
|
2403
|
-
_tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
|
|
2404
|
-
.assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
|
|
2405
|
-
_tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
|
|
2406
|
-
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2407
|
-
|
|
2408
|
-
if _comp_pref == "grouped_bar":
|
|
2409
|
-
ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
|
|
2410
|
-
ax.set_title(f"{col} by cohort (grouped)")
|
|
2411
|
-
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2412
|
-
plt.tight_layout(); plt.show()
|
|
2413
|
-
|
|
2414
|
-
elif _comp_pref == "stacked_bar":
|
|
2415
|
-
ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2416
|
-
ax.set_title(f"{col} by cohort (stacked)")
|
|
2417
|
-
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2418
|
-
plt.tight_layout(); plt.show()
|
|
2419
|
-
|
|
2420
|
-
elif _comp_pref == "stacked_bar_pct":
|
|
2421
|
-
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2422
|
-
ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2423
|
-
ax.set_title(f"{col} by cohort (100% stacked)")
|
|
2424
|
-
ax.set_xlabel(col); ax.set_ylabel("Percent")
|
|
2425
|
-
plt.tight_layout(); plt.show()
|
|
2426
|
-
|
|
2427
|
-
elif _comp_pref == "heatmap":
|
|
2428
|
-
_perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2429
|
-
import numpy as np
|
|
2430
|
-
fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
|
|
2431
|
-
im = ax.imshow(_perc.values, aspect='auto')
|
|
2432
|
-
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2433
|
-
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2434
|
-
ax.set_title(f"{col} by cohort — % heatmap")
|
|
2435
|
-
for i in range(_perc.shape[0]):
|
|
2436
|
-
for j in range(_perc.shape[1]):
|
|
2437
|
-
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2438
|
-
fig.colorbar(im, ax=ax, label="%")
|
|
2439
|
-
plt.tight_layout(); plt.show()
|
|
2440
|
-
|
|
2441
|
-
else: # counts_bar (default)
|
|
2442
|
-
ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
|
|
2443
|
-
ax.set_title(f"{col}: total counts (both cohorts)")
|
|
2444
|
-
ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2445
|
-
plt.tight_layout(); plt.show()
|
|
2446
|
-
""".lstrip()
|
|
2447
|
-
return snippet
|
|
2448
|
-
|
|
2449
|
-
# --------------- CASE B: facet-by (categories/bins) ---------------
|
|
2450
|
-
if facet:
|
|
2451
|
-
facet_col, how = facet
|
|
2452
|
-
# Build facet series
|
|
2453
|
-
if pd.api.types.is_numeric_dtype(df[facet_col]):
|
|
2454
|
-
if how == "bmi":
|
|
2455
|
-
facet_series = _bmi_bins(df[facet_col])
|
|
2456
|
-
else:
|
|
2457
|
-
# generic numeric bins: 3 equal-width bins by default
|
|
2458
|
-
facet_series = pd.cut(df[facet_col].astype(float), bins=3)
|
|
2459
|
-
else:
|
|
2460
|
-
facet_series = df[facet_col].astype(str)
|
|
2461
|
-
|
|
2462
|
-
# Choose pie dimension (categorical to count inside each facet)
|
|
2463
|
-
pie_dim = None
|
|
2464
|
-
for c in cats:
|
|
2465
|
-
if c in df.columns and _is_cat(c):
|
|
2466
|
-
pie_dim = c; break
|
|
2467
|
-
if pie_dim is None:
|
|
2468
|
-
pie_dim = _fallback_cat()
|
|
2469
|
-
if pie_dim is None:
|
|
2470
|
-
return code
|
|
2471
|
-
|
|
2472
|
-
snippet = f"""
|
|
2473
|
-
import math
|
|
2474
|
-
import pandas as pd
|
|
2475
|
-
import matplotlib.pyplot as plt
|
|
2476
|
-
|
|
2477
|
-
df = df.copy()
|
|
2478
|
-
_preferred = "{facet_col}" if "{facet_col}" in df.columns else None
|
|
2479
|
-
|
|
2480
|
-
def _select_facet_col(df, preferred=None):
|
|
2481
|
-
if preferred is not None:
|
|
2482
|
-
return preferred
|
|
2483
|
-
# Prefer low-cardinality categoricals (readable pies/grids)
|
|
2484
|
-
cat_cols = [
|
|
2485
|
-
c for c in df.columns
|
|
2486
|
-
if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
|
2487
|
-
and df[c].nunique() > 1 and df[c].nunique() <= 20
|
|
2488
|
-
]
|
|
2489
|
-
if cat_cols:
|
|
2490
|
-
cat_cols.sort(key=lambda c: df[c].nunique())
|
|
2491
|
-
return cat_cols[0]
|
|
2492
|
-
# Else fall back to first usable numeric
|
|
2493
|
-
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
|
|
2494
|
-
return num_cols[0] if num_cols else None
|
|
2495
|
-
|
|
2496
|
-
_facet_col = _select_facet_col(df, _preferred)
|
|
2497
|
-
|
|
2498
|
-
if _facet_col is None:
|
|
2499
|
-
# Nothing suitable → single facet keeps pipeline alive
|
|
2500
|
-
df["__facet__"] = "All"
|
|
2501
|
-
else:
|
|
2502
|
-
s = df[_facet_col]
|
|
2503
|
-
if pd.api.types.is_numeric_dtype(s):
|
|
2504
|
-
# Robust numeric binning: quantiles first, fallback to equal-width
|
|
2505
|
-
uniq = pd.Series(s).dropna().nunique()
|
|
2506
|
-
q = 3 if uniq < 10 else 4 if uniq < 30 else 5
|
|
2507
|
-
try:
|
|
2508
|
-
df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
|
|
2509
|
-
except Exception:
|
|
2510
|
-
df["__facet__"] = pd.cut(s.astype(float), bins=q)
|
|
2511
|
-
else:
|
|
2512
|
-
# Cap long tails; keep top categories
|
|
2513
|
-
vc = s.astype(str).value_counts()
|
|
2514
|
-
keep = vc.index[:{top_n}]
|
|
2515
|
-
df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
|
|
2516
|
-
|
|
2517
|
-
levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
|
|
2518
|
-
levels = [x for x in levels if x != "nan"]
|
|
2519
|
-
levels.sort()
|
|
2520
|
-
|
|
2521
|
-
m = len(levels)
|
|
2522
|
-
cols = 3 if m >= 3 else m or 1
|
|
2523
|
-
rows = int(math.ceil(m / cols))
|
|
2524
|
-
|
|
2525
|
-
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
|
|
2526
|
-
if not isinstance(axes, (list, np.ndarray)):
|
|
2527
|
-
axes = np.array([[axes]])
|
|
2528
|
-
axes = axes.reshape(rows, cols)
|
|
2529
|
-
|
|
2530
|
-
for i, lvl in enumerate(levels):
|
|
2531
|
-
r, c = divmod(i, cols)
|
|
2532
|
-
ax = axes[r, c]
|
|
2533
|
-
s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
|
|
2534
|
-
.astype(str).value_counts().nlargest({top_n}))
|
|
2535
|
-
if len(s) > 0:
|
|
2536
|
-
ax.pie(s.values, labels=[str(x) for x in s.index],
|
|
2537
|
-
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2538
|
-
ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
|
|
2539
|
-
|
|
2540
|
-
# hide any empty subplots
|
|
2541
|
-
for j in range(m, rows*cols):
|
|
2542
|
-
r, c = divmod(j, cols)
|
|
2543
|
-
axes[r, c].axis("off")
|
|
2544
|
-
|
|
2545
|
-
plt.tight_layout(); plt.show()
|
|
2546
|
-
|
|
2547
|
-
# --- companion visual (adaptive) ---
|
|
2548
|
-
_comp_pref = "{_comp_pref}"
|
|
2549
|
-
# build contingency table: pie_dim x facet
|
|
2550
|
-
_tab = (df[["__facet__", "{pie_dim}"]]
|
|
2551
|
-
.dropna()
|
|
2552
|
-
.astype({{"__facet__": str, "{pie_dim}": str}})
|
|
2553
|
-
.value_counts()
|
|
2554
|
-
.unstack(level="__facet__")
|
|
2555
|
-
.fillna(0))
|
|
2556
|
-
|
|
2557
|
-
# keep top categories by overall size
|
|
2558
|
-
_tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2559
|
-
|
|
2560
|
-
if _comp_pref == "grouped_bar":
|
|
2561
|
-
ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2562
|
-
ax.set_title("{pie_dim} by {facet_col} (grouped)")
|
|
2563
|
-
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2564
|
-
plt.tight_layout(); plt.show()
|
|
2565
|
-
|
|
2566
|
-
elif _comp_pref == "stacked_bar":
|
|
2567
|
-
ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2568
|
-
ax.set_title("{pie_dim} by {facet_col} (stacked)")
|
|
2569
|
-
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2570
|
-
plt.tight_layout(); plt.show()
|
|
2571
|
-
|
|
2572
|
-
elif _comp_pref == "stacked_bar_pct":
|
|
2573
|
-
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
|
|
2574
|
-
ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
|
|
2575
|
-
ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
|
|
2576
|
-
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
|
|
2577
|
-
plt.tight_layout(); plt.show()
|
|
2578
|
-
|
|
2579
|
-
elif _comp_pref == "heatmap":
|
|
2580
|
-
_perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
|
|
2581
|
-
import numpy as np
|
|
2582
|
-
fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
|
|
2583
|
-
im = ax.imshow(_perc.values, aspect='auto')
|
|
2584
|
-
ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2585
|
-
ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2586
|
-
ax.set_title("{pie_dim} by {facet_col} — % heatmap")
|
|
2587
|
-
for i in range(_perc.shape[0]):
|
|
2588
|
-
for j in range(_perc.shape[1]):
|
|
2589
|
-
ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2590
|
-
fig.colorbar(im, ax=ax, label="%")
|
|
2591
|
-
plt.tight_layout(); plt.show()
|
|
2592
|
-
|
|
2593
|
-
else: # counts_bar (default denominators)
|
|
2594
|
-
_counts = df["__facet"].value_counts()
|
|
2595
|
-
ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
|
|
2596
|
-
ax.set_title("Counts by {facet_col}")
|
|
2597
|
-
ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2598
|
-
plt.tight_layout(); plt.show()
|
|
2599
|
-
|
|
2600
|
-
""".lstrip()
|
|
2601
|
-
return snippet
|
|
2602
|
-
|
|
2603
|
-
# --------------- CASE C: single pie ---------------
|
|
2604
|
-
chosen = None
|
|
2605
|
-
for c in cats:
|
|
2606
|
-
if c in df.columns and _is_cat(c):
|
|
2607
|
-
chosen = c; break
|
|
2608
|
-
if chosen is None:
|
|
2609
|
-
chosen = _fallback_cat()
|
|
2610
|
-
|
|
2611
|
-
if chosen:
|
|
2612
|
-
snippet = f"""
|
|
2613
|
-
import matplotlib.pyplot as plt
|
|
2614
|
-
counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
|
|
2615
|
-
fig, ax = plt.subplots()
|
|
2616
|
-
if len(counts) > 0:
|
|
2617
|
-
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2618
|
-
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2619
|
-
ax.set_title('Distribution of {chosen} (top {top_n})')
|
|
2620
|
-
ax.axis('equal')
|
|
2621
|
-
plt.show()
|
|
2622
|
-
""".lstrip()
|
|
2623
|
-
return snippet
|
|
2624
|
-
|
|
2625
|
-
# numeric last resort
|
|
2626
|
-
num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
2627
|
-
if num_cols:
|
|
2628
|
-
col = num_cols[0]
|
|
2629
|
-
snippet = f"""
|
|
2630
|
-
import pandas as pd
|
|
2631
|
-
import matplotlib.pyplot as plt
|
|
2632
|
-
bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
|
|
2633
|
-
counts = bins.value_counts().sort_index()
|
|
2634
|
-
fig, ax = plt.subplots()
|
|
2635
|
-
if len(counts) > 0:
|
|
2636
|
-
ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2637
|
-
autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2638
|
-
ax.set_title('Distribution of {col} (binned)')
|
|
2639
|
-
ax.axis('equal')
|
|
2640
|
-
plt.show()
|
|
2641
|
-
""".lstrip()
|
|
2642
|
-
return snippet
|
|
2643
|
-
|
|
2644
|
-
return code
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2648
|
-
"""
|
|
2649
|
-
Removes seaborn `palette=` when no `hue=` is present in the same call.
|
|
2650
|
-
Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
|
|
2651
|
-
"""
|
|
2652
|
-
if "sns." not in code:
|
|
2653
|
-
return code
|
|
2654
|
-
|
|
2655
|
-
# Targets common seaborn plotters
|
|
2656
|
-
funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
|
|
2657
|
-
pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
|
|
2658
|
-
|
|
2659
|
-
def _fix_call(m):
|
|
2660
|
-
head, inner = m.group(1), m.group(2)
|
|
2661
|
-
# If there's already hue=, keep as is
|
|
2662
|
-
if re.search(r"(?<!\w)hue\s*=", inner):
|
|
2663
|
-
return f"{head}{inner})"
|
|
2664
|
-
# Otherwise remove palette=... safely (and any adjacent comma spacing)
|
|
2665
|
-
inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
|
|
2666
|
-
inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
|
|
2667
|
-
inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
|
|
2668
|
-
return f"{head}{inner2})"
|
|
2669
|
-
|
|
2670
|
-
return pattern.sub(_fix_call, code)
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
def patch_quiet_specific_warnings(code: str) -> str:
|
|
2674
|
-
"""
|
|
2675
|
-
Inserts targeted warning filters (not blanket ignores).
|
|
2676
|
-
- seaborn palette/hue deprecation
|
|
2677
|
-
- python-dotenv parse chatter
|
|
2678
|
-
"""
|
|
2679
|
-
prelude = (
|
|
2680
|
-
"import warnings\n"
|
|
2681
|
-
"warnings.filterwarnings(\n"
|
|
2682
|
-
" 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
|
|
2683
|
-
"warnings.filterwarnings(\n"
|
|
2684
|
-
" 'ignore', message=r'python-dotenv could not parse statement.*')\n"
|
|
2685
|
-
)
|
|
2686
|
-
# If warnings already imported once, just add filters; else insert full prelude.
|
|
2687
|
-
if "import warnings" in code:
|
|
2688
|
-
code = re.sub(
|
|
2689
|
-
r"(import warnings[^\n]*\n)",
|
|
2690
|
-
lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
|
|
2691
|
-
code,
|
|
2692
|
-
count=1
|
|
2693
|
-
)
|
|
2694
|
-
|
|
2695
|
-
else:
|
|
2696
|
-
# place after first import block if possible
|
|
2697
|
-
m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
|
|
2698
|
-
if m:
|
|
2699
|
-
idx = m.end()
|
|
2700
|
-
code = code[:idx] + prelude + code[idx:]
|
|
2701
|
-
else:
|
|
2702
|
-
code = prelude + code
|
|
2703
|
-
return code
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
def _norm_col_name(s: str) -> str:
|
|
2707
|
-
"""normalise a column name: lowercase + strip non-alphanumerics."""
|
|
2708
|
-
return re.sub(r"[^a-z0-9]+", "", str(s).lower())
|
|
2709
|
-
|
|
2710
|
-
|
|
2711
|
-
def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
2712
|
-
"""return the actual df column that matches any candidate (after normalisation)."""
|
|
2713
|
-
norm_map = {_norm_col_name(c): c for c in df.columns}
|
|
2714
|
-
for cand in candidates:
|
|
2715
|
-
hit = norm_map.get(_norm_col_name(cand))
|
|
2716
|
-
if hit is not None:
|
|
2717
|
-
return hit
|
|
2718
|
-
return None
|
|
2719
|
-
|
|
2720
|
-
|
|
2721
|
-
def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
|
|
2722
|
-
"""
|
|
2723
|
-
If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
|
|
2724
|
-
Returns (df, found_bool).
|
|
2725
|
-
"""
|
|
2726
|
-
if target in df.columns:
|
|
2727
|
-
return df, True
|
|
2728
|
-
col = _first_present(df, [target, *aliases])
|
|
2729
|
-
if col is None:
|
|
2730
|
-
return df, False
|
|
2731
|
-
df[target] = df[col]
|
|
2732
|
-
return df, True
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
def strip_python_dotenv(code: str) -> str:
|
|
2736
|
-
"""
|
|
2737
|
-
Remove any use of python-dotenv from generated code, including:
|
|
2738
|
-
- single and multi-line 'from dotenv import ...'
|
|
2739
|
-
- 'import dotenv' (with or without alias) and calls via any alias
|
|
2740
|
-
- load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
|
|
2741
|
-
- IPython magics (%load_ext dotenv, %dotenv, %env …)
|
|
2742
|
-
- shell installs like '!pip install python-dotenv'
|
|
2743
|
-
"""
|
|
2744
|
-
original = code
|
|
2745
|
-
|
|
2746
|
-
# 0) Kill IPython magics & shell installs referencing dotenv
|
|
2747
|
-
code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
|
|
2748
|
-
code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
|
|
2749
|
-
code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
|
|
2750
|
-
code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
|
|
2751
|
-
|
|
2752
|
-
# 1) Remove single-line 'from dotenv import ...'
|
|
2753
|
-
code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
|
|
2754
|
-
|
|
2755
|
-
# 2) Remove multi-line 'from dotenv import ( ... )' blocks
|
|
2756
|
-
code = re.sub(
|
|
2757
|
-
r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
|
|
2758
|
-
"",
|
|
2759
|
-
code,
|
|
2760
|
-
flags=re.MULTILINE,
|
|
2761
|
-
)
|
|
2762
|
-
|
|
2763
|
-
# 3) Remove 'import dotenv' (with optional alias). Capture alias names.
|
|
2764
|
-
aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
|
|
2765
|
-
code, flags=re.MULTILINE)
|
|
2766
|
-
code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
|
|
2767
|
-
"", code, flags=re.MULTILINE)
|
|
2768
|
-
|
|
2769
|
-
# 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
|
|
2770
|
-
# e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
|
|
2771
|
-
fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
|
|
2772
|
-
# bare calls
|
|
2773
|
-
code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2774
|
-
# dotted calls with any identifier prefix (alias or module)
|
|
2775
|
-
code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
|
|
2776
|
-
"", code, flags=re.MULTILINE)
|
|
2777
|
-
|
|
2778
|
-
# 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
|
|
2779
|
-
for al in aliases:
|
|
2780
|
-
code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2781
|
-
|
|
2782
|
-
# 6) Tidy excess blank lines
|
|
2783
|
-
code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
|
|
2784
|
-
return code
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
def fix_predict_calls_records_arg(code: str) -> str:
|
|
2788
|
-
"""
|
|
2789
|
-
If generated code calls predict_* with a list-of-dicts via .to_dict('records')
|
|
2790
|
-
(or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
|
|
2791
|
-
Works line-by-line to avoid over-rewrites elsewhere.
|
|
2792
|
-
Examples fixed:
|
|
2793
|
-
predict_patient(X_test.iloc[:5].to_dict('records'))
|
|
2794
|
-
predict_risk(df.head(3).to_dict(orient="records"))
|
|
2795
|
-
→ predict_patient(X_test.iloc[:5])
|
|
2796
|
-
"""
|
|
2797
|
-
fixed_lines = []
|
|
2798
|
-
for line in code.splitlines():
|
|
2799
|
-
if "predict_" in line and "to_dict" in line and "records" in line:
|
|
2800
|
-
line = re.sub(
|
|
2801
|
-
r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
|
|
2802
|
-
"",
|
|
2803
|
-
line
|
|
2804
|
-
)
|
|
2805
|
-
fixed_lines.append(line)
|
|
2806
|
-
return "\n".join(fixed_lines)
|
|
2807
|
-
|
|
2808
|
-
|
|
2809
|
-
def fix_fstring_backslash_paths(code: str) -> str:
|
|
2810
|
-
"""
|
|
2811
|
-
Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
|
|
2812
|
-
→ f"...{os.path.join(out_dir, r'plots\\img.png')}"
|
|
2813
|
-
Only touches f-strings that contain a backslash path inside {...}.
|
|
2814
|
-
"""
|
|
2815
|
-
def _fix_line(line: str) -> str:
|
|
2816
|
-
# quick check: only f-strings need scanning
|
|
2817
|
-
if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
|
|
2818
|
-
return line
|
|
2819
|
-
# {var\rest-of-path} where var can be dotted (e.g., cfg.out)
|
|
2820
|
-
pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
|
|
2821
|
-
def repl(m):
|
|
2822
|
-
left = m.group(1)
|
|
2823
|
-
right = m.group(2).strip().replace('"', '\\"')
|
|
2824
|
-
return "{os.path.join(" + left + ', r"' + right + '")}'
|
|
2825
|
-
return pattern.sub(repl, line)
|
|
2826
|
-
|
|
2827
|
-
return "\n".join(_fix_line(ln) for ln in code.splitlines())
|
|
2828
|
-
|
|
2829
|
-
|
|
2830
|
-
def ensure_os_import(code: str) -> str:
|
|
2831
|
-
"""
|
|
2832
|
-
If os.path.join is used but 'import os' is missing, inject it at the top.
|
|
2833
|
-
"""
|
|
2834
|
-
needs = "os.path.join(" in code
|
|
2835
|
-
has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
|
|
2836
|
-
has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
|
|
2837
|
-
if needs and not (has_import_os or has_from_os):
|
|
2838
|
-
return "import os\n" + code
|
|
2839
|
-
return code
|
|
2840
|
-
|
|
2841
|
-
|
|
2842
|
-
def fix_seaborn_boxplot_nameerror(code: str) -> str:
|
|
2843
|
-
"""
|
|
2844
|
-
Fix bad calls like: sns.boxplot(boxplot)
|
|
2845
|
-
Heuristic:
|
|
2846
|
-
- If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
|
|
2847
|
-
- Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
|
|
2848
|
-
- Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
|
|
2849
|
-
- Else → sns.boxplot(ax=ax)
|
|
2850
|
-
Ensures a matplotlib Axes 'ax' exists.
|
|
2851
|
-
"""
|
|
2852
|
-
pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
|
|
2853
|
-
if not pattern.search(code):
|
|
2854
|
-
return code
|
|
2855
|
-
|
|
2856
|
-
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2857
|
-
has_var = re.search(r"\bvar\b", code) is not None
|
|
2858
|
-
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2859
|
-
|
|
2860
|
-
if has_plot_df and has_var and has_fh:
|
|
2861
|
-
replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2862
|
-
elif has_plot_df and has_var:
|
|
2863
|
-
replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
|
|
2864
|
-
elif has_plot_df:
|
|
2865
|
-
replacement = "sns.boxplot(data=plot_df, ax=ax)"
|
|
2866
|
-
else:
|
|
2867
|
-
replacement = "sns.boxplot(ax=ax)"
|
|
2868
|
-
|
|
2869
|
-
fixed = pattern.sub(replacement, code)
|
|
2870
|
-
|
|
2871
|
-
# Ensure 'fig, ax = plt.subplots(...)' exists
|
|
2872
|
-
if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
|
|
2873
|
-
# Insert right before the first seaborn call
|
|
2874
|
-
m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
|
|
2875
|
-
insert_at = m.start() if m else 0
|
|
2876
|
-
fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
|
|
2877
|
-
|
|
2878
|
-
return fixed
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
def fix_seaborn_barplot_nameerror(code: str) -> str:
|
|
2882
|
-
"""
|
|
2883
|
-
Fix bad calls like: sns.barplot(barplot)
|
|
2884
|
-
Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
|
|
2885
|
-
otherwise degrade safely to an empty call on an existing Axes.
|
|
2886
|
-
"""
|
|
2887
|
-
import re
|
|
2888
|
-
pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
|
|
2889
|
-
if not pattern.search(code):
|
|
2890
|
-
return code
|
|
2891
|
-
|
|
2892
|
-
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2893
|
-
has_var = re.search(r"\bvar\b", code) is not None
|
|
2894
|
-
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2895
|
-
|
|
2896
|
-
if has_plot_df and has_var and has_fh:
|
|
2897
|
-
replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2898
|
-
elif has_plot_df and has_var:
|
|
2899
|
-
replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2900
|
-
elif has_plot_df:
|
|
2901
|
-
replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2902
|
-
else:
|
|
2903
|
-
replacement = "sns.barplot(ax=ax)"
|
|
2904
|
-
|
|
2905
|
-
# ensure an Axes 'ax' exists (no-op if already present)
|
|
2906
|
-
if "ax =" not in code:
|
|
2907
|
-
code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2908
|
-
|
|
2909
|
-
return pattern.sub(replacement, code)
|
|
2910
|
-
|
|
2911
|
-
|
|
2912
|
-
def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2913
|
-
"""
|
|
2914
|
-
Parses the raw text to extract and format the 'refined question',
|
|
2915
|
-
'intents (tasks)', and 'chronology of tasks' sections.
|
|
2916
|
-
Args:
|
|
2917
|
-
raw_text: The complete input string containing the ML pipeline structure.
|
|
2918
|
-
Returns:
|
|
2919
|
-
A tuple containing:
|
|
2920
|
-
(formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2921
|
-
"""
|
|
2922
|
-
# --- 1. Regex Pattern to Extract Sections ---
|
|
2923
|
-
# The pattern uses capturing groups (?) to look for the section headers
|
|
2924
|
-
# (e.g., 'refined question:') and captures all the content until the next
|
|
2925
|
-
# section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2926
|
-
|
|
2927
|
-
pattern = re.compile(
|
|
2928
|
-
r"refined question:(?P<question>.*?)"
|
|
2929
|
-
r"intents \(tasks\):(?P<intents>.*?)"
|
|
2930
|
-
r"Chronology of tasks:(?P<chronology>.*)",
|
|
2931
|
-
re.IGNORECASE | re.DOTALL
|
|
2932
|
-
)
|
|
2933
|
-
|
|
2934
|
-
match = pattern.search(raw_text)
|
|
2935
|
-
|
|
2936
|
-
if not match:
|
|
2937
|
-
raise ValueError("Input text structure does not match the expected pattern.")
|
|
2938
|
-
|
|
2939
|
-
# --- 2. Extract Content ---
|
|
2940
|
-
question_content = match.group('question').strip()
|
|
2941
|
-
intents_content = match.group('intents').strip()
|
|
2942
|
-
chronology_content = match.group('chronology').strip()
|
|
2943
|
-
|
|
2944
|
-
# --- 3. Formatting Functions ---
|
|
2945
|
-
|
|
2946
|
-
def format_question(content):
|
|
2947
|
-
"""Formats the Refined Question section."""
|
|
2948
|
-
# Clean up leading/trailing whitespace and ensure clean paragraphs
|
|
2949
|
-
content = content.strip().replace('\n', ' ').replace(' ', ' ')
|
|
2950
|
-
|
|
2951
|
-
# Simple formatting using Markdown headers and bolding
|
|
2952
|
-
formatted = (
|
|
2953
|
-
# "## 1. Project Goal and Objectives\n\n"
|
|
2954
|
-
"<b> Refined Question:</b>\n"
|
|
2955
|
-
f"{content}\n"
|
|
2956
|
-
)
|
|
2957
|
-
return formatted
|
|
2958
|
-
|
|
2959
|
-
def format_intents(content):
|
|
2960
|
-
"""Formats the Intents (Tasks) section as a structured list."""
|
|
2961
|
-
# Use regex to find and format each numbered task
|
|
2962
|
-
# It finds 'N. **Text** - ...' and breaks it down.
|
|
2963
|
-
|
|
2964
|
-
tasks = []
|
|
2965
|
-
# Pattern: N. **Text** - Content (including newlines, non-greedy)
|
|
2966
|
-
# We need to explicitly handle the list items starting with '-' within the content
|
|
2967
|
-
task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
|
|
2968
|
-
|
|
2969
|
-
# Split the content by lines and join tasks back into clean strings
|
|
2970
|
-
raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
|
|
2971
|
-
|
|
2972
|
-
for task in raw_tasks:
|
|
2973
|
-
# Replace the initial task number and **Heading** with a Heading 3
|
|
2974
|
-
task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
|
|
2975
|
-
|
|
2976
|
-
# Replace list markers (' - ') with Markdown bullets ('* ') for clarity
|
|
2977
|
-
task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
|
|
2978
|
-
tasks.append(task)
|
|
2979
|
-
|
|
2980
|
-
formatted_tasks = "\n\n".join(tasks)
|
|
2981
|
-
|
|
2982
|
-
return (
|
|
2983
|
-
"\n---\n"
|
|
2984
|
-
"## 2. Methodology and Tasks\n\n"
|
|
2985
|
-
f"{formatted_tasks}\n"
|
|
2986
|
-
)
|
|
2987
|
-
|
|
2988
|
-
def format_chronology(content):
|
|
2989
|
-
"""Formats the Chronology section."""
|
|
2990
|
-
# Uses the given LaTeX format
|
|
2991
|
-
content = content.strip().replace(' ', ' \rightarrow ')
|
|
2992
|
-
formatted = (
|
|
2993
|
-
"\n---\n"
|
|
2994
|
-
"## 3. Chronology of Tasks\n"
|
|
2995
|
-
f"$$\\text{{{content}}}$$"
|
|
2996
|
-
)
|
|
2997
|
-
return formatted
|
|
2998
|
-
|
|
2999
|
-
# --- 4. Format and Return ---
|
|
3000
|
-
formatted_question = format_question(question_content)
|
|
3001
|
-
formatted_intents = format_intents(intents_content)
|
|
3002
|
-
formatted_chronology = format_chronology(chronology_content)
|
|
3003
|
-
|
|
3004
|
-
return formatted_question, formatted_intents, formatted_chronology
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
3008
|
-
"""Combines all formatted parts into a final report string."""
|
|
3009
|
-
return (
|
|
3010
|
-
"# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
3011
|
-
f"{formatted_question}\n"
|
|
3012
|
-
f"{formatted_intents}\n"
|
|
3013
|
-
f"{formatted_chronology}\n"
|
|
3014
|
-
)
|
|
3015
|
-
|
|
3016
|
-
|
|
3017
|
-
def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
3018
|
-
"""
|
|
3019
|
-
Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
3020
|
-
from_predictions(...) which works for multi-label loops without requiring
|
|
3021
|
-
the estimator to expose _estimator_type.
|
|
3022
|
-
"""
|
|
3023
|
-
return re.sub(
|
|
3024
|
-
r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
3025
|
-
r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
3026
|
-
code
|
|
3027
|
-
)
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
3031
|
-
"""
|
|
3032
|
-
Ensure every Matplotlib/Seaborn Axes has a title.
|
|
3033
|
-
Uses refined_question -> askai_question -> fallback.
|
|
3034
|
-
Only sets a title if it's currently empty.
|
|
3035
|
-
"""
|
|
3036
|
-
import matplotlib.pyplot as plt
|
|
3037
|
-
|
|
3038
|
-
def _all_figures():
|
|
3039
|
-
try:
|
|
3040
|
-
from matplotlib._pylab_helpers import Gcf
|
|
3041
|
-
return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
3042
|
-
except Exception:
|
|
3043
|
-
# Best effort fallback
|
|
3044
|
-
nums = plt.get_fignums()
|
|
3045
|
-
return [plt.figure(n) for n in nums] if nums else []
|
|
3046
|
-
|
|
3047
|
-
# Choose a concise title
|
|
3048
|
-
title = None
|
|
3049
|
-
if isinstance(ctx, dict):
|
|
3050
|
-
title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
3051
|
-
title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
3052
|
-
|
|
3053
|
-
for fig in _all_figures():
|
|
3054
|
-
for ax in getattr(fig, "axes", []):
|
|
3055
|
-
try:
|
|
3056
|
-
if not (ax.get_title() or "").strip():
|
|
3057
|
-
ax.set_title(title)
|
|
3058
|
-
except Exception:
|
|
3059
|
-
pass
|
|
3060
|
-
try:
|
|
3061
|
-
fig.tight_layout()
|
|
3062
|
-
except Exception:
|
|
3063
|
-
pass
|
|
3064
|
-
|
|
3065
|
-
|
|
3066
|
-
def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
3067
|
-
"""
|
|
3068
|
-
Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
3069
|
-
SB_barplot(barplot) -> SB_barplot()
|
|
3070
|
-
SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
3071
|
-
sns.barplot(barplot) -> SB_barplot()
|
|
3072
|
-
sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
3073
|
-
Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
3074
|
-
"""
|
|
3075
|
-
names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
3076
|
-
for n in names:
|
|
3077
|
-
# SB_* with sentinel as the first arg (with or without trailing args)
|
|
3078
|
-
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3079
|
-
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3080
|
-
# sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
3081
|
-
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3082
|
-
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3083
|
-
return code
|
|
3084
|
-
|
|
3085
|
-
|
|
3086
|
-
def patch_rmse_calls(code: str) -> str:
|
|
3087
|
-
"""
|
|
3088
|
-
Make RMSE robust across sklearn versions.
|
|
3089
|
-
- Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
3090
|
-
- Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
3091
|
-
"""
|
|
3092
|
-
import re
|
|
3093
|
-
# (a) Specific RMSE pattern
|
|
3094
|
-
code = re.sub(
|
|
3095
|
-
r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3096
|
-
r"_SMX_rmse(\1)",
|
|
3097
|
-
code,
|
|
3098
|
-
flags=re.DOTALL
|
|
3099
|
-
)
|
|
3100
|
-
# (b) Guard any other MSE calls
|
|
3101
|
-
code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3102
|
-
return code
|