syntaxmatrix 2.5.6__py3-none-any.whl → 2.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syntaxmatrix/agentic/agents.py +1220 -169
- syntaxmatrix/agentic/agents_orchestrer.py +326 -0
- syntaxmatrix/agentic/code_tools_registry.py +27 -32
- syntaxmatrix/commentary.py +16 -16
- syntaxmatrix/core.py +185 -81
- syntaxmatrix/db.py +460 -4
- syntaxmatrix/{display.py → display_html.py} +2 -6
- syntaxmatrix/gpt_models_latest.py +1 -1
- syntaxmatrix/media/__init__.py +0 -0
- syntaxmatrix/media/media_pixabay.py +277 -0
- syntaxmatrix/models.py +1 -1
- syntaxmatrix/page_builder_defaults.py +183 -0
- syntaxmatrix/page_builder_generation.py +1122 -0
- syntaxmatrix/page_layout_contract.py +644 -0
- syntaxmatrix/page_patch_publish.py +1471 -0
- syntaxmatrix/preface.py +142 -21
- syntaxmatrix/profiles.py +28 -10
- syntaxmatrix/routes.py +1740 -453
- syntaxmatrix/selftest_page_templates.py +360 -0
- syntaxmatrix/settings/client_items.py +28 -0
- syntaxmatrix/settings/model_map.py +1022 -207
- syntaxmatrix/settings/prompts.py +328 -130
- syntaxmatrix/static/assets/hero-default.svg +22 -0
- syntaxmatrix/static/icons/bot-icon.png +0 -0
- syntaxmatrix/static/icons/favicon.png +0 -0
- syntaxmatrix/static/icons/logo.png +0 -0
- syntaxmatrix/static/icons/logo3.png +0 -0
- syntaxmatrix/templates/admin_branding.html +104 -0
- syntaxmatrix/templates/admin_features.html +63 -0
- syntaxmatrix/templates/admin_secretes.html +108 -0
- syntaxmatrix/templates/dashboard.html +296 -133
- syntaxmatrix/templates/dataset_resize.html +535 -0
- syntaxmatrix/templates/edit_page.html +2535 -0
- syntaxmatrix/utils.py +2431 -2383
- {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/METADATA +6 -2
- {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/RECORD +39 -24
- syntaxmatrix/generate_page.py +0 -644
- syntaxmatrix/static/icons/hero_bg.jpg +0 -0
- {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/WHEEL +0 -0
- {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/licenses/LICENSE.txt +0 -0
- {syntaxmatrix-2.5.6.dist-info → syntaxmatrix-2.6.2.dist-info}/top_level.txt +0 -0
syntaxmatrix/utils.py
CHANGED
|
@@ -1,96 +1,46 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
import re, textwrap
|
|
3
|
-
import pandas as pd
|
|
4
|
-
import numpy as np
|
|
2
|
+
import re, textwrap, ast
|
|
3
|
+
import pandas as pd, numpy as np
|
|
5
4
|
import warnings
|
|
6
|
-
|
|
7
5
|
from difflib import get_close_matches
|
|
8
6
|
from typing import Iterable, Tuple, Dict
|
|
9
7
|
import inspect
|
|
10
8
|
from sklearn.preprocessing import OneHotEncoder
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
from syntaxmatrix.agentic.model_templates import (
|
|
14
|
-
classification, regression, multilabel_classification,
|
|
15
|
-
eda_overview, eda_correlation,
|
|
16
|
-
anomaly_detection, ts_anomaly_detection,
|
|
17
|
-
dimensionality_reduction, feature_selection,
|
|
18
|
-
time_series_forecasting, time_series_classification,
|
|
19
|
-
unknown_group_proxy_pack, viz_line,
|
|
20
|
-
clustering, recommendation, topic_modelling,
|
|
21
|
-
viz_pie, viz_count_bar, viz_box, viz_scatter,
|
|
22
|
-
viz_stacked_bar, viz_distribution, viz_area, viz_kde,
|
|
23
|
-
)
|
|
24
9
|
import ast
|
|
25
10
|
|
|
26
11
|
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
|
27
12
|
|
|
28
|
-
|
|
29
|
-
"classification",
|
|
30
|
-
"multilabel_classification",
|
|
31
|
-
"regression",
|
|
32
|
-
"anomaly_detection",
|
|
33
|
-
"time_series_forecasting",
|
|
34
|
-
"time_series_classification",
|
|
35
|
-
"ts_anomaly_detection",
|
|
36
|
-
"dimensionality_reduction",
|
|
37
|
-
"feature_selection",
|
|
38
|
-
"clustering",
|
|
39
|
-
"eda",
|
|
40
|
-
"correlation_analysis",
|
|
41
|
-
"visualisation",
|
|
42
|
-
"recommendation",
|
|
43
|
-
"topic_modelling",
|
|
44
|
-
]
|
|
45
|
-
|
|
46
|
-
def classify_ml_job(prompt: str) -> str:
|
|
13
|
+
def patch_quiet_specific_warnings(code: str) -> str:
|
|
47
14
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
'classification' | 'regression' | 'eda'
|
|
15
|
+
Inserts targeted warning filters (not blanket ignores).
|
|
16
|
+
- seaborn palette/hue deprecation
|
|
17
|
+
- python-dotenv parse chatter
|
|
52
18
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
))
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
"suspicious"
|
|
79
|
-
)):
|
|
80
|
-
return "anomaly_detection"
|
|
81
|
-
|
|
82
|
-
if any(k in p for k in ("t-test", "anova", "p-value")):
|
|
83
|
-
return "stat_test"
|
|
84
|
-
if "forecast" in p or "prophet" in p:
|
|
85
|
-
return "time_series"
|
|
86
|
-
if "cluster" in p or "kmeans" in p:
|
|
87
|
-
return "clustering"
|
|
88
|
-
if any(k in p for k in ("accuracy", "precision", "roc")):
|
|
89
|
-
return "classification"
|
|
90
|
-
if any(k in p for k in ("rmse", "r2", "mae")):
|
|
91
|
-
return "regression"
|
|
92
|
-
|
|
93
|
-
return "eda"
|
|
19
|
+
prelude = (
|
|
20
|
+
"import warnings\n"
|
|
21
|
+
"warnings.filterwarnings(\n"
|
|
22
|
+
" 'ignore', message=r'.*Passing `palette` without assigning `hue`.*', category=FutureWarning)\n"
|
|
23
|
+
"warnings.filterwarnings(\n"
|
|
24
|
+
" 'ignore', message=r'python-dotenv could not parse statement.*')\n"
|
|
25
|
+
)
|
|
26
|
+
# If warnings already imported once, just add filters; else insert full prelude.
|
|
27
|
+
if "import warnings" in code:
|
|
28
|
+
code = re.sub(
|
|
29
|
+
r"(import warnings[^\n]*\n)",
|
|
30
|
+
lambda m: m.group(1) + prelude.replace("import warnings\n", ""),
|
|
31
|
+
code,
|
|
32
|
+
count=1
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
# place after first import block if possible
|
|
37
|
+
m = re.search(r"^(?:from\s+\S+\s+import\s+.+|import\s+\S+).*\n+", code, flags=re.MULTILINE)
|
|
38
|
+
if m:
|
|
39
|
+
idx = m.end()
|
|
40
|
+
code = code[:idx] + prelude + code[idx:]
|
|
41
|
+
else:
|
|
42
|
+
code = prelude + code
|
|
43
|
+
return code
|
|
94
44
|
|
|
95
45
|
|
|
96
46
|
def _indent(code: str, spaces: int = 4) -> str:
|
|
@@ -123,7 +73,7 @@ def wrap_llm_code_safe(body: str) -> str:
|
|
|
123
73
|
" df_local = globals().get('df')\n"
|
|
124
74
|
" if df_local is not None:\n"
|
|
125
75
|
" import pandas as pd\n"
|
|
126
|
-
" from syntaxmatrix.preface import SB_histplot, _SMX_export_png\n"
|
|
76
|
+
" from syntaxmatrix.preface import SB_histplot, SB_boxplot, SB_scatterplot, SB_heatmap, _SMX_export_png\n"
|
|
127
77
|
" num_cols = df_local.select_dtypes(include=['number', 'bool']).columns.tolist()\n"
|
|
128
78
|
" cat_cols = [c for c in df_local.columns if c not in num_cols]\n"
|
|
129
79
|
" info = {\n"
|
|
@@ -137,11 +87,78 @@ def wrap_llm_code_safe(body: str) -> str:
|
|
|
137
87
|
" if num_cols:\n"
|
|
138
88
|
" SB_histplot()\n"
|
|
139
89
|
" _SMX_export_png()\n"
|
|
90
|
+
" if len(num_cols) >= 2:\n"
|
|
91
|
+
" SB_scatterplot()\n"
|
|
92
|
+
" _SMX_export_png()\n"
|
|
93
|
+
" if num_cols and cat_cols:\n"
|
|
94
|
+
" SB_boxplot()\n"
|
|
95
|
+
" _SMX_export_png()\n"
|
|
96
|
+
" if len(num_cols) >= 2:\n"
|
|
97
|
+
" SB_heatmap()\n"
|
|
98
|
+
" _SMX_export_png()\n"
|
|
140
99
|
" except Exception as _f:\n"
|
|
141
100
|
" show(f\"⚠️ Fallback EDA failed: {type(_f).__name__}: {_f}\")\n"
|
|
142
101
|
)
|
|
143
102
|
|
|
144
103
|
|
|
104
|
+
def fix_print_html(code: str) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Ensure that HTML / DataFrame HTML are *displayed* (and captured by the kernel),
|
|
107
|
+
not printed as `<IPython.core.display.HTML object>` to the server console.
|
|
108
|
+
- Rewrites: print(HTML(...)) → display(HTML(...))
|
|
109
|
+
print(display(...)) → display(...)
|
|
110
|
+
print(df.to_html(...)) → display(HTML(df.to_html(...)))
|
|
111
|
+
Also prepends `from IPython.display import display, HTML` if required.
|
|
112
|
+
"""
|
|
113
|
+
import re
|
|
114
|
+
|
|
115
|
+
new = code
|
|
116
|
+
|
|
117
|
+
# 1) print(HTML(...)) -> display(HTML(...))
|
|
118
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*HTML\s*\(", "display(HTML(", new)
|
|
119
|
+
|
|
120
|
+
# 2) print(display(...)) -> display(...)
|
|
121
|
+
new = re.sub(r"(?m)^\s*print\s*\(\s*display\s*\(", "display(", new)
|
|
122
|
+
|
|
123
|
+
# 3) print(<expr>.to_html(...)) -> display(HTML(<expr>.to_html(...)))
|
|
124
|
+
new = re.sub(
|
|
125
|
+
r"(?m)^\s*print\s*\(\s*([A-Za-z_]\w*(?:\.[A-Za-z_]\w*)*)\s*\.to_html\s*\(",
|
|
126
|
+
r"display(HTML(\1.to_html(", new
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# If code references HTML() or display() make sure the import exists
|
|
130
|
+
if ("HTML(" in new or re.search(r"\bdisplay\s*\(", new)) and \
|
|
131
|
+
"from IPython.display import display, HTML" not in new:
|
|
132
|
+
new = "from IPython.display import display, HTML\n" + new
|
|
133
|
+
|
|
134
|
+
return new
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def ensure_ipy_display(code: str) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Guarantee that the cell has proper IPython display imports so that
|
|
140
|
+
display(HTML(...)) produces 'display_data' events the kernel captures.
|
|
141
|
+
"""
|
|
142
|
+
if "display(" in code and "from IPython.display import display, HTML" not in code:
|
|
143
|
+
return "from IPython.display import display, HTML\n" + code
|
|
144
|
+
return code
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def fix_plain_prints(code: str) -> str:
|
|
148
|
+
"""
|
|
149
|
+
Rewrite bare `print(var)` where var looks like a dataframe/series/ndarray/etc
|
|
150
|
+
to go through SyntaxMatrix's smart display (so it renders in the dashboard).
|
|
151
|
+
Keeps string prints alone.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
# Skip obvious string-literal prints
|
|
155
|
+
new = re.sub(
|
|
156
|
+
r"(?m)^\s*print\(\s*([A-Za-z_]\w*)\s*\)\s*$",
|
|
157
|
+
r"from syntaxmatrix.display import show\nshow(\1)",
|
|
158
|
+
code,
|
|
159
|
+
)
|
|
160
|
+
return new
|
|
161
|
+
|
|
145
162
|
def harden_ai_code(code: str) -> str:
|
|
146
163
|
"""
|
|
147
164
|
Make any AI-generated cell resilient:
|
|
@@ -155,31 +172,6 @@ def harden_ai_code(code: str) -> str:
|
|
|
155
172
|
# Remove any LLM-added try/except blocks (hardener adds its own)
|
|
156
173
|
import re
|
|
157
174
|
|
|
158
|
-
def strip_placeholders(code: str) -> str:
|
|
159
|
-
code = re.sub(r"\bshow\(\s*\.\.\.\s*\)",
|
|
160
|
-
"show('⚠ Block skipped due to an error.')",
|
|
161
|
-
code)
|
|
162
|
-
code = re.sub(r"\breturn\s+\.\.\.", "return None", code)
|
|
163
|
-
return code
|
|
164
|
-
|
|
165
|
-
def _SMX_OHE(**k):
|
|
166
|
-
# normalise arg name across sklearn versions
|
|
167
|
-
if "sparse" in k and "sparse_output" not in k:
|
|
168
|
-
k["sparse_output"] = k.pop("sparse")
|
|
169
|
-
# default behaviour we want
|
|
170
|
-
k.setdefault("handle_unknown", "ignore")
|
|
171
|
-
k.setdefault("sparse_output", False)
|
|
172
|
-
try:
|
|
173
|
-
# if running on old sklearn without sparse_output, translate back
|
|
174
|
-
if "sparse_output" not in inspect.signature(OneHotEncoder).parameters:
|
|
175
|
-
if "sparse_output" in k:
|
|
176
|
-
k["sparse"] = k.pop("sparse_output")
|
|
177
|
-
return OneHotEncoder(**k)
|
|
178
|
-
except TypeError:
|
|
179
|
-
# final fallback: try legacy name
|
|
180
|
-
if "sparse_output" in k:
|
|
181
|
-
k["sparse"] = k.pop("sparse_output")
|
|
182
|
-
return OneHotEncoder(**k)
|
|
183
175
|
|
|
184
176
|
def _strip_stray_backrefs(code: str) -> str:
|
|
185
177
|
code = re.sub(r'(?m)^\s*\\\d+\s*', '', code)
|
|
@@ -229,21 +221,37 @@ def harden_ai_code(code: str) -> str:
|
|
|
229
221
|
|
|
230
222
|
return pattern.sub(repl, code)
|
|
231
223
|
|
|
224
|
+
|
|
232
225
|
def _wrap_metric_calls(code: str) -> str:
|
|
233
226
|
names = [
|
|
234
|
-
"r2_score",
|
|
235
|
-
"
|
|
236
|
-
"
|
|
237
|
-
"
|
|
238
|
-
"
|
|
227
|
+
"r2_score",
|
|
228
|
+
"accuracy_score",
|
|
229
|
+
"precision_score",
|
|
230
|
+
"recall_score",
|
|
231
|
+
"f1_score",
|
|
232
|
+
"roc_auc_score",
|
|
233
|
+
"classification_report",
|
|
234
|
+
"confusion_matrix",
|
|
235
|
+
"mean_absolute_error",
|
|
236
|
+
"mean_absolute_percentage_error",
|
|
237
|
+
"explained_variance_score",
|
|
238
|
+
"log_loss",
|
|
239
|
+
"average_precision_score",
|
|
240
|
+
"precision_recall_fscore_support",
|
|
241
|
+
"mean_squared_error",
|
|
239
242
|
]
|
|
240
|
-
pat = re.compile(
|
|
243
|
+
pat = re.compile(
|
|
244
|
+
r"\b(?:(sklearn\.metrics\.|metrics\.)?(" + "|".join(names) + r"))\s*\("
|
|
245
|
+
)
|
|
246
|
+
|
|
241
247
|
def repl(m):
|
|
242
248
|
prefix = m.group(1) or "" # "", "metrics.", or "sklearn.metrics."
|
|
243
249
|
name = m.group(2)
|
|
244
250
|
return f"_SMX_call({prefix}{name}, "
|
|
251
|
+
|
|
245
252
|
return pat.sub(repl, code)
|
|
246
253
|
|
|
254
|
+
|
|
247
255
|
def _smx_patch_mean_squared_error_squared_kw():
|
|
248
256
|
"""
|
|
249
257
|
sklearn<0.22 doesn't accept mean_squared_error(..., squared=False).
|
|
@@ -318,6 +326,8 @@ def harden_ai_code(code: str) -> str:
|
|
|
318
326
|
needed.add("r2_score")
|
|
319
327
|
if "mean_absolute_error" in code:
|
|
320
328
|
needed.add("mean_absolute_error")
|
|
329
|
+
if "mean_squared_error" in code:
|
|
330
|
+
needed.add("mean_squared_error")
|
|
321
331
|
# ... add others if you like ...
|
|
322
332
|
|
|
323
333
|
if not needed:
|
|
@@ -378,7 +388,8 @@ def harden_ai_code(code: str) -> str:
|
|
|
378
388
|
- then falls back to generic but useful EDA.
|
|
379
389
|
|
|
380
390
|
It assumes `from syntaxmatrix.preface import *` has already been done,
|
|
381
|
-
so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `
|
|
391
|
+
so `_SMX_OHE`, `_SMX_call`, `SB_histplot`, `SB_boxplot`,
|
|
392
|
+
`SB_scatterplot`, `SB_heatmap`, `_SMX_export_png` and the
|
|
382
393
|
patched `show()` are available.
|
|
383
394
|
"""
|
|
384
395
|
import textwrap
|
|
@@ -488,15 +499,91 @@ def harden_ai_code(code: str) -> str:
|
|
|
488
499
|
show(df.head(), title='Sample of data')
|
|
489
500
|
show(info, title='Dataset summary')
|
|
490
501
|
|
|
491
|
-
#
|
|
502
|
+
# 1) Distribution of a numeric column
|
|
492
503
|
if num_cols:
|
|
493
504
|
SB_histplot()
|
|
494
505
|
_SMX_export_png()
|
|
506
|
+
|
|
507
|
+
# 2) Relationship between two numeric columns
|
|
508
|
+
if len(num_cols) >= 2:
|
|
509
|
+
SB_scatterplot()
|
|
510
|
+
_SMX_export_png()
|
|
511
|
+
|
|
512
|
+
# 3) Distribution of a numeric by the first categorical column
|
|
513
|
+
if num_cols and cat_cols:
|
|
514
|
+
SB_boxplot()
|
|
515
|
+
_SMX_export_png()
|
|
516
|
+
|
|
517
|
+
# 4) Correlation heatmap across numeric columns
|
|
518
|
+
if len(num_cols) >= 2:
|
|
519
|
+
SB_heatmap()
|
|
520
|
+
_SMX_export_png()
|
|
495
521
|
except Exception as _eda_e:
|
|
496
522
|
show(f"⚠ EDA fallback failed: {type(_eda_e).__name__}: {_eda_e}")
|
|
497
523
|
"""
|
|
498
524
|
)
|
|
499
525
|
|
|
526
|
+
def _strip_file_io_ops(code: str) -> str:
|
|
527
|
+
"""
|
|
528
|
+
Remove obvious local file I/O operations in LLM code
|
|
529
|
+
so nothing writes to the container filesystem.
|
|
530
|
+
"""
|
|
531
|
+
# 1) Methods like df.to_csv(...), df.to_excel(...), etc.
|
|
532
|
+
FILE_WRITE_METHODS = (
|
|
533
|
+
"to_csv", "to_excel", "to_pickle", "to_parquet",
|
|
534
|
+
"to_json", "to_hdf",
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
for mname in FILE_WRITE_METHODS:
|
|
538
|
+
pat = re.compile(
|
|
539
|
+
rf"(?m)^(\s*)([A-Za-z_][A-Za-z0-9_\.]*)\s*\.\s*{mname}\s*\([^)]*\)\s*$"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def _repl(match):
|
|
543
|
+
indent = match.group(1)
|
|
544
|
+
expr = match.group(2)
|
|
545
|
+
return f"{indent}# [SMX] stripped file write: {expr}.{mname}(...)"
|
|
546
|
+
|
|
547
|
+
code = pat.sub(_repl, code)
|
|
548
|
+
|
|
549
|
+
# 2) plt.savefig(...) calls
|
|
550
|
+
pat_savefig = re.compile(r"(?m)^(\s*)(plt\.savefig\s*\([^)]*\)\s*)$")
|
|
551
|
+
code = pat_savefig.sub(
|
|
552
|
+
lambda m: f"{m.group(1)}# [SMX] stripped savefig: {m.group(2).strip()}",
|
|
553
|
+
code,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# 3) with open(..., 'w'/'wb') as f:
|
|
557
|
+
pat_with_open = re.compile(
|
|
558
|
+
r"(?m)^(\s*)with\s+open\([^)]*['\"]w[b]?['\"][^)]*\)\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*:\s*$"
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def _with_open_repl(match):
|
|
562
|
+
indent = match.group(1)
|
|
563
|
+
var = match.group(2)
|
|
564
|
+
return f"{indent}if False: # [SMX] file write stripped (was: with open(... as {var}))"
|
|
565
|
+
|
|
566
|
+
code = pat_with_open.sub(_with_open_repl, code)
|
|
567
|
+
|
|
568
|
+
# 4) joblib.dump(...), pickle.dump(...)
|
|
569
|
+
for mod in ("joblib", "pickle"):
|
|
570
|
+
pat = re.compile(rf"(?m)^(\s*){mod}\.dump\s*\([^)]*\)\s*$")
|
|
571
|
+
code = pat.sub(
|
|
572
|
+
lambda m: f"{m.group(1)}# [SMX] stripped {mod}.dump(...)",
|
|
573
|
+
code,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
# 5) bare open(..., 'w'/'wb') calls
|
|
577
|
+
pat_open = re.compile(
|
|
578
|
+
r"(?m)^(\s*)open\([^)]*['\"]w[b]?['\"][^)]*\)\s*$"
|
|
579
|
+
)
|
|
580
|
+
code = pat_open.sub(
|
|
581
|
+
lambda m: f"{m.group(1)}# [SMX] stripped open(..., 'w'/'wb')",
|
|
582
|
+
code,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
return code
|
|
586
|
+
|
|
500
587
|
# Register and run patches once per execution
|
|
501
588
|
for _patch in (
|
|
502
589
|
_smx_patch_mean_squared_error_squared_kw,
|
|
@@ -562,19 +649,21 @@ def harden_ai_code(code: str) -> str:
|
|
|
562
649
|
except (SyntaxError, IndentationError):
|
|
563
650
|
fixed = _fallback_snippet()
|
|
564
651
|
|
|
652
|
+
|
|
653
|
+
# Fix placeholder Ellipsis handlers from LLM
|
|
565
654
|
fixed = re.sub(
|
|
566
655
|
r"except\s+Exception\s+as\s+e:\s*\n\s*show\(\.\.\.\)",
|
|
567
656
|
"except Exception as e:\n show(f\"⚠ Block skipped due to: {type(e).__name__}: {e}\")",
|
|
568
657
|
fixed,
|
|
569
658
|
)
|
|
570
659
|
|
|
571
|
-
#
|
|
660
|
+
# redirect that import to the real template module.
|
|
572
661
|
fixed = re.sub(
|
|
573
|
-
r"
|
|
574
|
-
"
|
|
662
|
+
r"from\s+syntaxmatrix\.templates\s+import\s+([^\n]+)",
|
|
663
|
+
r"from syntaxmatrix.agentic.model_templates import \1",
|
|
575
664
|
fixed,
|
|
576
665
|
)
|
|
577
|
-
|
|
666
|
+
|
|
578
667
|
try:
|
|
579
668
|
class _SMXMatmulRewriter(ast.NodeTransformer):
|
|
580
669
|
def visit_BinOp(self, node):
|
|
@@ -593,11 +682,25 @@ def harden_ai_code(code: str) -> str:
|
|
|
593
682
|
# 6) Final safety wrapper
|
|
594
683
|
fixed = fixed.replace("\t", " ")
|
|
595
684
|
fixed = textwrap.dedent(fixed).strip("\n")
|
|
685
|
+
|
|
596
686
|
fixed = _ensure_metrics_imports(fixed)
|
|
597
687
|
fixed = _strip_stray_backrefs(fixed)
|
|
598
688
|
fixed = _wrap_metric_calls(fixed)
|
|
599
689
|
fixed = _fix_unexpected_indent(fixed)
|
|
600
690
|
fixed = _patch_feature_coef_dataframe(fixed)
|
|
691
|
+
fixed = _strip_file_io_ops(fixed)
|
|
692
|
+
|
|
693
|
+
metric_defaults = "\n".join([
|
|
694
|
+
"acc = None",
|
|
695
|
+
"accuracy = None",
|
|
696
|
+
"r2 = None",
|
|
697
|
+
"mae = None",
|
|
698
|
+
"rmse = None",
|
|
699
|
+
"f1 = None",
|
|
700
|
+
"precision = None",
|
|
701
|
+
"recall = None",
|
|
702
|
+
]) + "\n"
|
|
703
|
+
fixed = metric_defaults + fixed
|
|
601
704
|
|
|
602
705
|
# Import shared preface helpers once and wrap the LLM body safely
|
|
603
706
|
header = "from syntaxmatrix.preface import *\n\n"
|
|
@@ -605,822 +708,822 @@ def harden_ai_code(code: str) -> str:
|
|
|
605
708
|
return wrapped
|
|
606
709
|
|
|
607
710
|
|
|
608
|
-
def indent_code(code: str, spaces: int = 4) -> str:
|
|
609
|
-
|
|
610
|
-
|
|
711
|
+
# def indent_code(code: str, spaces: int = 4) -> str:
|
|
712
|
+
# pad = " " * spaces
|
|
713
|
+
# return "\n".join(pad + line for line in code.splitlines())
|
|
611
714
|
|
|
612
715
|
|
|
613
|
-
def fix_boxplot_placeholder(code: str) -> str:
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
716
|
+
# def fix_boxplot_placeholder(code: str) -> str:
|
|
717
|
+
# # Replace invalid 'sns.boxplot(boxplot)' with a safe call using df/group_label/m
|
|
718
|
+
# return re.sub(
|
|
719
|
+
# r"sns\.boxplot\(\s*boxplot\s*\)",
|
|
720
|
+
# "sns.boxplot(x=group_label, y=m, data=df.loc[df[m].notnull()], showfliers=False)",
|
|
721
|
+
# code
|
|
722
|
+
# )
|
|
620
723
|
|
|
621
724
|
|
|
622
|
-
def relax_required_columns(code: str) -> str:
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
725
|
+
# def relax_required_columns(code: str) -> str:
|
|
726
|
+
# # Remove hard failure on required_cols; keep a soft filter instead
|
|
727
|
+
# return re.sub(
|
|
728
|
+
# r"required_cols\s*=\s*\[.*?\]\s*?\n\s*missing\s*=\s*\[.*?\]\s*?\n\s*if\s+missing:\s*raise[^\n]+",
|
|
729
|
+
# "required_cols = [c for c in df.columns]\n",
|
|
730
|
+
# code,
|
|
731
|
+
# flags=re.S
|
|
732
|
+
# )
|
|
630
733
|
|
|
631
734
|
|
|
632
|
-
def make_numeric_vars_dynamic(code: str) -> str:
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
735
|
+
# def make_numeric_vars_dynamic(code: str) -> str:
|
|
736
|
+
# # Replace any static numeric_vars list with a dynamic selection
|
|
737
|
+
# return re.sub(
|
|
738
|
+
# r"numeric_vars\s*=\s*\[.*?\]",
|
|
739
|
+
# "numeric_vars = df.select_dtypes(include=['number','bool']).columns.tolist()",
|
|
740
|
+
# code,
|
|
741
|
+
# flags=re.S
|
|
742
|
+
# )
|
|
640
743
|
|
|
641
744
|
|
|
642
|
-
def auto_inject_template(code: str, intents, df) -> str:
|
|
643
|
-
|
|
745
|
+
# def auto_inject_template(code: str, intents, df) -> str:
|
|
746
|
+
# """If the LLM forgot the core logic, prepend a skeleton block."""
|
|
644
747
|
|
|
645
|
-
|
|
646
|
-
|
|
748
|
+
# has_fit = ".fit(" in code
|
|
749
|
+
# has_plot = any(k in code for k in ("plt.", "sns.", ".plot(", ".hist("))
|
|
647
750
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
751
|
+
# UNKNOWN_TOKENS = {
|
|
752
|
+
# "unknown","not reported","not_reported","not known","n/a","na",
|
|
753
|
+
# "none","nan","missing","unreported","unspecified","null","-",""
|
|
754
|
+
# }
|
|
755
|
+
|
|
756
|
+
# # --- Safe template caller: passes only supported kwargs, falls back cleanly ---
|
|
757
|
+
# def _call_template(func, df, **hints):
|
|
758
|
+
# import inspect
|
|
759
|
+
# try:
|
|
760
|
+
# params = inspect.signature(func).parameters
|
|
761
|
+
# kw = {k: v for k, v in hints.items() if k in params}
|
|
762
|
+
# try:
|
|
763
|
+
# return func(df, **kw)
|
|
764
|
+
# except TypeError:
|
|
765
|
+
# # In case the template changed its signature at runtime
|
|
766
|
+
# return func(df)
|
|
767
|
+
# except Exception:
|
|
768
|
+
# # Absolute safety net
|
|
769
|
+
# try:
|
|
770
|
+
# return func(df)
|
|
771
|
+
# except Exception:
|
|
772
|
+
# # As a last resort, return empty code so we don't 500
|
|
773
|
+
# return ""
|
|
774
|
+
|
|
775
|
+
# def _guess_classification_target(df: pd.DataFrame) -> str | None:
|
|
776
|
+
# cols = list(df.columns)
|
|
777
|
+
|
|
778
|
+
# # Helper: does this column look like a sensible label?
|
|
779
|
+
# def _is_reasonable_class_col(s: pd.Series, col_name: str) -> bool:
|
|
780
|
+
# try:
|
|
781
|
+
# nunq = s.dropna().nunique()
|
|
782
|
+
# except Exception:
|
|
783
|
+
# return False
|
|
784
|
+
# # need at least 2 classes, but not hundreds
|
|
785
|
+
# if nunq < 2 or nunq > 20:
|
|
786
|
+
# return False
|
|
787
|
+
# bad_name_keys = ("id", "identifier", "index", "uuid", "key")
|
|
788
|
+
# name = str(col_name).lower()
|
|
789
|
+
# if any(k in name for k in bad_name_keys):
|
|
790
|
+
# return False
|
|
791
|
+
# return True
|
|
792
|
+
|
|
793
|
+
# # 1) columns whose names look like labels
|
|
794
|
+
# label_keys = ("target", "label", "outcome", "class", "y", "status")
|
|
795
|
+
# name_candidates: list[str] = []
|
|
796
|
+
# for key in label_keys:
|
|
797
|
+
# for c in cols:
|
|
798
|
+
# if key in str(c).lower():
|
|
799
|
+
# name_candidates.append(c)
|
|
800
|
+
# if name_candidates:
|
|
801
|
+
# break # keep the earliest matching key-group
|
|
802
|
+
|
|
803
|
+
# # prioritise name-based candidates that also look like proper label columns
|
|
804
|
+
# for c in name_candidates:
|
|
805
|
+
# if _is_reasonable_class_col(df[c], c):
|
|
806
|
+
# return c
|
|
807
|
+
# if name_candidates:
|
|
808
|
+
# # fall back to the first name-based candidate if none passed the shape test
|
|
809
|
+
# return name_candidates[0]
|
|
810
|
+
|
|
811
|
+
# # 2) any column with a small number of distinct values (likely a class label)
|
|
812
|
+
# for c in cols:
|
|
813
|
+
# s = df[c]
|
|
814
|
+
# if _is_reasonable_class_col(s, c):
|
|
815
|
+
# return c
|
|
816
|
+
|
|
817
|
+
# # Nothing suitable found
|
|
818
|
+
# return None
|
|
819
|
+
|
|
820
|
+
# def _guess_regression_target(df: pd.DataFrame) -> str | None:
|
|
821
|
+
# num_cols = df.select_dtypes(include=[np.number, "bool"]).columns.tolist()
|
|
822
|
+
# if not num_cols:
|
|
823
|
+
# return None
|
|
824
|
+
# # Avoid obvious ID-like columns
|
|
825
|
+
# bad_keys = ("id", "identifier", "index")
|
|
826
|
+
# candidates = [c for c in num_cols if not any(k in str(c).lower() for k in bad_keys)]
|
|
827
|
+
# return (candidates or num_cols)[-1]
|
|
725
828
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
829
|
+
# def _guess_time_col(df: pd.DataFrame) -> str | None:
|
|
830
|
+
# # Prefer actual datetime dtype
|
|
831
|
+
# dt_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.datetime64)]
|
|
832
|
+
# if dt_cols:
|
|
833
|
+
# return dt_cols[0]
|
|
834
|
+
|
|
835
|
+
# # Fallback: name-based hints
|
|
836
|
+
# name_keys = ["date", "time", "timestamp", "datetime", "ds", "period"]
|
|
837
|
+
# for c in df.columns:
|
|
838
|
+
# name = str(c).lower()
|
|
839
|
+
# if any(k in name for k in name_keys):
|
|
840
|
+
# return c
|
|
841
|
+
# return None
|
|
842
|
+
|
|
843
|
+
# def _guess_entity_col(df: pd.DataFrame) -> str | None:
|
|
844
|
+
# # Typical sequence IDs: id, patient, subject, device, series, entity
|
|
845
|
+
# keys = ["id", "patient", "subject", "device", "series", "entity"]
|
|
846
|
+
# candidates = []
|
|
847
|
+
# for c in df.columns:
|
|
848
|
+
# name = str(c).lower()
|
|
849
|
+
# if any(k in name for k in keys):
|
|
850
|
+
# candidates.append(c)
|
|
851
|
+
# return candidates[0] if candidates else None
|
|
852
|
+
|
|
853
|
+
# def _guess_ts_class_target(df: pd.DataFrame) -> str | None:
|
|
854
|
+
# # Try label-like names first
|
|
855
|
+
# keys = ["target", "label", "class", "outcome", "y"]
|
|
856
|
+
# for key in keys:
|
|
857
|
+
# for c in df.columns:
|
|
858
|
+
# if key in str(c).lower():
|
|
859
|
+
# return c
|
|
860
|
+
|
|
861
|
+
# # Fallback: any column with few distinct values (e.g. <= 10)
|
|
862
|
+
# for c in df.columns:
|
|
863
|
+
# s = df[c]
|
|
864
|
+
# # avoid obvious IDs
|
|
865
|
+
# if any(k in str(c).lower() for k in ["id", "index"]):
|
|
866
|
+
# continue
|
|
867
|
+
# try:
|
|
868
|
+
# nunq = s.dropna().nunique()
|
|
869
|
+
# except Exception:
|
|
870
|
+
# continue
|
|
871
|
+
# if 1 < nunq <= 10:
|
|
872
|
+
# return c
|
|
873
|
+
|
|
874
|
+
# return None
|
|
875
|
+
|
|
876
|
+
# def _guess_multilabel_cols(df: pd.DataFrame) -> list[str]:
|
|
877
|
+
# cols = list(df.columns)
|
|
878
|
+
# lbl_like = [c for c in cols if str(c).startswith(("LBL_", "lbl_"))]
|
|
879
|
+
# # also include boolean/binary columns with suitable names
|
|
880
|
+
# for c in cols:
|
|
881
|
+
# s = df[c]
|
|
882
|
+
# try:
|
|
883
|
+
# nunq = s.dropna().nunique()
|
|
884
|
+
# except Exception:
|
|
885
|
+
# continue
|
|
886
|
+
# if nunq in (2,) and c not in lbl_like:
|
|
887
|
+
# # avoid obvious IDs
|
|
888
|
+
# if not any(k in str(c).lower() for k in ("id","index","uuid","identifier")):
|
|
889
|
+
# lbl_like.append(c)
|
|
890
|
+
# # keep at most, say, 12 to avoid accidental flood
|
|
891
|
+
# return lbl_like[:12]
|
|
892
|
+
|
|
893
|
+
# def _find_unknownish_column(df: pd.DataFrame) -> str | None:
|
|
894
|
+
# # Search categorical-like columns for any 'unknown-like' values or high missingness
|
|
895
|
+
# candidates = []
|
|
896
|
+
# for c in df.columns:
|
|
897
|
+
# s = df[c]
|
|
898
|
+
# # focus on object/category/boolean-ish or low-card columns
|
|
899
|
+
# if not (pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= 20):
|
|
900
|
+
# continue
|
|
901
|
+
# try:
|
|
902
|
+
# vals = s.astype(str).str.strip().str.lower()
|
|
903
|
+
# except Exception:
|
|
904
|
+
# continue
|
|
905
|
+
# # score: presence of unknown tokens + missing rate
|
|
906
|
+
# token_hit = int(vals.isin(UNKNOWN_TOKENS).any())
|
|
907
|
+
# miss_rate = s.isna().mean()
|
|
908
|
+
# name_bonus = int(any(k in str(c).lower() for k in ("status","history","report","known","flag")))
|
|
909
|
+
# score = 3*token_hit + 2*name_bonus + miss_rate
|
|
910
|
+
# if token_hit or miss_rate > 0.05 or name_bonus:
|
|
911
|
+
# candidates.append((score, c))
|
|
912
|
+
# if not candidates:
|
|
913
|
+
# return None
|
|
914
|
+
# candidates.sort(reverse=True)
|
|
915
|
+
# return candidates[0][1]
|
|
916
|
+
|
|
917
|
+
# def _guess_numeric_cols(df: pd.DataFrame, max_n: int = 6) -> list[str]:
|
|
918
|
+
# cols = [c for c in df.select_dtypes(include=[np.number, "bool"]).columns if not any(k in str(c).lower() for k in ("id","identifier","index","uuid"))]
|
|
919
|
+
# # prefer non-constant columns
|
|
920
|
+
# scored = []
|
|
921
|
+
# for c in cols:
|
|
922
|
+
# try:
|
|
923
|
+
# v = df[c].dropna()
|
|
924
|
+
# var = float(v.var()) if len(v) else 0.0
|
|
925
|
+
# scored.append((var, c))
|
|
926
|
+
# except Exception:
|
|
927
|
+
# continue
|
|
928
|
+
# scored.sort(reverse=True)
|
|
929
|
+
# return [c for _, c in scored[:max_n]]
|
|
930
|
+
|
|
931
|
+
# def _guess_categorical_cols(df: pd.DataFrame, exclude: set[str] | None = None, max_card: int = 12, max_n: int = 5) -> list[str]:
|
|
932
|
+
# exclude = exclude or set()
|
|
933
|
+
# picks = []
|
|
934
|
+
# for c in df.columns:
|
|
935
|
+
# if c in exclude:
|
|
936
|
+
# continue
|
|
937
|
+
# s = df[c]
|
|
938
|
+
# if pd.api.types.is_object_dtype(s) or pd.api.types.is_categorical_dtype(s) or s.dropna().nunique() <= max_card:
|
|
939
|
+
# nunq = s.dropna().nunique()
|
|
940
|
+
# if 2 <= nunq <= max_card and not any(k in str(c).lower() for k in ("id","identifier","index","uuid")):
|
|
941
|
+
# picks.append((nunq, c))
|
|
942
|
+
# picks.sort(reverse=True)
|
|
943
|
+
# return [c for _, c in picks[:max_n]]
|
|
944
|
+
|
|
945
|
+
# def _guess_outcome_col(df: pd.DataFrame, exclude: set[str] | None = None) -> str | None:
|
|
946
|
+
# exclude = exclude or set()
|
|
947
|
+
# # name hints first
|
|
948
|
+
# name_keys = ("outcome","target","label","risk","score","result","prevalence","positivity")
|
|
949
|
+
# for c in df.columns:
|
|
950
|
+
# if c in exclude:
|
|
951
|
+
# continue
|
|
952
|
+
# name = str(c).lower()
|
|
953
|
+
# if any(k in name for k in name_keys) and pd.api.types.is_numeric_dtype(df[c]):
|
|
954
|
+
# return c
|
|
955
|
+
# # fallback: any binary numeric
|
|
956
|
+
# for c in df.select_dtypes(include=[np.number, "bool"]).columns:
|
|
957
|
+
# if c in exclude:
|
|
958
|
+
# continue
|
|
959
|
+
# try:
|
|
960
|
+
# if df[c].dropna().nunique() == 2:
|
|
961
|
+
# return c
|
|
962
|
+
# except Exception:
|
|
963
|
+
# continue
|
|
964
|
+
# return None
|
|
862
965
|
|
|
863
966
|
|
|
864
|
-
|
|
865
|
-
|
|
967
|
+
# def _pick_viz_template(signal: str):
|
|
968
|
+
# s = signal.lower()
|
|
866
969
|
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
970
|
+
# # explicit chart requests
|
|
971
|
+
# if any(k in s for k in ("pie", "donut")):
|
|
972
|
+
# return viz_pie
|
|
870
973
|
|
|
871
|
-
|
|
872
|
-
|
|
974
|
+
# if any(k in s for k in ("stacked", "100% stacked", "composition", "proportion", "share by")):
|
|
975
|
+
# return viz_stacked_bar
|
|
873
976
|
|
|
874
|
-
|
|
875
|
-
|
|
977
|
+
# if any(k in s for k in ("distribution", "hist", "histogram", "bins")):
|
|
978
|
+
# return viz_distribution
|
|
876
979
|
|
|
877
|
-
|
|
878
|
-
|
|
980
|
+
# if any(k in s for k in ("kde", "density")):
|
|
981
|
+
# return viz_kde
|
|
879
982
|
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
983
|
+
# # these three you asked about
|
|
984
|
+
# if any(k in s for k in ("box", "boxplot", "violin", "spread", "outlier")):
|
|
985
|
+
# return viz_box
|
|
883
986
|
|
|
884
|
-
|
|
885
|
-
|
|
987
|
+
# if any(k in s for k in ("scatter", "relationship", "vs ", "correlate")):
|
|
988
|
+
# return viz_scatter
|
|
886
989
|
|
|
887
|
-
|
|
888
|
-
|
|
990
|
+
# if any(k in s for k in ("count", "counts", "frequency", "bar chart", "barplot")):
|
|
991
|
+
# return viz_count_bar
|
|
889
992
|
|
|
890
|
-
|
|
891
|
-
|
|
993
|
+
# if any(k in s for k in ("area", "trend", "over time", "time series")):
|
|
994
|
+
# return viz_area
|
|
892
995
|
|
|
893
|
-
|
|
894
|
-
|
|
996
|
+
# # fallback
|
|
997
|
+
# return viz_line
|
|
895
998
|
|
|
896
|
-
|
|
999
|
+
# for intent in intents:
|
|
897
1000
|
|
|
898
|
-
|
|
899
|
-
|
|
1001
|
+
# if intent not in INJECTABLE_INTENTS:
|
|
1002
|
+
# return code
|
|
900
1003
|
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
1004
|
+
# # Correlation analysis
|
|
1005
|
+
# if intent == "correlation_analysis" and not has_fit:
|
|
1006
|
+
# return eda_correlation(df) + "\n\n" + code
|
|
1007
|
+
|
|
1008
|
+
# # Generic visualisation (keyword-based)
|
|
1009
|
+
# if intent == "visualisation" and not has_fit and not has_plot:
|
|
1010
|
+
# rq = str(globals().get("refined_question", ""))
|
|
1011
|
+
# # aq = str(globals().get("askai_question", ""))
|
|
1012
|
+
# signal = rq + "\n" + str(intents) + "\n" + code
|
|
1013
|
+
# tpl = _pick_viz_template(signal)
|
|
1014
|
+
# return tpl(df) + "\n\n" + code
|
|
912
1015
|
|
|
913
|
-
|
|
914
|
-
|
|
1016
|
+
# if intent == "clustering" and not has_fit:
|
|
1017
|
+
# return clustering(df) + "\n\n" + code
|
|
915
1018
|
|
|
916
|
-
|
|
917
|
-
|
|
1019
|
+
# if intent == "recommendation" and not has_fit:
|
|
1020
|
+
# return recommendation(df) + "\\n\\n" + code
|
|
918
1021
|
|
|
919
|
-
|
|
920
|
-
|
|
1022
|
+
# if intent == "topic_modelling" and not has_fit:
|
|
1023
|
+
# return topic_modelling(df) + "\\n\\n" + code
|
|
921
1024
|
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
1025
|
+
# if intent == "eda" and not has_fit:
|
|
1026
|
+
# return code + "\n\nSB_heatmap(df.corr())" # Inject heatmap if 'eda' intent
|
|
1027
|
+
|
|
1028
|
+
# # --- Classification ------------------------------------------------
|
|
1029
|
+
# if intent == "classification" and not has_fit:
|
|
1030
|
+
# target = _guess_classification_target(df)
|
|
1031
|
+
# if target:
|
|
1032
|
+
# return classification(df) + "\n\n" + code
|
|
1033
|
+
# # return _call_template(classification, df, target) + "\n\n" + code
|
|
1034
|
+
|
|
1035
|
+
# # --- Regression ----------------------------------------------------
|
|
1036
|
+
# if intent == "regression" and not has_fit:
|
|
1037
|
+
# target = _guess_regression_target(df)
|
|
1038
|
+
# if target:
|
|
1039
|
+
# return regression(df) + "\n\n" + code
|
|
1040
|
+
# # return _call_template(regression, df, target) + "\n\n" + code
|
|
1041
|
+
|
|
1042
|
+
# # --- Anomaly detection --------------------------------------------
|
|
1043
|
+
# if intent == "anomaly_detection":
|
|
1044
|
+
# uses_anomaly = any(k in code for k in ("IsolationForest", "LocalOutlierFactor", "OneClassSVM"))
|
|
1045
|
+
# if not uses_anomaly:
|
|
1046
|
+
# return anomaly_detection(df) + "\n\n" + code
|
|
1047
|
+
|
|
1048
|
+
# # --- Time-series anomaly detection --------------------------------
|
|
1049
|
+
# if intent == "ts_anomaly_detection":
|
|
1050
|
+
# uses_ts = "STL(" in code or "seasonal_decompose(" in code
|
|
1051
|
+
# if not uses_ts:
|
|
1052
|
+
# return ts_anomaly_detection(df) + "\n\n" + code
|
|
950
1053
|
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
1054
|
+
# # --- Time-series classification --------------------------------
|
|
1055
|
+
# if intent == "time_series_classification" and not has_fit:
|
|
1056
|
+
# time_col = _guess_time_col(df)
|
|
1057
|
+
# entity_col = _guess_entity_col(df)
|
|
1058
|
+
# target_col = _guess_ts_class_target(df)
|
|
1059
|
+
|
|
1060
|
+
# # If we can't confidently identify these, do NOT inject anything
|
|
1061
|
+
# if time_col and entity_col and target_col:
|
|
1062
|
+
# return time_series_classification(df, entity_col, time_col, target_col) + "\n\n" + code
|
|
1063
|
+
|
|
1064
|
+
# # --- Dimensionality reduction --------------------------------------
|
|
1065
|
+
# if intent == "dimensionality_reduction":
|
|
1066
|
+
# uses_dr = any(k in code for k in ("PCA(", "TSNE("))
|
|
1067
|
+
# if not uses_dr:
|
|
1068
|
+
# return dimensionality_reduction(df) + "\n\n" + code
|
|
1069
|
+
|
|
1070
|
+
# # --- Feature selection ---------------------------------------------
|
|
1071
|
+
# if intent == "feature_selection":
|
|
1072
|
+
# uses_fs = any(k in code for k in (
|
|
1073
|
+
# "mutual_info_", "permutation_importance(", "SelectKBest(", "RFE("
|
|
1074
|
+
# ))
|
|
1075
|
+
# if not uses_fs:
|
|
1076
|
+
# return feature_selection(df) + "\n\n" + code
|
|
1077
|
+
|
|
1078
|
+
# # --- EDA / correlation / visualisation -----------------------------
|
|
1079
|
+
# if intent in ("eda", "correlation_analysis", "visualisation") and not has_plot:
|
|
1080
|
+
# if intent == "correlation_analysis":
|
|
1081
|
+
# return eda_correlation(df) + "\n\n" + code
|
|
1082
|
+
# else:
|
|
1083
|
+
# return eda_overview(df) + "\n\n" + code
|
|
1084
|
+
|
|
1085
|
+
# # --- Time-series forecasting ---------------------------------------
|
|
1086
|
+
# if intent == "time_series_forecasting" and not has_fit:
|
|
1087
|
+
# uses_ts_forecast = any(k in code for k in (
|
|
1088
|
+
# "ARIMA", "ExponentialSmoothing", "forecast", "predict("
|
|
1089
|
+
# ))
|
|
1090
|
+
# if not uses_ts_forecast:
|
|
1091
|
+
# return time_series_forecasting(df) + "\n\n" + code
|
|
989
1092
|
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1093
|
+
# # --- Multi-label classification -----------------------------------
|
|
1094
|
+
# if intent in ("multilabel_classification",) and not has_fit:
|
|
1095
|
+
# label_cols = _guess_multilabel_cols(df)
|
|
1096
|
+
# if len(label_cols) >= 2:
|
|
1097
|
+
# return multilabel_classification(df, label_cols) + "\n\n" + code
|
|
1098
|
+
|
|
1099
|
+
# group_col = _find_unknownish_column(df)
|
|
1100
|
+
# if group_col:
|
|
1101
|
+
# num_cols = _guess_numeric_cols(df)
|
|
1102
|
+
# cat_cols = _guess_categorical_cols(df, exclude={group_col})
|
|
1103
|
+
# outcome_col = None # generic; let template skip if not present
|
|
1104
|
+
# tpl = unknown_group_proxy_pack(df, group_col, UNKNOWN_TOKENS, num_cols, cat_cols, outcome_col)
|
|
1105
|
+
|
|
1106
|
+
# # Return template + guarded (repaired) LLM code, so it never crashes
|
|
1107
|
+
# repaired = make_numeric_vars_dynamic(relax_required_columns(fix_boxplot_placeholder(code)))
|
|
1108
|
+
# return tpl + "\n\n" + wrap_llm_code_safe(repaired)
|
|
1006
1109
|
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
def strip_describe_slice(code: str) -> str:
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
def remove_plt_show(code: str) -> str:
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
def patch_plot_with_table(code: str) -> str:
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
+
# return code
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
# def fix_values_sum_numeric_only_bug(code: str) -> str:
|
|
1114
|
+
# """
|
|
1115
|
+
# If a previous pass injected numeric_only=True into a NumPy-style sum,
|
|
1116
|
+
# e.g. .values.sum(numeric_only=True), strip it and canonicalize to .to_numpy().sum().
|
|
1117
|
+
# """
|
|
1118
|
+
# # .values.sum(numeric_only=True, ...)
|
|
1119
|
+
# code = re.sub(
|
|
1120
|
+
# r"\.values\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1121
|
+
# ".to_numpy().sum()",
|
|
1122
|
+
# code,
|
|
1123
|
+
# flags=re.IGNORECASE,
|
|
1124
|
+
# )
|
|
1125
|
+
# # .to_numpy().sum(numeric_only=True, ...)
|
|
1126
|
+
# code = re.sub(
|
|
1127
|
+
# r"\.to_numpy\(\)\s*\.sum\s*\(\s*[^)]*numeric_only\s*=\s*True[^)]*\)",
|
|
1128
|
+
# ".to_numpy().sum()",
|
|
1129
|
+
# code,
|
|
1130
|
+
# flags=re.IGNORECASE,
|
|
1131
|
+
# )
|
|
1132
|
+
# return code
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
# def strip_describe_slice(code: str) -> str:
|
|
1136
|
+
# """
|
|
1137
|
+
# Remove any pattern like df.groupby(...).describe()[[ ... ]] because
|
|
1138
|
+
# slicing a SeriesGroupBy.describe() causes AttributeError.
|
|
1139
|
+
# We leave the plain .describe() in place (harmless) and let our own
|
|
1140
|
+
# table patcher add the correct .agg() table afterwards.
|
|
1141
|
+
# """
|
|
1142
|
+
# pat = re.compile(
|
|
1143
|
+
# r"(df\.groupby\([^)]+\)\[[^\]]+\]\.describe\()\s*\[[^\]]+\]\)",
|
|
1144
|
+
# flags=re.DOTALL,
|
|
1145
|
+
# )
|
|
1146
|
+
# return pat.sub(r"\1)", code)
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
# def remove_plt_show(code: str) -> str:
|
|
1150
|
+
# """Removes all plt.show() calls from the generated code string."""
|
|
1151
|
+
# return "\n".join(line for line in code.splitlines() if "plt.show()" not in line)
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
# def patch_plot_with_table(code: str) -> str:
|
|
1155
|
+
# """
|
|
1156
|
+
# ▸ strips every `plt.show()` (avoids warnings)
|
|
1157
|
+
# ▸ converts the *last* Matplotlib / Seaborn figure to PNG-HTML so it is
|
|
1158
|
+
# rendered in the dashboard
|
|
1159
|
+
# ▸ appends a summary-stats table **after** the plot
|
|
1160
|
+
# """
|
|
1161
|
+
# # 0. drop plt.show()
|
|
1162
|
+
# lines = [ln for ln in code.splitlines() if "plt.show()" not in ln]
|
|
1163
|
+
|
|
1164
|
+
# # 1. locate the last plotting line
|
|
1165
|
+
# plot_kw = ['plt.', 'sns.', '.plot(', '.boxplot(', '.hist(']
|
|
1166
|
+
# last_plot = max((i for i,l in enumerate(lines) if any(k in l for k in plot_kw)), default=-1)
|
|
1167
|
+
# if last_plot == -1:
|
|
1168
|
+
# return "\n".join(lines) # nothing to do
|
|
1169
|
+
|
|
1170
|
+
# whole = "\n".join(lines)
|
|
1171
|
+
|
|
1172
|
+
# # 2. detect group / feature (if any)
|
|
1173
|
+
# group, feature = None, None
|
|
1174
|
+
# xm = re.search(r"x\s*=\s*['\"](\w+)['\"]", whole)
|
|
1175
|
+
# ym = re.search(r"y\s*=\s*['\"](\w+)['\"]", whole)
|
|
1176
|
+
# if xm and ym:
|
|
1177
|
+
# group, feature = xm.group(1), ym.group(1)
|
|
1178
|
+
# else:
|
|
1179
|
+
# cm = re.search(r"column\s*=\s*['\"](\w+)['\"].*by\s*=\s*['\"](\w+)['\"]", whole)
|
|
1180
|
+
# if cm:
|
|
1181
|
+
# feature, group = cm.group(1), cm.group(2)
|
|
1182
|
+
|
|
1183
|
+
# # 3. code that captures current fig → PNG → HTML
|
|
1184
|
+
# img_block = textwrap.dedent("""
|
|
1185
|
+
# import io, base64
|
|
1186
|
+
# buf = io.BytesIO()
|
|
1187
|
+
# plt.savefig(buf, format='png', bbox_inches='tight')
|
|
1188
|
+
# buf.seek(0)
|
|
1189
|
+
# img_b64 = base64.b64encode(buf.read()).decode('utf-8')
|
|
1190
|
+
# from IPython.display import display, HTML
|
|
1191
|
+
# display(HTML(f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%;">'))
|
|
1192
|
+
# plt.close()
|
|
1193
|
+
# """)
|
|
1194
|
+
|
|
1195
|
+
# # 4. build summary-table code
|
|
1196
|
+
# if group and feature:
|
|
1197
|
+
# tbl_block = (
|
|
1198
|
+
# f"summary_table = (\n"
|
|
1199
|
+
# f" df.groupby('{group}')['{feature}']\n"
|
|
1200
|
+
# f" .agg(['count','mean','std','min','median','max'])\n"
|
|
1201
|
+
# f" .rename(columns={{'median':'50%'}})\n"
|
|
1202
|
+
# f")\n"
|
|
1203
|
+
# )
|
|
1204
|
+
# elif ym:
|
|
1205
|
+
# feature = ym.group(1)
|
|
1206
|
+
# tbl_block = (
|
|
1207
|
+
# f"summary_table = (\n"
|
|
1208
|
+
# f" df['{feature}']\n"
|
|
1209
|
+
# f" .agg(['count','mean','std','min','median','max'])\n"
|
|
1210
|
+
# f" .rename(columns={{'median':'50%'}})\n"
|
|
1211
|
+
# f")\n"
|
|
1212
|
+
# )
|
|
1110
1213
|
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1214
|
+
# # 3️⃣ grid-search results
|
|
1215
|
+
# elif "GridSearchCV(" in code:
|
|
1216
|
+
# tbl_block = textwrap.dedent("""
|
|
1217
|
+
# # build tidy CV-results table
|
|
1218
|
+
# cv_df = (
|
|
1219
|
+
# pd.DataFrame(grid_search.cv_results_)
|
|
1220
|
+
# .loc[:, ['param_n_estimators', 'param_max_depth',
|
|
1221
|
+
# 'mean_test_score', 'std_test_score']]
|
|
1222
|
+
# .rename(columns={
|
|
1223
|
+
# 'param_n_estimators': 'n_estimators',
|
|
1224
|
+
# 'param_max_depth': 'max_depth',
|
|
1225
|
+
# 'mean_test_score': 'mean_cv_accuracy',
|
|
1226
|
+
# 'std_test_score': 'std'
|
|
1227
|
+
# })
|
|
1228
|
+
# .sort_values('mean_cv_accuracy', ascending=False)
|
|
1229
|
+
# .reset_index(drop=True)
|
|
1230
|
+
# )
|
|
1231
|
+
# summary_table = cv_df
|
|
1232
|
+
# """).strip() + "\n"
|
|
1233
|
+
# else:
|
|
1234
|
+
# tbl_block = (
|
|
1235
|
+
# "summary_table = (\n"
|
|
1236
|
+
# " df.describe().T[['count','mean','std','min','50%','max']]\n"
|
|
1237
|
+
# ")\n"
|
|
1238
|
+
# )
|
|
1239
|
+
|
|
1240
|
+
# tbl_block += "show(summary_table, title='Summary Statistics')"
|
|
1241
|
+
|
|
1242
|
+
# # 5. inject image-export block, then table block, after the plot
|
|
1243
|
+
# patched = (
|
|
1244
|
+
# lines[:last_plot+1]
|
|
1245
|
+
# + img_block.splitlines()
|
|
1246
|
+
# + tbl_block.splitlines()
|
|
1247
|
+
# + lines[last_plot+1:]
|
|
1248
|
+
# )
|
|
1249
|
+
# patched_code = "\n".join(patched)
|
|
1250
|
+
# # ⬇️ strip every accidental left-indent so top-level lines are flush‐left
|
|
1251
|
+
# return textwrap.dedent(patched_code)
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
# def refine_eda_question(raw_question, df=None, max_points=1000):
|
|
1255
|
+
# """
|
|
1256
|
+
# Rewrites user's EDA question to avoid classic mistakes:
|
|
1257
|
+
# - For line plots and scatter: recommend aggregation or sampling if large.
|
|
1258
|
+
# - For histograms/bar: clarify which variable to plot and bin count.
|
|
1259
|
+
# - For correlation: suggest a heatmap.
|
|
1260
|
+
# - For counts: direct request for df.shape[0].
|
|
1261
|
+
# df (optional): pass DataFrame for column inspection.
|
|
1262
|
+
# """
|
|
1263
|
+
|
|
1264
|
+
# # --- SPECIFIC PEARSON CORRELATION DETECTION ----------------------
|
|
1265
|
+
# pc = re.match(
|
|
1266
|
+
# r".*\bpearson\b.*\bcorrelation\b.*between\s+(\w+)\s+(and|vs)\s+(\w+)",
|
|
1267
|
+
# raw_question, re.I
|
|
1268
|
+
# )
|
|
1269
|
+
# if pc:
|
|
1270
|
+
# col1, col2 = pc.group(1), pc.group(3)
|
|
1271
|
+
# # Return an instruction that preserves the exact intent
|
|
1272
|
+
# return (
|
|
1273
|
+
# f"Compute the Pearson correlation coefficient (r) and p-value "
|
|
1274
|
+
# f"between {col1} and {col2}. "
|
|
1275
|
+
# f"Print a short interpretation."
|
|
1276
|
+
# )
|
|
1277
|
+
# # -----------------------------------------------------------------
|
|
1278
|
+
# # ── Detect "predict <column>" intent ──────────────────────────────
|
|
1279
|
+
# c = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", raw_question, re.I)
|
|
1280
|
+
# if c:
|
|
1281
|
+
# target = c.group(1)
|
|
1282
|
+
# raw_question += (
|
|
1283
|
+
# f" IMPORTANT: do NOT recreate or overwrite the existing target column "
|
|
1284
|
+
# f"“{target}”. Use it as-is for y = df['{target}']."
|
|
1285
|
+
# )
|
|
1286
|
+
|
|
1287
|
+
# q = raw_question.strip()
|
|
1288
|
+
# # REMOVE explicit summary-table instructions
|
|
1289
|
+
# # ── strip any “table” request: “…table of …”, “…include table…”, “…with a table…”
|
|
1290
|
+
# q = re.sub(r"\b(include|with|and)\b[^.]*\btable[s]?\b[^.]*", "", q, flags=re.I).strip()
|
|
1291
|
+
# q = re.sub(r"\s*,\s*$", "", q) # drop trailing comma, if any
|
|
1292
|
+
|
|
1293
|
+
# ql = q.lower()
|
|
1294
|
+
|
|
1295
|
+
# # ── NEW: if the text contains an exact column name, leave it alone ──
|
|
1296
|
+
# if df is not None:
|
|
1297
|
+
# for col in df.columns:
|
|
1298
|
+
# if col.lower() in ql:
|
|
1299
|
+
# return q
|
|
1300
|
+
|
|
1301
|
+
# modelling_keywords = (
|
|
1302
|
+
# "random forest", "gradient-boost", "tree-based model",
|
|
1303
|
+
# "feature importance", "feature importances",
|
|
1304
|
+
# "overall accuracy", "train a model", "predict "
|
|
1305
|
+
# )
|
|
1306
|
+
# if any(k in ql for k in modelling_keywords):
|
|
1307
|
+
# return q
|
|
1308
|
+
|
|
1309
|
+
# # 1. Line plots: average if plotting raw numeric vs numeric
|
|
1310
|
+
# if "line plot" in ql and any(word in ql for word in ["over", "by", "vs"]):
|
|
1311
|
+
# match = re.search(r'line plot of ([\w_]+) (over|by|vs) ([\w_]+)', ql)
|
|
1312
|
+
# if match:
|
|
1313
|
+
# y, _, x = match.groups()
|
|
1314
|
+
# return f"Show me the average {y} by {x} as a line plot."
|
|
1315
|
+
|
|
1316
|
+
# # 2. Scatter plots: sample if too large
|
|
1317
|
+
# if "scatter" in ql or "scatter plot" in ql:
|
|
1318
|
+
# if df is not None and df.shape[0] > max_points:
|
|
1319
|
+
# return q + " (use only a random sample of 1000 points to avoid overplotting)"
|
|
1320
|
+
# else:
|
|
1321
|
+
# return q
|
|
1322
|
+
|
|
1323
|
+
# # 3. Histogram: specify bins and column
|
|
1324
|
+
# if "histogram" in ql:
|
|
1325
|
+
# match = re.search(r'histogram of ([\w_]+)', ql)
|
|
1326
|
+
# if match:
|
|
1327
|
+
# col = match.group(1)
|
|
1328
|
+
# return f"Show me a histogram of {col} using 20 bins."
|
|
1329
|
+
|
|
1330
|
+
# # Special case: histogram for column with most missing values
|
|
1331
|
+
# if "most missing" in ql:
|
|
1332
|
+
# return (
|
|
1333
|
+
# "Show a histogram for the column with the most missing values. "
|
|
1334
|
+
# "First, select the column using: "
|
|
1335
|
+
# "column_with_most_missing = df.isnull().sum().idxmax(); "
|
|
1336
|
+
# "then plot its histogram with: "
|
|
1337
|
+
# "df[column_with_most_missing].hist()"
|
|
1338
|
+
# )
|
|
1339
|
+
|
|
1340
|
+
# # 4. Bar plot: show top N
|
|
1341
|
+
# if "bar plot" in ql or "bar chart" in ql:
|
|
1342
|
+
# match = re.search(r'bar (plot|chart) of ([\w_]+)', ql)
|
|
1343
|
+
# if match:
|
|
1344
|
+
# col = match.group(2)
|
|
1345
|
+
# return f"Show me a bar plot of the top 10 {col} values."
|
|
1346
|
+
|
|
1347
|
+
# # 5. Correlation or heatmap
|
|
1348
|
+
# if "correlation" in ql:
|
|
1349
|
+
# return (
|
|
1350
|
+
# "Show a correlation heatmap for all numeric columns only. "
|
|
1351
|
+
# "Use: correlation_matrix = df.select_dtypes(include='number').corr()"
|
|
1352
|
+
# )
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
# # 6. Counts/size
|
|
1356
|
+
# if "how many record" in ql or "row count" in ql or "number of rows" in ql:
|
|
1357
|
+
# return "How many rows are in the dataset?"
|
|
1358
|
+
|
|
1359
|
+
# # 7. General best-practices fallback: add axis labels/titles
|
|
1360
|
+
# if "plot" in ql:
|
|
1361
|
+
# return q + " (make sure the axes are labeled and the plot is readable)"
|
|
1259
1362
|
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1363
|
+
# # 8.
|
|
1364
|
+
# if (("how often" in ql or "count" in ql or "frequency" in ql) and "category" in ql) or ("value_counts" in q):
|
|
1365
|
+
# match = re.search(r'(?:categories? in |bar plot of |bar chart of )([\w_]+)', ql)
|
|
1366
|
+
# col = match.group(1) if match else None
|
|
1367
|
+
# if col:
|
|
1368
|
+
# return (
|
|
1369
|
+
# f"Show a bar plot of the counts of {col} using: "
|
|
1370
|
+
# f"df['{col}'].value_counts().plot(kind='bar'); "
|
|
1371
|
+
# "add axis labels and a title, then plt.show()."
|
|
1372
|
+
# )
|
|
1270
1373
|
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1374
|
+
# if ("mean" in ql and "median" in ql and "standard deviation" in ql) or ("summary statistics" in ql):
|
|
1375
|
+
# return (
|
|
1376
|
+
# "Show a table of the mean, median, and standard deviation for all numeric columns. "
|
|
1377
|
+
# "Use: tbl = df.describe().loc[['mean', '50%', 'std']].rename(index={'50%': 'median'}); display(tbl)"
|
|
1378
|
+
# )
|
|
1276
1379
|
|
|
1277
|
-
|
|
1278
|
-
|
|
1380
|
+
# # 9. Fallback: return the raw question
|
|
1381
|
+
# return q
|
|
1279
1382
|
|
|
1280
1383
|
|
|
1281
|
-
def patch_plot_code(code, df, user_question=None):
|
|
1384
|
+
# def patch_plot_code(code, df, user_question=None):
|
|
1282
1385
|
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1386
|
+
# # ── Early guard: abort nicely if the generated code references columns that
|
|
1387
|
+
# # do not exist in the DataFrame. This prevents KeyError crashes.
|
|
1388
|
+
# import re
|
|
1286
1389
|
|
|
1287
1390
|
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1391
|
+
# # ── Detect columns referenced in the code ──────────────────────────
|
|
1392
|
+
# col_refs = re.findall(r"df\[['\"](\w+)['\"]\]", code)
|
|
1393
|
+
|
|
1394
|
+
# # Columns that will be newly CREATED (appear left of '=')
|
|
1395
|
+
# new_cols = re.findall(r"df\[['\"](\w+)['\"]\]\s*=", code)
|
|
1396
|
+
|
|
1397
|
+
# missing_cols = [
|
|
1398
|
+
# col for col in col_refs
|
|
1399
|
+
# if col not in df.columns and col not in new_cols
|
|
1400
|
+
# ]
|
|
1401
|
+
|
|
1402
|
+
# if missing_cols:
|
|
1403
|
+
# cols_list = ", ".join(missing_cols)
|
|
1404
|
+
# warning = (
|
|
1405
|
+
# f"show('⚠️ Warning: code references missing column(s): \"{cols_list}\". "
|
|
1406
|
+
# "These must either exist in df or be created earlier in the code; "
|
|
1407
|
+
# "otherwise you may see a KeyError.')\n"
|
|
1408
|
+
# )
|
|
1409
|
+
# # Prepend the warning but keep the original code so it can still run
|
|
1410
|
+
# code = warning + code
|
|
1308
1411
|
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1412
|
+
# # 1. For line plots (auto-aggregate)
|
|
1413
|
+
# m_l = re.search(r"plt\.plot\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
1414
|
+
# if m_l:
|
|
1415
|
+
# x, y = m_l.groups()
|
|
1416
|
+
# if pd.api.types.is_numeric_dtype(df[x]) and pd.api.types.is_numeric_dtype(df[y]) and df[x].nunique() > 20:
|
|
1417
|
+
# return (
|
|
1418
|
+
# f"agg_df = df.groupby('{x}')['{y}'].mean().reset_index()\n"
|
|
1419
|
+
# f"plt.plot(agg_df['{x}'], agg_df['{y}'], marker='o')\n"
|
|
1420
|
+
# f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('Average {y} by {x}')\nplt.show()"
|
|
1421
|
+
# )
|
|
1422
|
+
|
|
1423
|
+
# # 2. For scatter plots: sample to 1000 points max
|
|
1424
|
+
# m_s = re.search(r"plt\.scatter\(\s*df\[['\"](\w+)['\"]\]\s*,\s*df\[['\"](\w+)['\"]\]", code)
|
|
1425
|
+
# if m_s:
|
|
1426
|
+
# x, y = m_s.groups()
|
|
1427
|
+
# if len(df) > 1000:
|
|
1428
|
+
# return (
|
|
1429
|
+
# f"samp = df.sample(1000, random_state=42)\n"
|
|
1430
|
+
# f"plt.scatter(samp['{x}'], samp['{y}'])\n"
|
|
1431
|
+
# f"plt.xlabel('{x}')\nplt.ylabel('{y}')\nplt.title('{y} vs {x} (sampled)')\nplt.show()"
|
|
1432
|
+
# )
|
|
1433
|
+
|
|
1434
|
+
# # 3. For histograms: use bins=20 for numeric, value_counts for categorical
|
|
1435
|
+
# m_h = re.search(r"plt\.hist\(\s*df\[['\"](\w+)['\"]\]", code)
|
|
1436
|
+
# if m_h:
|
|
1437
|
+
# col = m_h.group(1)
|
|
1438
|
+
# if pd.api.types.is_numeric_dtype(df[col]):
|
|
1439
|
+
# return (
|
|
1440
|
+
# f"plt.hist(df['{col}'], bins=20, edgecolor='black')\n"
|
|
1441
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Frequency')\nplt.title('Histogram of {col}')\nplt.show()"
|
|
1442
|
+
# )
|
|
1443
|
+
# else:
|
|
1444
|
+
# # If categorical, show bar plot of value counts
|
|
1445
|
+
# return (
|
|
1446
|
+
# f"df['{col}'].value_counts().plot(kind='bar')\n"
|
|
1447
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Counts of {col}')\nplt.show()"
|
|
1448
|
+
# )
|
|
1449
|
+
|
|
1450
|
+
# # 4. For bar plots: show only top 20
|
|
1451
|
+
# m_b = re.search(r"(?:df\[['\"](\w+)['\"]\]\.value_counts\(\).plot\(kind=['\"]bar['\"]\))", code)
|
|
1452
|
+
# if m_b:
|
|
1453
|
+
# col = m_b.group(1)
|
|
1454
|
+
# if df[col].nunique() > 20:
|
|
1455
|
+
# return (
|
|
1456
|
+
# f"topN = df['{col}'].value_counts().head(20)\n"
|
|
1457
|
+
# f"topN.plot(kind='bar')\n"
|
|
1458
|
+
# f"plt.xlabel('{col}')\nplt.ylabel('Count')\nplt.title('Top 20 {col} Categories')\nplt.show()"
|
|
1459
|
+
# )
|
|
1460
|
+
|
|
1461
|
+
# # 5. For any DataFrame plot with len(df)>10000, sample before plotting!
|
|
1462
|
+
# if "df.plot" in code and len(df) > 10000:
|
|
1463
|
+
# return (
|
|
1464
|
+
# f"samp = df.sample(1000, random_state=42)\n"
|
|
1465
|
+
# + code.replace("df.", "samp.")
|
|
1466
|
+
# )
|
|
1364
1467
|
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1468
|
+
# # ── Block assignment to an existing target column ────────────────
|
|
1469
|
+
# #*******************************************************
|
|
1470
|
+
# target_match = re.search(r"\bpredict\s+([A-Za-z0-9_]+)", user_question or "", re.I)
|
|
1471
|
+
# if target_match:
|
|
1472
|
+
# target = target_match.group(1)
|
|
1473
|
+
|
|
1474
|
+
# # pattern for an assignment to that target
|
|
1475
|
+
# assign_pat = rf"df\[['\"]{re.escape(target)}['\"]\]\s*="
|
|
1476
|
+
# assign_line = re.search(assign_pat + r".*", code)
|
|
1477
|
+
# if assign_line:
|
|
1478
|
+
# # runtime check: keep the assignment **only if** the column is absent
|
|
1479
|
+
# guard = (
|
|
1480
|
+
# f"if '{target}' in df.columns:\n"
|
|
1481
|
+
# f" print('⚠️ {target} already exists – overwrite skipped.');\n"
|
|
1482
|
+
# f"else:\n"
|
|
1483
|
+
# f" {assign_line.group(0)}"
|
|
1484
|
+
# )
|
|
1485
|
+
# # remove original assignment line and insert guarded block
|
|
1486
|
+
# code = code.replace(assign_line.group(0), guard, 1)
|
|
1487
|
+
# # ***************************************************
|
|
1385
1488
|
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1489
|
+
# # 6. Grouped bar plot for two categoricals
|
|
1490
|
+
# # Grouped bar plot for two categoricals (.value_counts().unstack() or .groupby().size().unstack())
|
|
1491
|
+
# if ".value_counts().unstack()" in code or ".groupby(" in code and ".size().unstack()" in code:
|
|
1492
|
+
# # Try to infer columns from user question if possible:
|
|
1493
|
+
# group, cat = None, None
|
|
1494
|
+
# if user_question:
|
|
1495
|
+
# # crude parse for "counts of X for each Y"
|
|
1496
|
+
# m = re.search(r"counts? of (\w+) for each (\w+)", user_question)
|
|
1497
|
+
# if m:
|
|
1498
|
+
# cat, group = m.groups()
|
|
1499
|
+
# if not (cat and group):
|
|
1500
|
+
# # fallback: use two most frequent categoricals
|
|
1501
|
+
# categoricals = [col for col in df.columns if pd.api.types.is_categorical_dtype(df[col]) or df[col].dtype == "object"]
|
|
1502
|
+
# if len(categoricals) >= 2:
|
|
1503
|
+
# cat, group = categoricals[:2]
|
|
1504
|
+
# else:
|
|
1505
|
+
# # fallback: any
|
|
1506
|
+
# cat, group = df.columns[:2]
|
|
1507
|
+
# return (
|
|
1508
|
+
# f"import pandas as pd\n"
|
|
1509
|
+
# f"import matplotlib.pyplot as plt\n"
|
|
1510
|
+
# f"ct = pd.crosstab(df['{group}'], df['{cat}'])\n"
|
|
1511
|
+
# f"ct.plot(kind='bar')\n"
|
|
1512
|
+
# f"plt.title('Counts of {cat} for each {group}')\n"
|
|
1513
|
+
# f"plt.xlabel('{group}')\nplt.ylabel('Count')\nplt.xticks(rotation=0)\nplt.show()"
|
|
1514
|
+
# )
|
|
1515
|
+
|
|
1516
|
+
# # Fallback: Return original code
|
|
1517
|
+
# return code
|
|
1518
|
+
|
|
1519
|
+
|
|
1520
|
+
# def ensure_matplotlib_title(code, title_var="refined_question"):
|
|
1521
|
+
# import re
|
|
1522
|
+
# makes_plot = re.search(r"\b(plt\.(plot|scatter|bar|hist)|ax\.(plot|scatter|bar|hist))\b", code)
|
|
1523
|
+
# has_title = re.search(r"\b(plt\.title|ax\.set_title)\s*\(", code)
|
|
1524
|
+
# if makes_plot and not has_title:
|
|
1525
|
+
# code += f"\ntry:\n plt.title(str({title_var})[:120])\nexcept Exception: pass\n"
|
|
1526
|
+
# return code
|
|
1424
1527
|
|
|
1425
1528
|
|
|
1426
1529
|
def ensure_output(code: str) -> str:
|
|
@@ -1499,1510 +1602,1455 @@ def ensure_output(code: str) -> str:
|
|
|
1499
1602
|
return "\n".join(lines)
|
|
1500
1603
|
|
|
1501
1604
|
|
|
1502
|
-
def get_plotting_imports(code):
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
def patch_pairplot(code, df):
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
def ensure_image_output(code: str) -> str:
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
def clean_llm_code(code: str) -> str:
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
|
|
1656
|
-
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1605
|
+
# def get_plotting_imports(code):
|
|
1606
|
+
# imports = []
|
|
1607
|
+
# if "plt." in code and "import matplotlib.pyplot as plt" not in code:
|
|
1608
|
+
# imports.append("import matplotlib.pyplot as plt")
|
|
1609
|
+
# if "sns." in code and "import seaborn as sns" not in code:
|
|
1610
|
+
# imports.append("import seaborn as sns")
|
|
1611
|
+
# if "px." in code and "import plotly.express as px" not in code:
|
|
1612
|
+
# imports.append("import plotly.express as px")
|
|
1613
|
+
# if "pd." in code and "import pandas as pd" not in code:
|
|
1614
|
+
# imports.append("import pandas as pd")
|
|
1615
|
+
# if "np." in code and "import numpy as np" not in code:
|
|
1616
|
+
# imports.append("import numpy as np")
|
|
1617
|
+
# if "display(" in code and "from IPython.display import display" not in code:
|
|
1618
|
+
# imports.append("from IPython.display import display")
|
|
1619
|
+
# # Optionally, add more as you see usage (e.g., import scipy, statsmodels, etc)
|
|
1620
|
+
# if imports:
|
|
1621
|
+
# code = "\n".join(imports) + "\n\n" + code
|
|
1622
|
+
# return code
|
|
1623
|
+
|
|
1624
|
+
|
|
1625
|
+
# def patch_pairplot(code, df):
|
|
1626
|
+
# if "sns.pairplot" in code:
|
|
1627
|
+
# # Always assign and print pairgrid
|
|
1628
|
+
# code = re.sub(r"sns\.pairplot\((.+)\)", r"pairgrid = sns.pairplot(\1)", code)
|
|
1629
|
+
# if "plt.show()" not in code:
|
|
1630
|
+
# code += "\nplt.show()"
|
|
1631
|
+
# if "print(pairgrid)" not in code:
|
|
1632
|
+
# code += "\nprint(pairgrid)"
|
|
1633
|
+
# return code
|
|
1634
|
+
|
|
1635
|
+
|
|
1636
|
+
# def ensure_image_output(code: str) -> str:
|
|
1637
|
+
# """
|
|
1638
|
+
# Replace each plt.show() with an indented _SMX_export_png() call.
|
|
1639
|
+
# This keeps block indentation valid and still renders images in the dashboard.
|
|
1640
|
+
# """
|
|
1641
|
+
# if "plt.show()" not in code:
|
|
1642
|
+
# return code
|
|
1643
|
+
|
|
1644
|
+
# import re
|
|
1645
|
+
# out_lines = []
|
|
1646
|
+
# for ln in code.splitlines():
|
|
1647
|
+
# if "plt.show()" not in ln:
|
|
1648
|
+
# out_lines.append(ln)
|
|
1649
|
+
# continue
|
|
1650
|
+
|
|
1651
|
+
# # works for:
|
|
1652
|
+
# # plt.show()
|
|
1653
|
+
# # plt.tight_layout(); plt.show()
|
|
1654
|
+
# # ... ; plt.show(); ... (multiple on one line)
|
|
1655
|
+
# indent = re.match(r"^(\s*)", ln).group(1)
|
|
1656
|
+
# parts = ln.split("plt.show()")
|
|
1657
|
+
|
|
1658
|
+
# # keep whatever is before the first plt.show()
|
|
1659
|
+
# if parts[0].strip():
|
|
1660
|
+
# out_lines.append(parts[0].rstrip())
|
|
1661
|
+
|
|
1662
|
+
# # for every plt.show() we removed, insert exporter at same indent
|
|
1663
|
+
# for _ in range(len(parts) - 1):
|
|
1664
|
+
# out_lines.append(indent + "_SMX_export_png()")
|
|
1665
|
+
|
|
1666
|
+
# # keep whatever comes after the last plt.show()
|
|
1667
|
+
# if parts[-1].strip():
|
|
1668
|
+
# out_lines.append(indent + parts[-1].lstrip())
|
|
1669
|
+
|
|
1670
|
+
# return "\n".join(out_lines)
|
|
1671
|
+
|
|
1672
|
+
|
|
1673
|
+
# def clean_llm_code(code: str) -> str:
|
|
1674
|
+
# """
|
|
1675
|
+
# Make LLM output safe to exec:
|
|
1676
|
+
# - If fenced blocks exist, keep the largest one (usually the real code).
|
|
1677
|
+
# - Otherwise strip any stray ``` / ```python lines.
|
|
1678
|
+
# - Remove common markdown/preamble junk.
|
|
1679
|
+
# """
|
|
1680
|
+
# code = str(code or "")
|
|
1681
|
+
|
|
1682
|
+
# # Special case: sometimes the OpenAI SDK object repr (e.g. ChatCompletion(...))
|
|
1683
|
+
# # is accidentally passed here as `code`. In that case, extract the actual
|
|
1684
|
+
# # Python code from the ChatCompletionMessage(content=...) field.
|
|
1685
|
+
# if "ChatCompletion(" in code and "ChatCompletionMessage" in code and "content=" in code:
|
|
1686
|
+
# try:
|
|
1687
|
+
# extracted = None
|
|
1688
|
+
|
|
1689
|
+
# class _ChatCompletionVisitor(ast.NodeVisitor):
|
|
1690
|
+
# def visit_Call(self, node):
|
|
1691
|
+
# nonlocal extracted
|
|
1692
|
+
# func = node.func
|
|
1693
|
+
# fname = getattr(func, "id", None) or getattr(func, "attr", None)
|
|
1694
|
+
# if fname == "ChatCompletionMessage":
|
|
1695
|
+
# for kw in node.keywords:
|
|
1696
|
+
# if kw.arg == "content" and isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
|
|
1697
|
+
# extracted = kw.value.value
|
|
1698
|
+
# self.generic_visit(node)
|
|
1699
|
+
|
|
1700
|
+
# tree = ast.parse(code, mode="exec")
|
|
1701
|
+
# _ChatCompletionVisitor().visit(tree)
|
|
1702
|
+
# if extracted:
|
|
1703
|
+
# code = extracted
|
|
1704
|
+
# except Exception:
|
|
1705
|
+
# # Best-effort regex fallback if AST parsing fails
|
|
1706
|
+
# m = re.search(r"content=(?P<q>['\\\"])(?P<body>.*?)(?P=q)", code, flags=re.S)
|
|
1707
|
+
# if m:
|
|
1708
|
+
# code = m.group("body")
|
|
1709
|
+
|
|
1710
|
+
# # Existing logic continues unchanged below...
|
|
1711
|
+
# # Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1712
|
+
# blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1713
|
+
|
|
1714
|
+
# if blocks:
|
|
1715
|
+
# # pick the largest block; small trailing blocks are usually garbage
|
|
1716
|
+
# largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1717
|
+
# if len(largest.strip().splitlines()) >= 10:
|
|
1718
|
+
# code = largest
|
|
1719
|
+
|
|
1720
|
+
# # Extract fenced blocks (```python ... ``` or ``` ... ```)
|
|
1721
|
+
# blocks = re.findall(r"```(?:python)?\s*(.*?)```", code, flags=re.I | re.S)
|
|
1722
|
+
|
|
1723
|
+
# if blocks:
|
|
1724
|
+
# # pick the largest block; small trailing blocks are usually garbage
|
|
1725
|
+
# largest = max(blocks, key=lambda b: len(b.strip()))
|
|
1726
|
+
# if len(largest.strip().splitlines()) >= 10:
|
|
1727
|
+
# code = largest
|
|
1728
|
+
# else:
|
|
1729
|
+
# # if no meaningful block, just remove fence markers
|
|
1730
|
+
# code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1731
|
+
# else:
|
|
1732
|
+
# # no complete blocks — still remove any stray fence lines
|
|
1733
|
+
# code = re.sub(r"^```.*?$", "", code, flags=re.M)
|
|
1734
|
+
|
|
1735
|
+
# # Strip common markdown/preamble lines
|
|
1736
|
+
# drop_prefixes = (
|
|
1737
|
+
# "here is", "here's", "below is", "sure,", "certainly",
|
|
1738
|
+
# "explanation", "note:", "```"
|
|
1739
|
+
# )
|
|
1740
|
+
# cleaned_lines = []
|
|
1741
|
+
# for ln in code.splitlines():
|
|
1742
|
+
# s = ln.strip().lower()
|
|
1743
|
+
# if any(s.startswith(p) for p in drop_prefixes):
|
|
1744
|
+
# continue
|
|
1745
|
+
# cleaned_lines.append(ln)
|
|
1746
|
+
|
|
1747
|
+
# return "\n".join(cleaned_lines).strip()
|
|
1748
|
+
|
|
1749
|
+
|
|
1750
|
+
# def fix_groupby_describe_slice(code: str) -> str:
|
|
1751
|
+
# """
|
|
1752
|
+
# Replaces df.groupby(...).describe()[[...] ] with a safe .agg(...)
|
|
1753
|
+
# so it works for both SeriesGroupBy and DataFrameGroupBy.
|
|
1754
|
+
# """
|
|
1755
|
+
# pat = re.compile(
|
|
1756
|
+
# r"(df\.groupby\(['\"][\w]+['\"]\)\['[\w]+['\"]\]\.describe\()\s*\[\[([^\]]+)\]\]\)",
|
|
1757
|
+
# re.MULTILINE
|
|
1758
|
+
# )
|
|
1759
|
+
# def repl(match):
|
|
1760
|
+
# inner = match.group(0)
|
|
1761
|
+
# # extract group and feature to build df.groupby('g')['f']
|
|
1762
|
+
# g = re.search(r"groupby\('([\w]+)'\)", inner).group(1)
|
|
1763
|
+
# f = re.search(r"\)\['([\w]+)'\]\.describe", inner).group(1)
|
|
1764
|
+
# return (
|
|
1765
|
+
# f"df.groupby('{g}')['{f}']"
|
|
1766
|
+
# ".agg(['count','mean','std','min','median','max'])"
|
|
1767
|
+
# ".rename(columns={'median':'50%'})"
|
|
1768
|
+
# )
|
|
1769
|
+
# return pat.sub(repl, code)
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
# def fix_importance_groupby(code: str) -> str:
|
|
1773
|
+
# pattern = re.compile(r"df\.groupby\(['\"]Importance['\"]\)\['\"?Importance['\"]?\]")
|
|
1774
|
+
# if "importance_df" in code:
|
|
1775
|
+
# return pattern.sub("importance_df.groupby('Importance')['Importance']", code)
|
|
1776
|
+
# return code
|
|
1777
|
+
|
|
1778
|
+
# def inject_auto_preprocessing(code: str) -> str:
|
|
1779
|
+
# """
|
|
1780
|
+
# • Detects a RandomForestClassifier in the generated code.
|
|
1781
|
+
# • Finds the target column from `y = df['target']`.
|
|
1782
|
+
# • Prepends a fully-dedented preprocessing snippet that:
|
|
1783
|
+
# – auto-detects numeric & categorical columns
|
|
1784
|
+
# – builds a ColumnTransformer (OneHotEncoder + StandardScaler)
|
|
1785
|
+
# The dedent() call guarantees no leading-space IndentationError.
|
|
1786
|
+
# """
|
|
1787
|
+
# if "RandomForestClassifier" not in code:
|
|
1788
|
+
# return code # nothing to patch
|
|
1789
|
+
|
|
1790
|
+
# y_match = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
1791
|
+
# if not y_match:
|
|
1792
|
+
# return code # can't infer target safely
|
|
1793
|
+
# target = y_match.group(1)
|
|
1794
|
+
|
|
1795
|
+
# prep_snippet = textwrap.dedent(f"""
|
|
1796
|
+
# # ── automatic preprocessing ───────────────────────────────
|
|
1797
|
+
# num_cols = df.select_dtypes(include=['number']).columns.tolist()
|
|
1798
|
+
# cat_cols = df.select_dtypes(exclude=['number']).columns.tolist()
|
|
1799
|
+
# num_cols = [c for c in num_cols if c != '{target}']
|
|
1800
|
+
# cat_cols = [c for c in cat_cols if c != '{target}']
|
|
1801
|
+
|
|
1802
|
+
# from sklearn.compose import ColumnTransformer
|
|
1803
|
+
# from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
|
1804
|
+
|
|
1805
|
+
# preproc = ColumnTransformer(
|
|
1806
|
+
# transformers=[
|
|
1807
|
+
# ('num', StandardScaler(), num_cols),
|
|
1808
|
+
# ('cat', OneHotEncoder(handle_unknown='ignore'), cat_cols),
|
|
1809
|
+
# ],
|
|
1810
|
+
# remainder='drop',
|
|
1811
|
+
# )
|
|
1812
|
+
# # ───────────────────────────────────────────────────────────
|
|
1813
|
+
# """).strip() + "\n\n"
|
|
1814
|
+
|
|
1815
|
+
# # simply prepend; model code that follows can wrap estimator in a Pipeline
|
|
1816
|
+
# return prep_snippet + code
|
|
1817
|
+
|
|
1818
|
+
|
|
1819
|
+
# def fix_to_datetime_errors(code: str) -> str:
|
|
1820
|
+
# """
|
|
1821
|
+
# Force every pd.to_datetime(…) call to ignore bad dates so that
|
|
1681
1822
|
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
def fix_numeric_sum(code: str) -> str:
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
def fix_concat_empty_list(code: str) -> str:
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
def fix_numeric_aggs(code: str) -> str:
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
|
|
1770
|
-
def ensure_accuracy_block(code: str) -> str:
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
def fix_scatter_and_summary(code: str) -> str:
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
def auto_format_with_black(code: str) -> str:
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
def ensure_preproc_in_pipeline(code: str) -> str:
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1891
|
-
|
|
1892
|
-
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
|
|
1955
|
-
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
"
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2013
|
-
|
|
2014
|
-
|
|
2015
|
-
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
|
|
2020
|
-
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
2066
|
-
|
|
2067
|
-
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
|
|
2090
|
-
|
|
2091
|
-
|
|
2092
|
-
|
|
2093
|
-
|
|
2094
|
-
|
|
2095
|
-
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
|
|
2099
|
-
|
|
2100
|
-
|
|
2101
|
-
|
|
2102
|
-
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
|
|
2107
|
-
|
|
2108
|
-
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
|
|
2112
|
-
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
|
-
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
|
|
2120
|
-
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2124
|
-
|
|
2125
|
-
|
|
2126
|
-
|
|
2127
|
-
"""
|
|
2128
|
-
needs_sns = "sns." in code
|
|
2129
|
-
has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2130
|
-
if needs_sns and not has_import:
|
|
2131
|
-
# Insert after the first block of imports if possible, else at top
|
|
2132
|
-
import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2133
|
-
inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2134
|
-
if import_block:
|
|
2135
|
-
start = import_block.end()
|
|
2136
|
-
code = code[:start] + inject + code[start:]
|
|
2137
|
-
else:
|
|
2138
|
-
code = inject + code
|
|
2139
|
-
return code
|
|
2140
|
-
|
|
2141
|
-
|
|
2142
|
-
def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2143
|
-
"""
|
|
2144
|
-
Normalise pie-chart requests.
|
|
2145
|
-
|
|
2146
|
-
Supports three patterns:
|
|
2147
|
-
A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2148
|
-
B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2149
|
-
→ one pie per facet level (grid) + counts bar of facet sizes.
|
|
2150
|
-
C) Single pie when no split/facet is requested.
|
|
2151
|
-
|
|
2152
|
-
Notes:
|
|
2153
|
-
- Pie variables must be categorical (or numeric binned).
|
|
2154
|
-
- Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2155
|
-
"""
|
|
2156
|
-
|
|
2157
|
-
q = (user_question or "")
|
|
2158
|
-
q_low = q.lower()
|
|
2159
|
-
|
|
2160
|
-
# Prefer explicit: df['col'].value_counts()
|
|
2161
|
-
m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2162
|
-
col = m.group(1) if m else None
|
|
2163
|
-
|
|
2164
|
-
# ---------- helpers ----------
|
|
2165
|
-
def _is_cat(col):
|
|
2166
|
-
return (str(df[col].dtype).startswith("category")
|
|
2167
|
-
or df[col].dtype == "object"
|
|
2168
|
-
or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2169
|
-
|
|
2170
|
-
def _cats_from_question(question: str):
|
|
2171
|
-
found = []
|
|
2172
|
-
for c in df.columns:
|
|
2173
|
-
if c.lower() in question.lower() and _is_cat(c):
|
|
2174
|
-
found.append(c)
|
|
2175
|
-
# dedupe preserve order
|
|
2176
|
-
seen, out = set(), []
|
|
2177
|
-
for c in found:
|
|
2178
|
-
if c not in seen:
|
|
2179
|
-
out.append(c); seen.add(c)
|
|
2180
|
-
return out
|
|
2181
|
-
|
|
2182
|
-
def _fallback_cat():
|
|
2183
|
-
cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2184
|
-
if not cats: return None
|
|
2185
|
-
cats.sort(key=lambda t: t[1])
|
|
2186
|
-
return cats[0][0]
|
|
1823
|
+
# 'year 16500 is out of range' and similar issues don’t crash runs.
|
|
1824
|
+
# """
|
|
1825
|
+
# import re
|
|
1826
|
+
# # look for any pd.to_datetime( … )
|
|
1827
|
+
# pat = re.compile(r"pd\.to_datetime\(([^)]+)\)")
|
|
1828
|
+
# def repl(m):
|
|
1829
|
+
# inside = m.group(1)
|
|
1830
|
+
# # if the call already has errors=, leave it unchanged
|
|
1831
|
+
# if "errors=" in inside:
|
|
1832
|
+
# return m.group(0)
|
|
1833
|
+
# return f"pd.to_datetime({inside}, errors='coerce')"
|
|
1834
|
+
# return pat.sub(repl, code)
|
|
1835
|
+
|
|
1836
|
+
|
|
1837
|
+
# def fix_numeric_sum(code: str) -> str:
|
|
1838
|
+
# """
|
|
1839
|
+
# Make .sum(...) code safe across pandas versions by removing any
|
|
1840
|
+
# numeric_only=... argument (True/False/None) from function calls.
|
|
1841
|
+
|
|
1842
|
+
# This avoids errors on pandas versions where numeric_only is not
|
|
1843
|
+
# supported for Series/grouped sums, and we rely instead on explicit
|
|
1844
|
+
# numeric column selection (e.g. select_dtypes) in the generated code.
|
|
1845
|
+
# """
|
|
1846
|
+
# # Case 1: ..., numeric_only=True/False/None
|
|
1847
|
+
# code = re.sub(
|
|
1848
|
+
# r",\s*numeric_only\s*=\s*(True|False|None)",
|
|
1849
|
+
# "",
|
|
1850
|
+
# code,
|
|
1851
|
+
# flags=re.IGNORECASE,
|
|
1852
|
+
# )
|
|
1853
|
+
|
|
1854
|
+
# # Case 2: numeric_only=True/False/None, ... (as first argument)
|
|
1855
|
+
# code = re.sub(
|
|
1856
|
+
# r"numeric_only\s*=\s*(True|False|None)\s*,\s*",
|
|
1857
|
+
# "",
|
|
1858
|
+
# code,
|
|
1859
|
+
# flags=re.IGNORECASE,
|
|
1860
|
+
# )
|
|
1861
|
+
|
|
1862
|
+
# # Case 3: numeric_only=True/False/None (only argument)
|
|
1863
|
+
# code = re.sub(
|
|
1864
|
+
# r"numeric_only\s*=\s*(True|False|None)",
|
|
1865
|
+
# "",
|
|
1866
|
+
# code,
|
|
1867
|
+
# flags=re.IGNORECASE,
|
|
1868
|
+
# )
|
|
1869
|
+
|
|
1870
|
+
# return code
|
|
1871
|
+
|
|
1872
|
+
|
|
1873
|
+
# def fix_concat_empty_list(code: str) -> str:
|
|
1874
|
+
# """
|
|
1875
|
+
# Make pd.concat calls resilient to empty lists of objects.
|
|
1876
|
+
|
|
1877
|
+
# Transforms patterns like:
|
|
1878
|
+
# pd.concat(frames, ignore_index=True)
|
|
1879
|
+
# pd.concat(frames)
|
|
1880
|
+
|
|
1881
|
+
# into:
|
|
1882
|
+
# pd.concat(frames or [pd.DataFrame()], ignore_index=True)
|
|
1883
|
+
# pd.concat(frames or [pd.DataFrame()])
|
|
1884
|
+
|
|
1885
|
+
# Only triggers when the first argument is a simple variable name.
|
|
1886
|
+
# """
|
|
1887
|
+
# pattern = re.compile(r"pd\.concat\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*(,|\))")
|
|
1888
|
+
|
|
1889
|
+
# def _repl(m):
|
|
1890
|
+
# name = m.group(1)
|
|
1891
|
+
# sep = m.group(2) # ',' or ')'
|
|
1892
|
+
# return f"pd.concat({name} or [pd.DataFrame()]{sep}"
|
|
1893
|
+
|
|
1894
|
+
# return pattern.sub(_repl, code)
|
|
1895
|
+
|
|
1896
|
+
|
|
1897
|
+
# def fix_numeric_aggs(code: str) -> str:
|
|
1898
|
+
# _AGG_FUNCS = ("sum", "mean")
|
|
1899
|
+
# pat = re.compile(rf"\.({'|'.join(_AGG_FUNCS)})\(\s*([^)]+)?\)")
|
|
1900
|
+
# def _repl(m):
|
|
1901
|
+
# func, args = m.group(1), m.group(2) or ""
|
|
1902
|
+
# if "numeric_only" in args:
|
|
1903
|
+
# return m.group(0)
|
|
1904
|
+
# args = args.rstrip()
|
|
1905
|
+
# if args:
|
|
1906
|
+
# args += ", "
|
|
1907
|
+
# return f".{func}({args}numeric_only=True)"
|
|
1908
|
+
# return pat.sub(_repl, code)
|
|
1909
|
+
|
|
1910
|
+
|
|
1911
|
+
# def ensure_accuracy_block(code: str) -> str:
|
|
1912
|
+
# """
|
|
1913
|
+
# Inject a sensible evaluation block right after the last `<est>.fit(...)`
|
|
1914
|
+
# Classification → accuracy + weighted F1
|
|
1915
|
+
# Regression → R², RMSE, MAE
|
|
1916
|
+
# Heuristic: infer task from estimator names present in the code.
|
|
1917
|
+
# """
|
|
1918
|
+
# import re, textwrap
|
|
1919
|
+
|
|
1920
|
+
# # If any proper metric already exists, do nothing
|
|
1921
|
+
# if re.search(r"\b(accuracy_score|f1_score|r2_score|mean_squared_error|mean_absolute_error)\b", code):
|
|
1922
|
+
# return code
|
|
1923
|
+
|
|
1924
|
+
# # Find the last "<var>.fit(" occurrence to reuse the estimator variable name
|
|
1925
|
+
# m = list(re.finditer(r"(\w+)\.fit\s*\(", code))
|
|
1926
|
+
# if not m:
|
|
1927
|
+
# return code # no estimator
|
|
1928
|
+
|
|
1929
|
+
# var = m[-1].group(1)
|
|
1930
|
+
# # indent with same leading whitespace used on that line
|
|
1931
|
+
# indent = re.match(r"\s*", code[m[-1].start():]).group(0)
|
|
1932
|
+
|
|
1933
|
+
# # Detect regression by estimator names / hints in code
|
|
1934
|
+
# is_regression = bool(
|
|
1935
|
+
# re.search(
|
|
1936
|
+
# r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
1937
|
+
# r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
1938
|
+
# r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
1939
|
+
# )
|
|
1940
|
+
# or re.search(r"\bOLS\s*\(", code)
|
|
1941
|
+
# or re.search(r"\bRegressor\b", code)
|
|
1942
|
+
# )
|
|
1943
|
+
|
|
1944
|
+
# if is_regression:
|
|
1945
|
+
# # inject numpy import if needed for RMSE
|
|
1946
|
+
# if "import numpy as np" not in code and "np." not in code:
|
|
1947
|
+
# code = "import numpy as np\n" + code
|
|
1948
|
+
# eval_block = textwrap.dedent(f"""
|
|
1949
|
+
# {indent}# ── automatic regression evaluation ─────────
|
|
1950
|
+
# {indent}from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
|
1951
|
+
# {indent}y_pred = {var}.predict(X_test)
|
|
1952
|
+
# {indent}r2 = r2_score(y_test, y_pred)
|
|
1953
|
+
# {indent}rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
|
1954
|
+
# {indent}mae = float(mean_absolute_error(y_test, y_pred))
|
|
1955
|
+
# {indent}print(f"R²: {{r2:.4f}} | RMSE: {{rmse:.4f}} | MAE: {{mae:.4f}}")
|
|
1956
|
+
# """)
|
|
1957
|
+
# else:
|
|
1958
|
+
# eval_block = textwrap.dedent(f"""
|
|
1959
|
+
# {indent}# ── automatic classification evaluation ─────────
|
|
1960
|
+
# {indent}from sklearn.metrics import accuracy_score, f1_score
|
|
1961
|
+
# {indent}y_pred = {var}.predict(X_test)
|
|
1962
|
+
# {indent}acc = accuracy_score(y_test, y_pred)
|
|
1963
|
+
# {indent}f1 = f1_score(y_test, y_pred, average='weighted')
|
|
1964
|
+
# {indent}print(f"Accuracy: {{acc:.2%}} | F1 (weighted): {{f1:.3f}}")
|
|
1965
|
+
# """)
|
|
1966
|
+
|
|
1967
|
+
# insert_at = code.find("\n", m[-1].end()) + 1
|
|
1968
|
+
# return code[:insert_at] + eval_block + code[insert_at:]
|
|
1969
|
+
|
|
1970
|
+
|
|
1971
|
+
# def fix_scatter_and_summary(code: str) -> str:
|
|
1972
|
+
# """
|
|
1973
|
+
# 1. Change cmap='spectral' (any case) → cmap='Spectral'
|
|
1974
|
+
# 2. If the LLM forgets to close the parenthesis in
|
|
1975
|
+
# summary_table = ( df.describe()... <missing )>
|
|
1976
|
+
# insert the ')' right before the next 'from' or 'show('.
|
|
1977
|
+
# """
|
|
1978
|
+
# # 1️⃣ colormap case
|
|
1979
|
+
# code = re.sub(
|
|
1980
|
+
# r"cmap\s*=\s*['\"]spectral['\"]", # insensitive pattern
|
|
1981
|
+
# "cmap='Spectral'",
|
|
1982
|
+
# code,
|
|
1983
|
+
# flags=re.IGNORECASE,
|
|
1984
|
+
# )
|
|
1985
|
+
|
|
1986
|
+
# # 2️⃣ close summary_table = ( ... )
|
|
1987
|
+
# code = re.sub(
|
|
1988
|
+
# r"(summary_table\s*=\s*\(\s*df\.describe\([^\n]+?\n)"
|
|
1989
|
+
# r"(?=\s*(from|show\())", # look-ahead: next line starts with 'from' or 'show('
|
|
1990
|
+
# r"\1)", # keep group 1 and add ')'
|
|
1991
|
+
# code,
|
|
1992
|
+
# flags=re.MULTILINE,
|
|
1993
|
+
# )
|
|
1994
|
+
|
|
1995
|
+
# return code
|
|
1996
|
+
|
|
1997
|
+
|
|
1998
|
+
# def auto_format_with_black(code: str) -> str:
|
|
1999
|
+
# """
|
|
2000
|
+
# Format the generated code with Black. Falls back silently if Black
|
|
2001
|
+
# is missing or raises (so the dashboard never 500s).
|
|
2002
|
+
# """
|
|
2003
|
+
# try:
|
|
2004
|
+
# import black # make sure black is in your v-env: pip install black
|
|
2005
|
+
|
|
2006
|
+
# mode = black.FileMode() # default settings
|
|
2007
|
+
# return black.format_str(code, mode=mode)
|
|
2008
|
+
|
|
2009
|
+
# except Exception:
|
|
2010
|
+
# return code
|
|
2011
|
+
|
|
2012
|
+
|
|
2013
|
+
# def ensure_preproc_in_pipeline(code: str) -> str:
|
|
2014
|
+
# """
|
|
2015
|
+
# If code defines `preproc = ColumnTransformer(...)` but then builds
|
|
2016
|
+
# `Pipeline([('scaler', StandardScaler()), ('clf', ...)])`, replace
|
|
2017
|
+
# that stanza with `Pipeline([('prep', preproc), ('clf', ...)])`.
|
|
2018
|
+
# """
|
|
2019
|
+
# return re.sub(
|
|
2020
|
+
# r"Pipeline\(\s*\[\('scaler',\s*StandardScaler\(\)\)",
|
|
2021
|
+
# "Pipeline([('prep', preproc)",
|
|
2022
|
+
# code
|
|
2023
|
+
# )
|
|
2024
|
+
|
|
2025
|
+
# def drop_bad_classification_metrics(code: str, y_or_df) -> str:
|
|
2026
|
+
# """
|
|
2027
|
+
# Remove classification metrics (accuracy_score, classification_report, confusion_matrix)
|
|
2028
|
+
# if the generated cell is *regression*. We infer this from:
|
|
2029
|
+
# 1) The estimator names in the code (LinearRegression, OLS, Regressor*, etc.), OR
|
|
2030
|
+
# 2) The target dtype if we can parse y = df['...'] and have the DataFrame.
|
|
2031
|
+
# Safe across datasets and queries.
|
|
2032
|
+
# """
|
|
2033
|
+
# import re
|
|
2034
|
+
# import pandas as pd
|
|
2035
|
+
|
|
2036
|
+
# # 1) Heuristic by estimator names in the *code* (fast path)
|
|
2037
|
+
# regression_by_model = bool(re.search(
|
|
2038
|
+
# r"\b(LinearRegression|Ridge|Lasso|ElasticNet|ElasticNetCV|HuberRegressor|TheilSenRegressor|RANSACRegressor|"
|
|
2039
|
+
# r"RandomForestRegressor|GradientBoostingRegressor|DecisionTreeRegressor|KNeighborsRegressor|SVR|"
|
|
2040
|
+
# r"XGBRegressor|LGBMRegressor|CatBoostRegressor)\b", code
|
|
2041
|
+
# ) or re.search(r"\bOLS\s*\(", code))
|
|
2042
|
+
|
|
2043
|
+
# is_regression = regression_by_model
|
|
2044
|
+
|
|
2045
|
+
# # 2) If not obvious from the model, try to infer from y dtype (if we can)
|
|
2046
|
+
# if not is_regression:
|
|
2047
|
+
# try:
|
|
2048
|
+
# # Try to parse: y = df['target']
|
|
2049
|
+
# m = re.search(r"y\s*=\s*df\[['\"]([^'\"]+)['\"]\]", code)
|
|
2050
|
+
# if m and hasattr(y_or_df, "columns") and m.group(1) in getattr(y_or_df, "columns", []):
|
|
2051
|
+
# y = y_or_df[m.group(1)]
|
|
2052
|
+
# if pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2053
|
+
# is_regression = True
|
|
2054
|
+
# else:
|
|
2055
|
+
# # If a Series was passed
|
|
2056
|
+
# y = y_or_df
|
|
2057
|
+
# if hasattr(y, "dtype") and pd.api.types.is_numeric_dtype(y) and y.nunique(dropna=True) > 10:
|
|
2058
|
+
# is_regression = True
|
|
2059
|
+
# except Exception:
|
|
2060
|
+
# pass
|
|
2061
|
+
|
|
2062
|
+
# if is_regression:
|
|
2063
|
+
# # Strip classification-only lines
|
|
2064
|
+
# for pat in (r"\n.*accuracy_score[^\n]*", r"\n.*classification_report[^\n]*", r"\n.*confusion_matrix[^\n]*"):
|
|
2065
|
+
# code = re.sub(pat, "", code, flags=re.I)
|
|
2066
|
+
|
|
2067
|
+
# return code
|
|
2068
|
+
|
|
2069
|
+
|
|
2070
|
+
# def force_capture_display(code: str) -> str:
|
|
2071
|
+
# """
|
|
2072
|
+
# Ensure our executor captures HTML output:
|
|
2073
|
+
# - Remove any import that would override our 'display' hook.
|
|
2074
|
+
# - Keep/allow importing HTML only.
|
|
2075
|
+
# - Handle alias cases like 'display as d'.
|
|
2076
|
+
# """
|
|
2077
|
+
# import re
|
|
2078
|
+
# new = code
|
|
2079
|
+
|
|
2080
|
+
# # 'from IPython.display import display, HTML' -> keep HTML only
|
|
2081
|
+
# new = re.sub(
|
|
2082
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*,\s*HTML\s*(?:as\s+([A-Za-z_]\w*))?\s*$",
|
|
2083
|
+
# r"from IPython.display import HTML\1", new
|
|
2084
|
+
# )
|
|
2085
|
+
|
|
2086
|
+
# # 'from IPython.display import display as d' -> 'd = display'
|
|
2087
|
+
# new = re.sub(
|
|
2088
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s+as\s+([A-Za-z_]\w+)\s*$",
|
|
2089
|
+
# r"\1 = display", new
|
|
2090
|
+
# )
|
|
2091
|
+
|
|
2092
|
+
# # 'from IPython.display import display' -> remove (use our injected display)
|
|
2093
|
+
# new = re.sub(
|
|
2094
|
+
# r"(?m)^\s*from\s+IPython\.display\s+import\s+display\s*$",
|
|
2095
|
+
# r"# display import removed (SMX capture active)", new
|
|
2096
|
+
# )
|
|
2097
|
+
|
|
2098
|
+
# # If someone does 'import IPython.display as disp' and calls disp.display(...), rewrite to display(...)
|
|
2099
|
+
# new = re.sub(
|
|
2100
|
+
# r"(?m)\bIPython\.display\.display\s*\(",
|
|
2101
|
+
# "display(", new
|
|
2102
|
+
# )
|
|
2103
|
+
# new = re.sub(
|
|
2104
|
+
# r"(?m)\b([A-Za-z_]\w*)\.display\s*\(" # handles 'disp.display(' after 'import IPython.display as disp'
|
|
2105
|
+
# r"(?=.*import\s+IPython\.display\s+as\s+\1)",
|
|
2106
|
+
# "display(", new
|
|
2107
|
+
# )
|
|
2108
|
+
# return new
|
|
2109
|
+
|
|
2110
|
+
|
|
2111
|
+
# def strip_matplotlib_show(code: str) -> str:
|
|
2112
|
+
# """Remove blocking plt.show() calls (we export base64 instead)."""
|
|
2113
|
+
# import re
|
|
2114
|
+
# return re.sub(r"(?m)^\s*plt\.show\(\)\s*$", "", code)
|
|
2115
|
+
|
|
2116
|
+
|
|
2117
|
+
# def inject_display_shim(code: str) -> str:
|
|
2118
|
+
# """
|
|
2119
|
+
# Provide display()/HTML() if missing, forwarding to our executor hook.
|
|
2120
|
+
# Harmless if the names already exist.
|
|
2121
|
+
# """
|
|
2122
|
+
# shim = (
|
|
2123
|
+
# "try:\n"
|
|
2124
|
+
# " display\n"
|
|
2125
|
+
# "except NameError:\n"
|
|
2126
|
+
# " def display(obj=None, **kwargs):\n"
|
|
2127
|
+
# " __builtins__.get('_smx_display', print)(obj)\n"
|
|
2128
|
+
# "try:\n"
|
|
2129
|
+
# " HTML\n"
|
|
2130
|
+
# "except NameError:\n"
|
|
2131
|
+
# " class HTML:\n"
|
|
2132
|
+
# " def __init__(self, data): self.data = str(data)\n"
|
|
2133
|
+
# " def _repr_html_(self): return self.data\n"
|
|
2134
|
+
# "\n"
|
|
2135
|
+
# )
|
|
2136
|
+
# return shim + code
|
|
2137
|
+
|
|
2138
|
+
|
|
2139
|
+
# def strip_spurious_column_tokens(code: str) -> str:
|
|
2140
|
+
# """
|
|
2141
|
+
# Remove common stop-words ('the','whether', ...) when they appear
|
|
2142
|
+
# inside column lists, e.g.:
|
|
2143
|
+
# predictors = ['BMI','the','HbA1c']
|
|
2144
|
+
# df[['GGT','whether','BMI']]
|
|
2145
|
+
# Leaves other strings intact.
|
|
2146
|
+
# """
|
|
2147
|
+
# STOP = {
|
|
2148
|
+
# "the","whether","a","an","and","or","of","to","in","on","for","by",
|
|
2149
|
+
# "with","as","at","from","that","this","these","those","is","are","was","were",
|
|
2150
|
+
# "coef", "Coef", "coefficient", "Coefficient"
|
|
2151
|
+
# }
|
|
2152
|
+
|
|
2153
|
+
# def _norm(s: str) -> str:
|
|
2154
|
+
# return re.sub(r"[^a-z0-9]+", "", s.lower())
|
|
2155
|
+
|
|
2156
|
+
# def _clean_list(content: str) -> str:
|
|
2157
|
+
# # Rebuild a string list, keeping only non-stopword items
|
|
2158
|
+
# items = re.findall(r"(['\"])(.*?)\1", content)
|
|
2159
|
+
# if not items:
|
|
2160
|
+
# return "[" + content + "]"
|
|
2161
|
+
# keep = [f"{q}{s}{q}" for (q, s) in items if _norm(s) not in STOP]
|
|
2162
|
+
# return "[" + ", ".join(keep) + "]"
|
|
2163
|
+
|
|
2164
|
+
# # Variable assignments: predictors/features/columns/cols = [...]
|
|
2165
|
+
# code = re.sub(
|
|
2166
|
+
# r"(?m)\b(predictors|features|columns|cols)\s*=\s*\[([^\]]+)\]",
|
|
2167
|
+
# lambda m: f"{m.group(1)} = " + _clean_list(m.group(2)),
|
|
2168
|
+
# code
|
|
2169
|
+
# )
|
|
2170
|
+
|
|
2171
|
+
# # df[[ ... ]] selections
|
|
2172
|
+
# code = re.sub(
|
|
2173
|
+
# r"df\s*\[\s*\[([^\]]+)\]\s*\]", lambda m: "df[" + _clean_list(m.group(1)) + "]", code)
|
|
2174
|
+
|
|
2175
|
+
# return code
|
|
2176
|
+
|
|
2177
|
+
|
|
2178
|
+
# def patch_prefix_seaborn_calls(code: str) -> str:
|
|
2179
|
+
# """
|
|
2180
|
+
# Ensure bare seaborn calls are prefixed with `sns.`.
|
|
2181
|
+
# E.g., `barplot(...)` → `sns.barplot(...)`, `heatmap(...)` → `sns.heatmap(...)`, etc.
|
|
2182
|
+
# """
|
|
2183
|
+
# if "sns." in code:
|
|
2184
|
+
# # still fix any leftover bare calls alongside prefixed ones
|
|
2185
|
+
# pass
|
|
2186
|
+
|
|
2187
|
+
# # functions commonly used from seaborn
|
|
2188
|
+
# funcs = [
|
|
2189
|
+
# "barplot","countplot","boxplot","violinplot","stripplot","swarmplot",
|
|
2190
|
+
# "histplot","kdeplot","jointplot","pairplot","heatmap","clustermap",
|
|
2191
|
+
# "scatterplot","lineplot","catplot","displot","lmplot"
|
|
2192
|
+
# ]
|
|
2193
|
+
# # Replace bare function calls not already qualified by a dot (e.g., obj.barplot)
|
|
2194
|
+
# # (?<![\w.]) ensures no preceding word char or dot; avoids touching obj.barplot or mybarplot
|
|
2195
|
+
# pattern = re.compile(r"(?<![\w\.])(" + "|".join(funcs) + r")\s*\(", flags=re.MULTILINE)
|
|
2196
|
+
|
|
2197
|
+
# def _add_prefix(m):
|
|
2198
|
+
# fn = m.group(1)
|
|
2199
|
+
# return f"sns.{fn}("
|
|
2200
|
+
|
|
2201
|
+
# return pattern.sub(_add_prefix, code)
|
|
2202
|
+
|
|
2203
|
+
|
|
2204
|
+
# def patch_ensure_seaborn_import(code: str) -> str:
|
|
2205
|
+
# """
|
|
2206
|
+
# If seaborn is used (sns.) ensure `import seaborn as sns` exists once.
|
|
2207
|
+
# Also set a quiet theme for consistent visuals.
|
|
2208
|
+
# """
|
|
2209
|
+
# needs_sns = "sns." in code
|
|
2210
|
+
# has_import = bool(re.search(r"^\s*import\s+seaborn\s+as\s+sns\s*$", code, flags=re.MULTILINE))
|
|
2211
|
+
# if needs_sns and not has_import:
|
|
2212
|
+
# # Insert after the first block of imports if possible, else at top
|
|
2213
|
+
# import_block = re.search(r"^(?:\s*(?:from\s+\S+\s+import\s+.+|import\s+\S+)\s*\n)+", code, flags=re.MULTILINE)
|
|
2214
|
+
# inject = "import seaborn as sns\ntry:\n sns.set_theme()\nexcept Exception:\n pass\n"
|
|
2215
|
+
# if import_block:
|
|
2216
|
+
# start = import_block.end()
|
|
2217
|
+
# code = code[:start] + inject + code[start:]
|
|
2218
|
+
# else:
|
|
2219
|
+
# code = inject + code
|
|
2220
|
+
# return code
|
|
2221
|
+
|
|
2222
|
+
|
|
2223
|
+
# def patch_pie_chart(code, df, user_question=None, top_n: int = 12):
|
|
2224
|
+
# """
|
|
2225
|
+
# Normalise pie-chart requests.
|
|
2226
|
+
|
|
2227
|
+
# Supports three patterns:
|
|
2228
|
+
# A) Threshold split cohorts, e.g. "HbA1c ≥ 6.5 vs < 6.5" → two pies per categorical + grouped bar.
|
|
2229
|
+
# B) Facet-by categories, e.g. "Ethnicity across BMI categories" or "bin BMI into Normal/Overweight/Obese"
|
|
2230
|
+
# → one pie per facet level (grid) + counts bar of facet sizes.
|
|
2231
|
+
# C) Single pie when no split/facet is requested.
|
|
2232
|
+
|
|
2233
|
+
# Notes:
|
|
2234
|
+
# - Pie variables must be categorical (or numeric binned).
|
|
2235
|
+
# - Facet variables can be categorical or numeric (we bin numeric; BMI gets WHO bins).
|
|
2236
|
+
# """
|
|
2237
|
+
|
|
2238
|
+
# q = (user_question or "")
|
|
2239
|
+
# q_low = q.lower()
|
|
2240
|
+
|
|
2241
|
+
# # Prefer explicit: df['col'].value_counts()
|
|
2242
|
+
# m = re.search(r"df\[['\"](\w+)['\"]\]\.value_counts\(", code)
|
|
2243
|
+
# col = m.group(1) if m else None
|
|
2244
|
+
|
|
2245
|
+
# # ---------- helpers ----------
|
|
2246
|
+
# def _is_cat(col):
|
|
2247
|
+
# return (str(df[col].dtype).startswith("category")
|
|
2248
|
+
# or df[col].dtype == "object"
|
|
2249
|
+
# or (pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() <= 20))
|
|
2250
|
+
|
|
2251
|
+
# def _cats_from_question(question: str):
|
|
2252
|
+
# found = []
|
|
2253
|
+
# for c in df.columns:
|
|
2254
|
+
# if c.lower() in question.lower() and _is_cat(c):
|
|
2255
|
+
# found.append(c)
|
|
2256
|
+
# # dedupe preserve order
|
|
2257
|
+
# seen, out = set(), []
|
|
2258
|
+
# for c in found:
|
|
2259
|
+
# if c not in seen:
|
|
2260
|
+
# out.append(c); seen.add(c)
|
|
2261
|
+
# return out
|
|
2262
|
+
|
|
2263
|
+
# def _fallback_cat():
|
|
2264
|
+
# cats = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2265
|
+
# if not cats: return None
|
|
2266
|
+
# cats.sort(key=lambda t: t[1])
|
|
2267
|
+
# return cats[0][0]
|
|
2187
2268
|
|
|
2188
|
-
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
|
|
2192
|
-
|
|
2193
|
-
|
|
2194
|
-
|
|
2195
|
-
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2222
|
-
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
|
|
2239
|
-
|
|
2240
|
-
|
|
2241
|
-
|
|
2242
|
-
|
|
2243
|
-
|
|
2244
|
-
|
|
2245
|
-
|
|
2246
|
-
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
|
|
2251
|
-
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
|
|
2309
|
-
|
|
2310
|
-
|
|
2311
|
-
|
|
2312
|
-
|
|
2313
|
-
|
|
2314
|
-
|
|
2315
|
-
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
|
|
2319
|
-
|
|
2320
|
-
|
|
2321
|
-
|
|
2322
|
-
|
|
2323
|
-
|
|
2324
|
-
|
|
2325
|
-
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
2329
|
-
|
|
2330
|
-
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
|
|
2334
|
-
|
|
2335
|
-
|
|
2336
|
-
|
|
2337
|
-
|
|
2338
|
-
|
|
2339
|
-
|
|
2340
|
-
|
|
2341
|
-
|
|
2342
|
-
|
|
2343
|
-
|
|
2344
|
-
|
|
2345
|
-
|
|
2346
|
-
|
|
2347
|
-
|
|
2348
|
-
|
|
2349
|
-
|
|
2350
|
-
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
|
|
2354
|
-
|
|
2355
|
-
|
|
2356
|
-
|
|
2357
|
-
|
|
2358
|
-
|
|
2359
|
-
|
|
2360
|
-
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
|
-
|
|
2364
|
-
|
|
2365
|
-
|
|
2366
|
-
|
|
2367
|
-
|
|
2368
|
-
|
|
2369
|
-
|
|
2370
|
-
|
|
2371
|
-
|
|
2372
|
-
|
|
2373
|
-
|
|
2374
|
-
|
|
2375
|
-
|
|
2376
|
-
|
|
2377
|
-
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2384
|
-
|
|
2385
|
-
|
|
2386
|
-
|
|
2387
|
-
|
|
2388
|
-
|
|
2389
|
-
|
|
2390
|
-
|
|
2391
|
-
|
|
2392
|
-
|
|
2393
|
-
|
|
2394
|
-
|
|
2395
|
-
|
|
2396
|
-
|
|
2397
|
-
|
|
2398
|
-
|
|
2399
|
-
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
|
|
2414
|
-
|
|
2415
|
-
|
|
2416
|
-
|
|
2417
|
-
|
|
2418
|
-
|
|
2419
|
-
|
|
2420
|
-
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2434
|
-
|
|
2435
|
-
|
|
2436
|
-
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
|
|
2459
|
-
|
|
2460
|
-
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
|
|
2466
|
-
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
|
|
2470
|
-
|
|
2471
|
-
|
|
2472
|
-
|
|
2473
|
-
|
|
2474
|
-
|
|
2475
|
-
|
|
2476
|
-
|
|
2477
|
-
|
|
2478
|
-
|
|
2479
|
-
|
|
2480
|
-
|
|
2481
|
-
|
|
2482
|
-
|
|
2483
|
-
|
|
2484
|
-
|
|
2485
|
-
|
|
2486
|
-
|
|
2487
|
-
|
|
2488
|
-
|
|
2489
|
-
|
|
2490
|
-
|
|
2491
|
-
|
|
2492
|
-
|
|
2493
|
-
|
|
2494
|
-
|
|
2495
|
-
|
|
2496
|
-
|
|
2497
|
-
|
|
2498
|
-
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
2502
|
-
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
|
|
2514
|
-
|
|
2515
|
-
|
|
2516
|
-
|
|
2517
|
-
|
|
2518
|
-
|
|
2519
|
-
|
|
2520
|
-
|
|
2521
|
-
|
|
2522
|
-
|
|
2523
|
-
|
|
2524
|
-
|
|
2525
|
-
|
|
2526
|
-
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
|
|
2536
|
-
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
|
|
2547
|
-
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2551
|
-
|
|
2552
|
-
|
|
2553
|
-
def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2554
|
-
|
|
2555
|
-
|
|
2556
|
-
|
|
2557
|
-
|
|
2558
|
-
|
|
2559
|
-
|
|
2560
|
-
|
|
2561
|
-
|
|
2562
|
-
|
|
2563
|
-
|
|
2564
|
-
|
|
2565
|
-
|
|
2566
|
-
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2570
|
-
|
|
2571
|
-
|
|
2572
|
-
|
|
2573
|
-
|
|
2574
|
-
|
|
2575
|
-
|
|
2576
|
-
|
|
2577
|
-
|
|
2578
|
-
|
|
2579
|
-
|
|
2580
|
-
|
|
2581
|
-
|
|
2582
|
-
|
|
2583
|
-
|
|
2584
|
-
|
|
2585
|
-
|
|
2586
|
-
|
|
2587
|
-
|
|
2588
|
-
|
|
2589
|
-
|
|
2590
|
-
|
|
2591
|
-
|
|
2592
|
-
|
|
2593
|
-
|
|
2594
|
-
|
|
2595
|
-
|
|
2596
|
-
|
|
2597
|
-
|
|
2598
|
-
|
|
2599
|
-
|
|
2600
|
-
|
|
2601
|
-
|
|
2602
|
-
|
|
2603
|
-
|
|
2604
|
-
|
|
2605
|
-
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2609
|
-
|
|
2610
|
-
|
|
2611
|
-
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
2684
|
-
|
|
2685
|
-
|
|
2686
|
-
|
|
2687
|
-
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
|
|
2691
|
-
|
|
2692
|
-
|
|
2693
|
-
def
|
|
2694
|
-
|
|
2695
|
-
|
|
2696
|
-
|
|
2697
|
-
|
|
2698
|
-
|
|
2699
|
-
|
|
2700
|
-
|
|
2701
|
-
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
|
|
2709
|
-
|
|
2710
|
-
|
|
2711
|
-
|
|
2712
|
-
|
|
2713
|
-
|
|
2714
|
-
|
|
2715
|
-
|
|
2716
|
-
|
|
2717
|
-
|
|
2718
|
-
|
|
2719
|
-
|
|
2720
|
-
|
|
2721
|
-
|
|
2722
|
-
|
|
2723
|
-
|
|
2724
|
-
|
|
2725
|
-
|
|
2726
|
-
|
|
2727
|
-
|
|
2728
|
-
|
|
2729
|
-
|
|
2730
|
-
|
|
2731
|
-
|
|
2732
|
-
|
|
2733
|
-
|
|
2734
|
-
|
|
2735
|
-
|
|
2736
|
-
|
|
2737
|
-
|
|
2738
|
-
|
|
2739
|
-
|
|
2740
|
-
|
|
2741
|
-
|
|
2742
|
-
|
|
2743
|
-
|
|
2744
|
-
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
|
|
2749
|
-
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2756
|
-
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2798
|
-
has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2799
|
-
has_var = re.search(r"\bvar\b", code) is not None
|
|
2800
|
-
has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2801
|
-
|
|
2802
|
-
if has_plot_df and has_var and has_fh:
|
|
2803
|
-
replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2804
|
-
elif has_plot_df and has_var:
|
|
2805
|
-
replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2806
|
-
elif has_plot_df:
|
|
2807
|
-
replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2808
|
-
else:
|
|
2809
|
-
replacement = "sns.barplot(ax=ax)"
|
|
2810
|
-
|
|
2811
|
-
# ensure an Axes 'ax' exists (no-op if already present)
|
|
2812
|
-
if "ax =" not in code:
|
|
2813
|
-
code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2814
|
-
|
|
2815
|
-
return pattern.sub(replacement, code)
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2819
|
-
"""
|
|
2820
|
-
Parses the raw text to extract and format the 'refined question',
|
|
2821
|
-
'intents (tasks)', and 'chronology of tasks' sections.
|
|
2822
|
-
Args:
|
|
2823
|
-
raw_text: The complete input string containing the ML pipeline structure.
|
|
2824
|
-
Returns:
|
|
2825
|
-
A tuple containing:
|
|
2826
|
-
(formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2827
|
-
"""
|
|
2828
|
-
# --- 1. Regex Pattern to Extract Sections ---
|
|
2829
|
-
# The pattern uses capturing groups (?) to look for the section headers
|
|
2830
|
-
# (e.g., 'refined question:') and captures all the content until the next
|
|
2831
|
-
# section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2269
|
+
# def _infer_comp_pref(question: str) -> str:
|
|
2270
|
+
# ql = (question or "").lower()
|
|
2271
|
+
# if "heatmap" in ql or "matrix" in ql:
|
|
2272
|
+
# return "heatmap"
|
|
2273
|
+
# if "100%" in ql or "100 percent" in ql or "proportion" in ql or "share" in ql or "composition" in ql:
|
|
2274
|
+
# return "stacked_bar_pct"
|
|
2275
|
+
# if "stacked" in ql:
|
|
2276
|
+
# return "stacked_bar"
|
|
2277
|
+
# if "grouped" in ql or "clustered" in ql or "side-by-side" in ql:
|
|
2278
|
+
# return "grouped_bar"
|
|
2279
|
+
# return "counts_bar"
|
|
2280
|
+
|
|
2281
|
+
# # parse threshold split like "HbA1c ≥ 6.5"
|
|
2282
|
+
# def _parse_split(question: str):
|
|
2283
|
+
# ops_map = {"≥": ">=", "≤": "<=", ">=": ">=", "<=": "<=", ">": ">", "<": "<", "==": "==", "=": "=="}
|
|
2284
|
+
# m = re.search(r"([A-Za-z_][A-Za-z0-9_ ]*)\s*(≥|<=|≤|>=|>|<|==|=)\s*([0-9]+(?:\.[0-9]+)?)", question)
|
|
2285
|
+
# if not m: return None
|
|
2286
|
+
# col_raw, op_raw, val_raw = m.group(1).strip(), m.group(2), m.group(3)
|
|
2287
|
+
# op = ops_map.get(op_raw);
|
|
2288
|
+
# if not op: return None
|
|
2289
|
+
# # case-insensitive column match
|
|
2290
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2291
|
+
# col = candidates.get(col_raw.lower())
|
|
2292
|
+
# if not col: return None
|
|
2293
|
+
# try: val = float(val_raw)
|
|
2294
|
+
# except Exception: return None
|
|
2295
|
+
# return (col, op, val)
|
|
2296
|
+
|
|
2297
|
+
# # facet extractor: "by/ across / within each / per <col>", or "bin <col>", or named category list
|
|
2298
|
+
# def _extract_facet(question: str):
|
|
2299
|
+
# # 1) explicit "by/ across / within / per <col>"
|
|
2300
|
+
# for kw in [" by ", " across ", " within ", " within each ", " per "]:
|
|
2301
|
+
# m = re.search(kw + r"([A-Za-z_][A-Za-z0-9_ ]*)", " " + question + " ", flags=re.IGNORECASE)
|
|
2302
|
+
# if m:
|
|
2303
|
+
# col_raw = m.group(1).strip()
|
|
2304
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2305
|
+
# if col_raw.lower() in candidates:
|
|
2306
|
+
# return (candidates[col_raw.lower()], "auto")
|
|
2307
|
+
# # 2) "bin <col>"
|
|
2308
|
+
# m2 = re.search(r"bin\s+([A-Za-z_][A-Za-z0-9_ ]*)", question, flags=re.IGNORECASE)
|
|
2309
|
+
# if m2:
|
|
2310
|
+
# col_raw = m2.group(1).strip()
|
|
2311
|
+
# candidates = {c.lower(): c for c in df.columns}
|
|
2312
|
+
# if col_raw.lower() in candidates:
|
|
2313
|
+
# return (candidates[col_raw.lower()], "bin")
|
|
2314
|
+
# # 3) BMI special: mentions of normal/overweight/obese imply BMI categories
|
|
2315
|
+
# if any(kw in question.lower() for kw in ["normal", "overweight", "obese", "obesity"]) and \
|
|
2316
|
+
# any(c.lower() == "bmi" for c in df.columns.str.lower()):
|
|
2317
|
+
# bmi_col = [c for c in df.columns if c.lower() == "bmi"][0]
|
|
2318
|
+
# return (bmi_col, "bmi")
|
|
2319
|
+
# return None
|
|
2320
|
+
|
|
2321
|
+
# def _bmi_bins(series: pd.Series):
|
|
2322
|
+
# # WHO cutoffs
|
|
2323
|
+
# bins = [-np.inf, 18.5, 25, 30, np.inf]
|
|
2324
|
+
# labels = ["Underweight (<18.5)", "Normal (18.5–24.9)", "Overweight (25–29.9)", "Obese (≥30)"]
|
|
2325
|
+
# return pd.cut(series.astype(float), bins=bins, labels=labels, right=False)
|
|
2326
|
+
|
|
2327
|
+
# wants_pie = ("pie" in q_low) or ("plt.pie(" in code) or ("kind='pie'" in code) or ('kind="pie"' in code)
|
|
2328
|
+
# if not wants_pie:
|
|
2329
|
+
# return code
|
|
2330
|
+
|
|
2331
|
+
# split = _parse_split(q)
|
|
2332
|
+
# facet = _extract_facet(q)
|
|
2333
|
+
# cats = _cats_from_question(q)
|
|
2334
|
+
# _comp_pref = _infer_comp_pref(q)
|
|
2335
|
+
|
|
2336
|
+
# # Prefer explicitly referenced categorical like Ethnicity, Smoking_Status, Physical_Activity_Level
|
|
2337
|
+
# for hard in ["Ethnicity", "Smoking_Status", "Physical_Activity_Level"]:
|
|
2338
|
+
# if hard in df.columns and hard not in cats and hard.lower() in q_low:
|
|
2339
|
+
# cats.append(hard)
|
|
2340
|
+
|
|
2341
|
+
# # --------------- CASE A: threshold split (cohorts) ---------------
|
|
2342
|
+
# if split:
|
|
2343
|
+
# if not (cats or any(_is_cat(c) for c in df.columns)):
|
|
2344
|
+
# return code
|
|
2345
|
+
# if not cats:
|
|
2346
|
+
# pool = [(c, df[c].nunique()) for c in df.columns if _is_cat(c) and df[c].nunique() > 1]
|
|
2347
|
+
# pool.sort(key=lambda t: t[1])
|
|
2348
|
+
# cats = [t[0] for t in pool[:3]] if pool else []
|
|
2349
|
+
# if not cats:
|
|
2350
|
+
# return code
|
|
2351
|
+
|
|
2352
|
+
# split_col, op, val = split
|
|
2353
|
+
# cond_str = f"(df['{split_col}'] {op} {val})"
|
|
2354
|
+
# snippet = f"""
|
|
2355
|
+
# import numpy as np
|
|
2356
|
+
# import pandas as pd
|
|
2357
|
+
# import matplotlib.pyplot as plt
|
|
2358
|
+
|
|
2359
|
+
# _mask_a = ({cond_str}) & df['{split_col}'].notna()
|
|
2360
|
+
# _mask_b = (~({cond_str})) & df['{split_col}'].notna()
|
|
2361
|
+
|
|
2362
|
+
# _cohort_a_name = "{split_col} {op} {val}"
|
|
2363
|
+
# _cohort_b_name = "NOT ({split_col} {op} {val})"
|
|
2364
|
+
|
|
2365
|
+
# _cat_cols = {cats!r}
|
|
2366
|
+
# n = len(_cat_cols)
|
|
2367
|
+
# fig, axes = plt.subplots(nrows=n, ncols=2, figsize=(12, 5*n))
|
|
2368
|
+
# if n == 1:
|
|
2369
|
+
# axes = np.array([axes])
|
|
2370
|
+
|
|
2371
|
+
# for i, col in enumerate(_cat_cols):
|
|
2372
|
+
# s_a = df.loc[_mask_a, col].astype(str).value_counts().nlargest({top_n})
|
|
2373
|
+
# s_b = df.loc[_mask_b, col].astype(str).value_counts().nlargest({top_n})
|
|
2374
|
+
|
|
2375
|
+
# ax_a = axes[i, 0]; ax_b = axes[i, 1]
|
|
2376
|
+
# if len(s_a) > 0:
|
|
2377
|
+
# ax_a.pie(s_a.values, labels=[str(x) for x in s_a.index],
|
|
2378
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2379
|
+
# ax_a.set_title(f"{{col}} — {{_cohort_a_name}}"); ax_a.axis('equal')
|
|
2380
|
+
|
|
2381
|
+
# if len(s_b) > 0:
|
|
2382
|
+
# ax_b.pie(s_b.values, labels=[str(x) for x in s_b.index],
|
|
2383
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2384
|
+
# ax_b.set_title(f"{{col}} — {{_cohort_b_name}}"); ax_b.axis('equal')
|
|
2385
|
+
|
|
2386
|
+
# plt.tight_layout(); plt.show()
|
|
2387
|
+
|
|
2388
|
+
# # grouped bar complement
|
|
2389
|
+
# for col in _cat_cols:
|
|
2390
|
+
# _tmp = (df.loc[df['{split_col}'].notna(), [col, '{split_col}']]
|
|
2391
|
+
# .assign(__cohort=np.where({cond_str}, _cohort_a_name, _cohort_b_name)))
|
|
2392
|
+
# _tab = _tmp.groupby([col, "__cohort"]).size().unstack("__cohort").fillna(0)
|
|
2393
|
+
# _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2394
|
+
|
|
2395
|
+
# if _comp_pref == "grouped_bar":
|
|
2396
|
+
# ax = _tab.plot(kind='bar', rot=0, figsize=(10, 4))
|
|
2397
|
+
# ax.set_title(f"{col} by cohort (grouped)")
|
|
2398
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2399
|
+
# plt.tight_layout(); plt.show()
|
|
2400
|
+
|
|
2401
|
+
# elif _comp_pref == "stacked_bar":
|
|
2402
|
+
# ax = _tab.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2403
|
+
# ax.set_title(f"{col} by cohort (stacked)")
|
|
2404
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2405
|
+
# plt.tight_layout(); plt.show()
|
|
2406
|
+
|
|
2407
|
+
# elif _comp_pref == "stacked_bar_pct":
|
|
2408
|
+
# _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2409
|
+
# ax = _perc.plot(kind='bar', stacked=True, rot=0, figsize=(10, 4))
|
|
2410
|
+
# ax.set_title(f"{col} by cohort (100% stacked)")
|
|
2411
|
+
# ax.set_xlabel(col); ax.set_ylabel("Percent")
|
|
2412
|
+
# plt.tight_layout(); plt.show()
|
|
2413
|
+
|
|
2414
|
+
# elif _comp_pref == "heatmap":
|
|
2415
|
+
# _perc = _tab.div(_tab.sum(axis=1), axis=0) * 100
|
|
2416
|
+
# import numpy as np
|
|
2417
|
+
# fig, ax = plt.subplots(figsize=(8, max(3, 0.35*len(_perc))))
|
|
2418
|
+
# im = ax.imshow(_perc.values, aspect='auto')
|
|
2419
|
+
# ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2420
|
+
# ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2421
|
+
# ax.set_title(f"{col} by cohort — % heatmap")
|
|
2422
|
+
# for i in range(_perc.shape[0]):
|
|
2423
|
+
# for j in range(_perc.shape[1]):
|
|
2424
|
+
# ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2425
|
+
# fig.colorbar(im, ax=ax, label="%")
|
|
2426
|
+
# plt.tight_layout(); plt.show()
|
|
2427
|
+
|
|
2428
|
+
# else: # counts_bar (default)
|
|
2429
|
+
# ax = _tab.sum(axis=1).plot(kind='bar', rot=0, figsize=(10, 3))
|
|
2430
|
+
# ax.set_title(f"{col}: total counts (both cohorts)")
|
|
2431
|
+
# ax.set_xlabel(col); ax.set_ylabel("Count")
|
|
2432
|
+
# plt.tight_layout(); plt.show()
|
|
2433
|
+
# """.lstrip()
|
|
2434
|
+
# return snippet
|
|
2435
|
+
|
|
2436
|
+
# # --------------- CASE B: facet-by (categories/bins) ---------------
|
|
2437
|
+
# if facet:
|
|
2438
|
+
# facet_col, how = facet
|
|
2439
|
+
# # Build facet series
|
|
2440
|
+
# if pd.api.types.is_numeric_dtype(df[facet_col]):
|
|
2441
|
+
# if how == "bmi":
|
|
2442
|
+
# facet_series = _bmi_bins(df[facet_col])
|
|
2443
|
+
# else:
|
|
2444
|
+
# # generic numeric bins: 3 equal-width bins by default
|
|
2445
|
+
# facet_series = pd.cut(df[facet_col].astype(float), bins=3)
|
|
2446
|
+
# else:
|
|
2447
|
+
# facet_series = df[facet_col].astype(str)
|
|
2448
|
+
|
|
2449
|
+
# # Choose pie dimension (categorical to count inside each facet)
|
|
2450
|
+
# pie_dim = None
|
|
2451
|
+
# for c in cats:
|
|
2452
|
+
# if c in df.columns and _is_cat(c):
|
|
2453
|
+
# pie_dim = c; break
|
|
2454
|
+
# if pie_dim is None:
|
|
2455
|
+
# pie_dim = _fallback_cat()
|
|
2456
|
+
# if pie_dim is None:
|
|
2457
|
+
# return code
|
|
2458
|
+
|
|
2459
|
+
# snippet = f"""
|
|
2460
|
+
# import math
|
|
2461
|
+
# import pandas as pd
|
|
2462
|
+
# import matplotlib.pyplot as plt
|
|
2463
|
+
|
|
2464
|
+
# df = df.copy()
|
|
2465
|
+
# _preferred = "{facet_col}" if "{facet_col}" in df.columns else None
|
|
2466
|
+
|
|
2467
|
+
# def _select_facet_col(df, preferred=None):
|
|
2468
|
+
# if preferred is not None:
|
|
2469
|
+
# return preferred
|
|
2470
|
+
# # Prefer low-cardinality categoricals (readable pies/grids)
|
|
2471
|
+
# cat_cols = [
|
|
2472
|
+
# c for c in df.columns
|
|
2473
|
+
# if (df[c].dtype == 'object' or str(df[c].dtype).startswith('category'))
|
|
2474
|
+
# and df[c].nunique() > 1 and df[c].nunique() <= 20
|
|
2475
|
+
# ]
|
|
2476
|
+
# if cat_cols:
|
|
2477
|
+
# cat_cols.sort(key=lambda c: df[c].nunique())
|
|
2478
|
+
# return cat_cols[0]
|
|
2479
|
+
# # Else fall back to first usable numeric
|
|
2480
|
+
# num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c]) and df[c].nunique() > 1]
|
|
2481
|
+
# return num_cols[0] if num_cols else None
|
|
2482
|
+
|
|
2483
|
+
# _facet_col = _select_facet_col(df, _preferred)
|
|
2484
|
+
|
|
2485
|
+
# if _facet_col is None:
|
|
2486
|
+
# # Nothing suitable → single facet keeps pipeline alive
|
|
2487
|
+
# df["__facet__"] = "All"
|
|
2488
|
+
# else:
|
|
2489
|
+
# s = df[_facet_col]
|
|
2490
|
+
# if pd.api.types.is_numeric_dtype(s):
|
|
2491
|
+
# # Robust numeric binning: quantiles first, fallback to equal-width
|
|
2492
|
+
# uniq = pd.Series(s).dropna().nunique()
|
|
2493
|
+
# q = 3 if uniq < 10 else 4 if uniq < 30 else 5
|
|
2494
|
+
# try:
|
|
2495
|
+
# df["__facet__"] = pd.qcut(s.astype(float), q=q, duplicates="drop")
|
|
2496
|
+
# except Exception:
|
|
2497
|
+
# df["__facet__"] = pd.cut(s.astype(float), bins=q)
|
|
2498
|
+
# else:
|
|
2499
|
+
# # Cap long tails; keep top categories
|
|
2500
|
+
# vc = s.astype(str).value_counts()
|
|
2501
|
+
# keep = vc.index[:{top_n}]
|
|
2502
|
+
# df["__facet__"] = s.astype(str).where(s.astype(str).isin(keep), other="Other")
|
|
2503
|
+
|
|
2504
|
+
# levels = [str(x) for x in df["__facet__"].dropna().unique().tolist()]
|
|
2505
|
+
# levels = [x for x in levels if x != "nan"]
|
|
2506
|
+
# levels.sort()
|
|
2507
|
+
|
|
2508
|
+
# m = len(levels)
|
|
2509
|
+
# cols = 3 if m >= 3 else m or 1
|
|
2510
|
+
# rows = int(math.ceil(m / cols))
|
|
2511
|
+
|
|
2512
|
+
# fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(4*cols, 4*rows))
|
|
2513
|
+
# if not isinstance(axes, (list, np.ndarray)):
|
|
2514
|
+
# axes = np.array([[axes]])
|
|
2515
|
+
# axes = axes.reshape(rows, cols)
|
|
2516
|
+
|
|
2517
|
+
# for i, lvl in enumerate(levels):
|
|
2518
|
+
# r, c = divmod(i, cols)
|
|
2519
|
+
# ax = axes[r, c]
|
|
2520
|
+
# s = (df.loc[df["__facet"].astype(str) == str(lvl), "{pie_dim}"]
|
|
2521
|
+
# .astype(str).value_counts().nlargest({top_n}))
|
|
2522
|
+
# if len(s) > 0:
|
|
2523
|
+
# ax.pie(s.values, labels=[str(x) for x in s.index],
|
|
2524
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2525
|
+
# ax.set_title(f"{pie_dim} — {{lvl}}"); ax.axis('equal')
|
|
2526
|
+
|
|
2527
|
+
# # hide any empty subplots
|
|
2528
|
+
# for j in range(m, rows*cols):
|
|
2529
|
+
# r, c = divmod(j, cols)
|
|
2530
|
+
# axes[r, c].axis("off")
|
|
2531
|
+
|
|
2532
|
+
# plt.tight_layout(); plt.show()
|
|
2533
|
+
|
|
2534
|
+
# # --- companion visual (adaptive) ---
|
|
2535
|
+
# _comp_pref = "{_comp_pref}"
|
|
2536
|
+
# # build contingency table: pie_dim x facet
|
|
2537
|
+
# _tab = (df[["__facet__", "{pie_dim}"]]
|
|
2538
|
+
# .dropna()
|
|
2539
|
+
# .astype({{"__facet__": str, "{pie_dim}": str}})
|
|
2540
|
+
# .value_counts()
|
|
2541
|
+
# .unstack(level="__facet__")
|
|
2542
|
+
# .fillna(0))
|
|
2543
|
+
|
|
2544
|
+
# # keep top categories by overall size
|
|
2545
|
+
# _tab = _tab.loc[_tab.sum(axis=1).sort_values(ascending=False).index[:{top_n}]]
|
|
2546
|
+
|
|
2547
|
+
# if _comp_pref == "grouped_bar":
|
|
2548
|
+
# ax = _tab.T.plot(kind="bar", rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2549
|
+
# ax.set_title("{pie_dim} by {facet_col} (grouped)")
|
|
2550
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2551
|
+
# plt.tight_layout(); plt.show()
|
|
2552
|
+
|
|
2553
|
+
# elif _comp_pref == "stacked_bar":
|
|
2554
|
+
# ax = _tab.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_tab.columns)), 4))
|
|
2555
|
+
# ax.set_title("{pie_dim} by {facet_col} (stacked)")
|
|
2556
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2557
|
+
# plt.tight_layout(); plt.show()
|
|
2558
|
+
|
|
2559
|
+
# elif _comp_pref == "stacked_bar_pct":
|
|
2560
|
+
# _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100 # column-normalised to 100%
|
|
2561
|
+
# ax = _perc.T.plot(kind="bar", stacked=True, rot=0, figsize=(max(8, 1.2*len(_perc.columns)), 4))
|
|
2562
|
+
# ax.set_title("{pie_dim} by {facet_col} (100% stacked)")
|
|
2563
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Percent")
|
|
2564
|
+
# plt.tight_layout(); plt.show()
|
|
2565
|
+
|
|
2566
|
+
# elif _comp_pref == "heatmap":
|
|
2567
|
+
# _perc = _tab.div(_tab.sum(axis=0), axis=1) * 100
|
|
2568
|
+
# import numpy as np
|
|
2569
|
+
# fig, ax = plt.subplots(figsize=(max(6, 0.9*len(_perc.columns)), max(4, 0.35*len(_perc))))
|
|
2570
|
+
# im = ax.imshow(_perc.values, aspect='auto')
|
|
2571
|
+
# ax.set_xticks(range(_perc.shape[1])); ax.set_xticklabels(_perc.columns, rotation=0)
|
|
2572
|
+
# ax.set_yticks(range(_perc.shape[0])); ax.set_yticklabels(_perc.index)
|
|
2573
|
+
# ax.set_title("{pie_dim} by {facet_col} — % heatmap")
|
|
2574
|
+
# for i in range(_perc.shape[0]):
|
|
2575
|
+
# for j in range(_perc.shape[1]):
|
|
2576
|
+
# ax.text(j, i, f"{{_perc.values[i, j]:.1f}}%", ha="center", va="center")
|
|
2577
|
+
# fig.colorbar(im, ax=ax, label="%")
|
|
2578
|
+
# plt.tight_layout(); plt.show()
|
|
2579
|
+
|
|
2580
|
+
# else: # counts_bar (default denominators)
|
|
2581
|
+
# _counts = df["__facet"].value_counts()
|
|
2582
|
+
# ax = _counts.plot(kind="bar", rot=0, figsize=(6, 3))
|
|
2583
|
+
# ax.set_title("Counts by {facet_col}")
|
|
2584
|
+
# ax.set_xlabel("{facet_col}"); ax.set_ylabel("Count")
|
|
2585
|
+
# plt.tight_layout(); plt.show()
|
|
2586
|
+
|
|
2587
|
+
# """.lstrip()
|
|
2588
|
+
# return snippet
|
|
2589
|
+
|
|
2590
|
+
# # --------------- CASE C: single pie ---------------
|
|
2591
|
+
# chosen = None
|
|
2592
|
+
# for c in cats:
|
|
2593
|
+
# if c in df.columns and _is_cat(c):
|
|
2594
|
+
# chosen = c; break
|
|
2595
|
+
# if chosen is None:
|
|
2596
|
+
# chosen = _fallback_cat()
|
|
2597
|
+
|
|
2598
|
+
# if chosen:
|
|
2599
|
+
# snippet = f"""
|
|
2600
|
+
# import matplotlib.pyplot as plt
|
|
2601
|
+
# counts = df['{chosen}'].astype(str).value_counts().nlargest({top_n})
|
|
2602
|
+
# fig, ax = plt.subplots()
|
|
2603
|
+
# if len(counts) > 0:
|
|
2604
|
+
# ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2605
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2606
|
+
# ax.set_title('Distribution of {chosen} (top {top_n})')
|
|
2607
|
+
# ax.axis('equal')
|
|
2608
|
+
# plt.show()
|
|
2609
|
+
# """.lstrip()
|
|
2610
|
+
# return snippet
|
|
2611
|
+
|
|
2612
|
+
# # numeric last resort
|
|
2613
|
+
# num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
|
|
2614
|
+
# if num_cols:
|
|
2615
|
+
# col = num_cols[0]
|
|
2616
|
+
# snippet = f"""
|
|
2617
|
+
# import pandas as pd
|
|
2618
|
+
# import matplotlib.pyplot as plt
|
|
2619
|
+
# bins = pd.qcut(df['{col}'], q=5, duplicates='drop')
|
|
2620
|
+
# counts = bins.value_counts().sort_index()
|
|
2621
|
+
# fig, ax = plt.subplots()
|
|
2622
|
+
# if len(counts) > 0:
|
|
2623
|
+
# ax.pie(counts.values, labels=[str(i) for i in counts.index],
|
|
2624
|
+
# autopct='%1.1f%%', startangle=90, counterclock=False)
|
|
2625
|
+
# ax.set_title('Distribution of {col} (binned)')
|
|
2626
|
+
# ax.axis('equal')
|
|
2627
|
+
# plt.show()
|
|
2628
|
+
# """.lstrip()
|
|
2629
|
+
# return snippet
|
|
2630
|
+
|
|
2631
|
+
# return code
|
|
2632
|
+
|
|
2633
|
+
|
|
2634
|
+
# def patch_fix_seaborn_palette_calls(code: str) -> str:
|
|
2635
|
+
# """
|
|
2636
|
+
# Removes seaborn `palette=` when no `hue=` is present in the same call.
|
|
2637
|
+
# Fixes FutureWarning: 'Passing `palette` without assigning `hue` ...'.
|
|
2638
|
+
# """
|
|
2639
|
+
# if "sns." not in code:
|
|
2640
|
+
# return code
|
|
2641
|
+
|
|
2642
|
+
# # Targets common seaborn plotters
|
|
2643
|
+
# funcs = r"(boxplot|barplot|countplot|violinplot|stripplot|swarmplot|histplot|kdeplot)"
|
|
2644
|
+
# pattern = re.compile(rf"(sns\.{funcs}\s*\()([^)]*)\)", re.DOTALL)
|
|
2645
|
+
|
|
2646
|
+
# def _fix_call(m):
|
|
2647
|
+
# head, inner = m.group(1), m.group(2)
|
|
2648
|
+
# # If there's already hue=, keep as is
|
|
2649
|
+
# if re.search(r"(?<!\w)hue\s*=", inner):
|
|
2650
|
+
# return f"{head}{inner})"
|
|
2651
|
+
# # Otherwise remove palette=... safely (and any adjacent comma spacing)
|
|
2652
|
+
# inner2 = re.sub(r",\s*palette\s*=\s*[^,)\n]+", "", inner)
|
|
2653
|
+
# inner2 = re.sub(r"\bpalette\s*=\s*[^,)\n]+\s*,\s*", "", inner2)
|
|
2654
|
+
# inner2 = re.sub(r"\s*,\s*\)", ")", f"{inner2})")[:-1] # clean trailing comma before ')'
|
|
2655
|
+
# return f"{head}{inner2})"
|
|
2656
|
+
|
|
2657
|
+
# return pattern.sub(_fix_call, code)
|
|
2658
|
+
|
|
2659
|
+
# def _norm_col_name(s: str) -> str:
|
|
2660
|
+
# """normalise a column name: lowercase + strip non-alphanumerics."""
|
|
2661
|
+
# return re.sub(r"[^a-z0-9]+", "", str(s).lower())
|
|
2662
|
+
|
|
2663
|
+
|
|
2664
|
+
# def _first_present(df: pd.DataFrame, candidates: list[str]) -> str | None:
|
|
2665
|
+
# """return the actual df column that matches any candidate (after normalisation)."""
|
|
2666
|
+
# norm_map = {_norm_col_name(c): c for c in df.columns}
|
|
2667
|
+
# for cand in candidates:
|
|
2668
|
+
# hit = norm_map.get(_norm_col_name(cand))
|
|
2669
|
+
# if hit is not None:
|
|
2670
|
+
# return hit
|
|
2671
|
+
# return None
|
|
2672
|
+
|
|
2673
|
+
|
|
2674
|
+
# def _ensure_canonical_alias(df: pd.DataFrame, target: str, aliases: list[str]) -> tuple[pd.DataFrame, bool]:
|
|
2675
|
+
# """
|
|
2676
|
+
# If any alias exists, materialise a canonical copy at `target` (don’t drop the original).
|
|
2677
|
+
# Returns (df, found_bool).
|
|
2678
|
+
# """
|
|
2679
|
+
# if target in df.columns:
|
|
2680
|
+
# return df, True
|
|
2681
|
+
# col = _first_present(df, [target, *aliases])
|
|
2682
|
+
# if col is None:
|
|
2683
|
+
# return df, False
|
|
2684
|
+
# df[target] = df[col]
|
|
2685
|
+
# return df, True
|
|
2686
|
+
|
|
2687
|
+
|
|
2688
|
+
# def strip_python_dotenv(code: str) -> str:
|
|
2689
|
+
# """
|
|
2690
|
+
# Remove any use of python-dotenv from generated code, including:
|
|
2691
|
+
# - single and multi-line 'from dotenv import ...'
|
|
2692
|
+
# - 'import dotenv' (with or without alias) and calls via any alias
|
|
2693
|
+
# - load_dotenv/find_dotenv/dotenv_values calls (bare or prefixed)
|
|
2694
|
+
# - IPython magics (%load_ext dotenv, %dotenv, %env …)
|
|
2695
|
+
# - shell installs like '!pip install python-dotenv'
|
|
2696
|
+
# """
|
|
2697
|
+
# original = code
|
|
2698
|
+
|
|
2699
|
+
# # 0) Kill IPython magics & shell installs referencing dotenv
|
|
2700
|
+
# code = re.sub(r"^\s*%load_ext\s+dotenv\s*$", "", code, flags=re.MULTILINE)
|
|
2701
|
+
# code = re.sub(r"^\s*%dotenv\b.*$", "", code, flags=re.MULTILINE)
|
|
2702
|
+
# code = re.sub(r"^\s*%env\b.*$", "", code, flags=re.MULTILINE)
|
|
2703
|
+
# code = re.sub(r"^\s*!\s*pip\s+install\b.*dotenv.*$", "", code, flags=re.IGNORECASE | re.MULTILINE)
|
|
2704
|
+
|
|
2705
|
+
# # 1) Remove single-line 'from dotenv import ...'
|
|
2706
|
+
# code = re.sub(r"^\s*from\s+dotenv\s+import\s+.*$", "", code, flags=re.MULTILINE)
|
|
2707
|
+
|
|
2708
|
+
# # 2) Remove multi-line 'from dotenv import ( ... )' blocks
|
|
2709
|
+
# code = re.sub(
|
|
2710
|
+
# r"^\s*from\s+dotenv\s+import\s*\([\s\S]*?\)\s*$",
|
|
2711
|
+
# "",
|
|
2712
|
+
# code,
|
|
2713
|
+
# flags=re.MULTILINE,
|
|
2714
|
+
# )
|
|
2715
|
+
|
|
2716
|
+
# # 3) Remove 'import dotenv' (with optional alias). Capture alias names.
|
|
2717
|
+
# aliases = re.findall(r"^\s*import\s+dotenv\s+as\s+([A-Za-z_][A-Za-z0-9_]*)\s*$",
|
|
2718
|
+
# code, flags=re.MULTILINE)
|
|
2719
|
+
# code = re.sub(r"^\s*import\s+dotenv\s*(?:as\s+[A-Za-z_][A-Za-z0-9_]*)?\s*$",
|
|
2720
|
+
# "", code, flags=re.MULTILINE)
|
|
2721
|
+
|
|
2722
|
+
# # 4) Remove calls to load_dotenv / find_dotenv / dotenv_values with any prefix
|
|
2723
|
+
# # e.g., load_dotenv(...), dotenv.load_dotenv(...), dtenv.load_dotenv(...)
|
|
2724
|
+
# fn_names = r"(?:load_dotenv|find_dotenv|dotenv_values)"
|
|
2725
|
+
# # bare calls
|
|
2726
|
+
# code = re.sub(rf"^\s*{fn_names}\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2727
|
+
# # dotted calls with any identifier prefix (alias or module)
|
|
2728
|
+
# code = re.sub(rf"^\s*[A-Za-z_][A-Za-z0-9_]*\s*\.\s*{fn_names}\s*\([^)]*\)\s*$",
|
|
2729
|
+
# "", code, flags=re.MULTILINE)
|
|
2730
|
+
|
|
2731
|
+
# # 5) If any alias imported earlier slipped through (method chains etc.), remove lines using that alias.
|
|
2732
|
+
# for al in aliases:
|
|
2733
|
+
# code = re.sub(rf"^\s*{al}\s*\.\s*\w+\s*\([^)]*\)\s*$", "", code, flags=re.MULTILINE)
|
|
2734
|
+
|
|
2735
|
+
# # 6) Tidy excess blank lines
|
|
2736
|
+
# code = re.sub(r"\n{3,}", "\n\n", code).strip("\n") + "\n"
|
|
2737
|
+
# return code
|
|
2738
|
+
|
|
2739
|
+
|
|
2740
|
+
# def fix_predict_calls_records_arg(code: str) -> str:
|
|
2741
|
+
# """
|
|
2742
|
+
# If generated code calls predict_* with a list-of-dicts via .to_dict('records')
|
|
2743
|
+
# (or orient='records'), strip the .to_dict(...) so a DataFrame is passed instead.
|
|
2744
|
+
# Works line-by-line to avoid over-rewrites elsewhere.
|
|
2745
|
+
# Examples fixed:
|
|
2746
|
+
# predict_patient(X_test.iloc[:5].to_dict('records'))
|
|
2747
|
+
# predict_risk(df.head(3).to_dict(orient="records"))
|
|
2748
|
+
# → predict_patient(X_test.iloc[:5])
|
|
2749
|
+
# """
|
|
2750
|
+
# fixed_lines = []
|
|
2751
|
+
# for line in code.splitlines():
|
|
2752
|
+
# if "predict_" in line and "to_dict" in line and "records" in line:
|
|
2753
|
+
# line = re.sub(
|
|
2754
|
+
# r"\.to_dict\s*\(\s*(?:orient\s*=\s*)?['\"]records['\"]\s*\)",
|
|
2755
|
+
# "",
|
|
2756
|
+
# line
|
|
2757
|
+
# )
|
|
2758
|
+
# fixed_lines.append(line)
|
|
2759
|
+
# return "\n".join(fixed_lines)
|
|
2760
|
+
|
|
2761
|
+
|
|
2762
|
+
# def fix_fstring_backslash_paths(code: str) -> str:
|
|
2763
|
+
# """
|
|
2764
|
+
# Fix bad f-strings like: f"...{out_dir\\plots\\img.png}..."
|
|
2765
|
+
# → f"...{os.path.join(out_dir, r'plots\\img.png')}"
|
|
2766
|
+
# Only touches f-strings that contain a backslash path inside {...}.
|
|
2767
|
+
# """
|
|
2768
|
+
# def _fix_line(line: str) -> str:
|
|
2769
|
+
# # quick check: only f-strings need scanning
|
|
2770
|
+
# if not (("f\"" in line) or ("f'" in line) or ("f\"\"\"" in line) or ("f'''" in line)):
|
|
2771
|
+
# return line
|
|
2772
|
+
# # {var\rest-of-path} where var can be dotted (e.g., cfg.out)
|
|
2773
|
+
# pattern = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\.]*)\\([^}]+)\}")
|
|
2774
|
+
# def repl(m):
|
|
2775
|
+
# left = m.group(1)
|
|
2776
|
+
# right = m.group(2).strip().replace('"', '\\"')
|
|
2777
|
+
# return "{os.path.join(" + left + ', r"' + right + '")}'
|
|
2778
|
+
# return pattern.sub(repl, line)
|
|
2779
|
+
|
|
2780
|
+
# return "\n".join(_fix_line(ln) for ln in code.splitlines())
|
|
2781
|
+
|
|
2782
|
+
|
|
2783
|
+
# def ensure_os_import(code: str) -> str:
|
|
2784
|
+
# """
|
|
2785
|
+
# If os.path.join is used but 'import os' is missing, inject it at the top.
|
|
2786
|
+
# """
|
|
2787
|
+
# needs = "os.path.join(" in code
|
|
2788
|
+
# has_import_os = re.search(r"^\s*import\s+os\b", code, flags=re.MULTILINE) is not None
|
|
2789
|
+
# has_from_os = re.search(r"^\s*from\s+os\s+import\b", code, flags=re.MULTILINE) is not None
|
|
2790
|
+
# if needs and not (has_import_os or has_from_os):
|
|
2791
|
+
# return "import os\n" + code
|
|
2792
|
+
# return code
|
|
2793
|
+
|
|
2794
|
+
|
|
2795
|
+
# def fix_seaborn_boxplot_nameerror(code: str) -> str:
|
|
2796
|
+
# """
|
|
2797
|
+
# Fix bad calls like: sns.boxplot(boxplot)
|
|
2798
|
+
# Heuristic:
|
|
2799
|
+
# - If plot_df + FH_status + var exist → sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)
|
|
2800
|
+
# - Else if plot_df + var → sns.boxplot(data=plot_df, y=var, ax=ax)
|
|
2801
|
+
# - Else if plot_df only → sns.boxplot(data=plot_df, ax=ax)
|
|
2802
|
+
# - Else → sns.boxplot(ax=ax)
|
|
2803
|
+
# Ensures a matplotlib Axes 'ax' exists.
|
|
2804
|
+
# """
|
|
2805
|
+
# pattern = re.compile(r"^\s*sns\.boxplot\s*\(\s*boxplot\s*\)\s*$", re.MULTILINE)
|
|
2806
|
+
# if not pattern.search(code):
|
|
2807
|
+
# return code
|
|
2808
|
+
|
|
2809
|
+
# has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2810
|
+
# has_var = re.search(r"\bvar\b", code) is not None
|
|
2811
|
+
# has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2812
|
+
|
|
2813
|
+
# if has_plot_df and has_var and has_fh:
|
|
2814
|
+
# replacement = "sns.boxplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2815
|
+
# elif has_plot_df and has_var:
|
|
2816
|
+
# replacement = "sns.boxplot(data=plot_df, y=var, ax=ax)"
|
|
2817
|
+
# elif has_plot_df:
|
|
2818
|
+
# replacement = "sns.boxplot(data=plot_df, ax=ax)"
|
|
2819
|
+
# else:
|
|
2820
|
+
# replacement = "sns.boxplot(ax=ax)"
|
|
2821
|
+
|
|
2822
|
+
# fixed = pattern.sub(replacement, code)
|
|
2823
|
+
|
|
2824
|
+
# # Ensure 'fig, ax = plt.subplots(...)' exists
|
|
2825
|
+
# if "ax=" in replacement and not re.search(r"\bfig\s*,\s*ax\s*=\s*plt\.subplots\s*\(", fixed):
|
|
2826
|
+
# # Insert right before the first seaborn call
|
|
2827
|
+
# m = re.search(r"^\s*sns\.", fixed, flags=re.MULTILINE)
|
|
2828
|
+
# insert_at = m.start() if m else 0
|
|
2829
|
+
# fixed = fixed[:insert_at] + "fig, ax = plt.subplots(figsize=(8,4))\n" + fixed[insert_at:]
|
|
2830
|
+
|
|
2831
|
+
# return fixed
|
|
2832
|
+
|
|
2833
|
+
|
|
2834
|
+
# def fix_seaborn_barplot_nameerror(code: str) -> str:
|
|
2835
|
+
# """
|
|
2836
|
+
# Fix bad calls like: sns.barplot(barplot)
|
|
2837
|
+
# Strategy mirrors boxplot fixer: prefer data=plot_df with x/y if available,
|
|
2838
|
+
# otherwise degrade safely to an empty call on an existing Axes.
|
|
2839
|
+
# """
|
|
2840
|
+
# import re
|
|
2841
|
+
# pattern = re.compile(r"^\s*sns\.barplot\s*\(\s*barplot\s*\)\s*$", re.MULTILINE)
|
|
2842
|
+
# if not pattern.search(code):
|
|
2843
|
+
# return code
|
|
2844
|
+
|
|
2845
|
+
# has_plot_df = re.search(r"\bplot_df\b", code) is not None
|
|
2846
|
+
# has_var = re.search(r"\bvar\b", code) is not None
|
|
2847
|
+
# has_fh = bool(re.search(r"['\"]FH_status['\"]", code) or re.search(r"\bFH_status\b", code))
|
|
2848
|
+
|
|
2849
|
+
# if has_plot_df and has_var and has_fh:
|
|
2850
|
+
# replacement = "sns.barplot(data=plot_df, x='FH_status', y=var, ax=ax)"
|
|
2851
|
+
# elif has_plot_df and has_var:
|
|
2852
|
+
# replacement = "sns.barplot(data=plot_df, y=var, ax=ax)"
|
|
2853
|
+
# elif has_plot_df:
|
|
2854
|
+
# replacement = "sns.barplot(data=plot_df, ax=ax)"
|
|
2855
|
+
# else:
|
|
2856
|
+
# replacement = "sns.barplot(ax=ax)"
|
|
2857
|
+
|
|
2858
|
+
# # ensure an Axes 'ax' exists (no-op if already present)
|
|
2859
|
+
# if "ax =" not in code:
|
|
2860
|
+
# code = "import matplotlib.pyplot as plt\nfig, ax = plt.subplots(figsize=(6,4))\n" + code
|
|
2861
|
+
|
|
2862
|
+
# return pattern.sub(replacement, code)
|
|
2863
|
+
|
|
2864
|
+
|
|
2865
|
+
# def parse_and_format_ml_pipeline(raw_text: str) -> tuple[str, str, str]:
|
|
2866
|
+
# """
|
|
2867
|
+
# Parses the raw text to extract and format the 'refined question',
|
|
2868
|
+
# 'intents (tasks)', and 'chronology of tasks' sections.
|
|
2869
|
+
# Args:
|
|
2870
|
+
# raw_text: The complete input string containing the ML pipeline structure.
|
|
2871
|
+
# Returns:
|
|
2872
|
+
# A tuple containing:
|
|
2873
|
+
# (formatted_question_str, formatted_intents_str, formatted_chronology_str)
|
|
2874
|
+
# """
|
|
2875
|
+
# # --- 1. Regex Pattern to Extract Sections ---
|
|
2876
|
+
# # The pattern uses capturing groups (?) to look for the section headers
|
|
2877
|
+
# # (e.g., 'refined question:') and captures all the content until the next
|
|
2878
|
+
# # section header or the end of the string. re.DOTALL is crucial for '.' to match newlines.
|
|
2832
2879
|
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2880
|
+
# pattern = re.compile(
|
|
2881
|
+
# r"refined question:(?P<question>.*?)"
|
|
2882
|
+
# r"intents \(tasks\):(?P<intents>.*?)"
|
|
2883
|
+
# r"Chronology of tasks:(?P<chronology>.*)",
|
|
2884
|
+
# re.IGNORECASE | re.DOTALL
|
|
2885
|
+
# )
|
|
2839
2886
|
|
|
2840
|
-
|
|
2887
|
+
# match = pattern.search(raw_text)
|
|
2841
2888
|
|
|
2842
|
-
|
|
2843
|
-
|
|
2889
|
+
# if not match:
|
|
2890
|
+
# raise ValueError("Input text structure does not match the expected pattern.")
|
|
2844
2891
|
|
|
2845
|
-
|
|
2846
|
-
|
|
2847
|
-
|
|
2848
|
-
|
|
2892
|
+
# # --- 2. Extract Content ---
|
|
2893
|
+
# question_content = match.group('question').strip()
|
|
2894
|
+
# intents_content = match.group('intents').strip()
|
|
2895
|
+
# chronology_content = match.group('chronology').strip()
|
|
2849
2896
|
|
|
2850
|
-
|
|
2897
|
+
# # --- 3. Formatting Functions ---
|
|
2851
2898
|
|
|
2852
|
-
|
|
2853
|
-
|
|
2854
|
-
|
|
2855
|
-
|
|
2899
|
+
# def format_question(content):
|
|
2900
|
+
# """Formats the Refined Question section."""
|
|
2901
|
+
# # Clean up leading/trailing whitespace and ensure clean paragraphs
|
|
2902
|
+
# content = content.strip().replace('\n', ' ').replace(' ', ' ')
|
|
2856
2903
|
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
|
-
|
|
2860
|
-
|
|
2861
|
-
|
|
2862
|
-
|
|
2863
|
-
|
|
2864
|
-
|
|
2865
|
-
|
|
2866
|
-
|
|
2867
|
-
|
|
2868
|
-
|
|
2904
|
+
# # Simple formatting using Markdown headers and bolding
|
|
2905
|
+
# formatted = (
|
|
2906
|
+
# # "## 1. Project Goal and Objectives\n\n"
|
|
2907
|
+
# "<b> Refined Question:</b>\n"
|
|
2908
|
+
# f"{content}\n"
|
|
2909
|
+
# )
|
|
2910
|
+
# return formatted
|
|
2911
|
+
|
|
2912
|
+
# def format_intents(content):
|
|
2913
|
+
# """Formats the Intents (Tasks) section as a structured list."""
|
|
2914
|
+
# # Use regex to find and format each numbered task
|
|
2915
|
+
# # It finds 'N. **Text** - ...' and breaks it down.
|
|
2869
2916
|
|
|
2870
|
-
|
|
2871
|
-
|
|
2872
|
-
|
|
2873
|
-
|
|
2917
|
+
# tasks = []
|
|
2918
|
+
# # Pattern: N. **Text** - Content (including newlines, non-greedy)
|
|
2919
|
+
# # We need to explicitly handle the list items starting with '-' within the content
|
|
2920
|
+
# task_pattern = re.compile(r'(\d+\. \*\*.*?\*\*.*?)(?=\n\d+\. \*\*|\Z)', re.DOTALL)
|
|
2874
2921
|
|
|
2875
|
-
|
|
2876
|
-
|
|
2922
|
+
# # Split the content by lines and join tasks back into clean strings
|
|
2923
|
+
# raw_tasks = [m.group(1).strip() for m in task_pattern.finditer(content)]
|
|
2877
2924
|
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
|
|
2925
|
+
# for task in raw_tasks:
|
|
2926
|
+
# # Replace the initial task number and **Heading** with a Heading 3
|
|
2927
|
+
# task = re.sub(r'^\d+\. (\*\*.*?\*\*)', r'### \1', task, count=1, flags=re.MULTILINE)
|
|
2881
2928
|
|
|
2882
|
-
|
|
2883
|
-
|
|
2884
|
-
|
|
2929
|
+
# # Replace list markers (' - ') with Markdown bullets ('* ') for clarity
|
|
2930
|
+
# task = task.replace('\n - ', '\n* ').replace('- ', '* ', 1)
|
|
2931
|
+
# tasks.append(task)
|
|
2885
2932
|
|
|
2886
|
-
|
|
2933
|
+
# formatted_tasks = "\n\n".join(tasks)
|
|
2934
|
+
|
|
2935
|
+
# return (
|
|
2936
|
+
# "\n---\n"
|
|
2937
|
+
# "## 2. Methodology and Tasks\n\n"
|
|
2938
|
+
# f"{formatted_tasks}\n"
|
|
2939
|
+
# )
|
|
2940
|
+
|
|
2941
|
+
# def format_chronology(content):
|
|
2942
|
+
# """Formats the Chronology section."""
|
|
2943
|
+
# # Uses the given LaTeX format
|
|
2944
|
+
# content = content.strip().replace(' ', ' \rightarrow ')
|
|
2945
|
+
# formatted = (
|
|
2946
|
+
# "\n---\n"
|
|
2947
|
+
# "## 3. Chronology of Tasks\n"
|
|
2948
|
+
# f"$$\\text{{{content}}}$$"
|
|
2949
|
+
# )
|
|
2950
|
+
# return formatted
|
|
2951
|
+
|
|
2952
|
+
# # --- 4. Format and Return ---
|
|
2953
|
+
# formatted_question = format_question(question_content)
|
|
2954
|
+
# formatted_intents = format_intents(intents_content)
|
|
2955
|
+
# formatted_chronology = format_chronology(chronology_content)
|
|
2956
|
+
|
|
2957
|
+
# return formatted_question, formatted_intents, formatted_chronology
|
|
2958
|
+
|
|
2959
|
+
|
|
2960
|
+
# def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
2961
|
+
# """Combines all formatted parts into a final report string."""
|
|
2962
|
+
# return (
|
|
2963
|
+
# "# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
2964
|
+
# f"{formatted_question}\n"
|
|
2965
|
+
# f"{formatted_intents}\n"
|
|
2966
|
+
# f"{formatted_chronology}\n"
|
|
2967
|
+
# )
|
|
2968
|
+
|
|
2969
|
+
|
|
2970
|
+
# def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
2971
|
+
# """
|
|
2972
|
+
# Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
2973
|
+
# from_predictions(...) which works for multi-label loops without requiring
|
|
2974
|
+
# the estimator to expose _estimator_type.
|
|
2975
|
+
# """
|
|
2976
|
+
# return re.sub(
|
|
2977
|
+
# r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
2978
|
+
# r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
2979
|
+
# code
|
|
2980
|
+
# )
|
|
2981
|
+
|
|
2982
|
+
|
|
2983
|
+
# def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
2984
|
+
# """
|
|
2985
|
+
# Ensure every Matplotlib/Seaborn Axes has a title.
|
|
2986
|
+
# Uses refined_question -> askai_question -> fallback.
|
|
2987
|
+
# Only sets a title if it's currently empty.
|
|
2988
|
+
# """
|
|
2989
|
+
# import matplotlib.pyplot as plt
|
|
2990
|
+
|
|
2991
|
+
# def _all_figures():
|
|
2992
|
+
# try:
|
|
2993
|
+
# from matplotlib._pylab_helpers import Gcf
|
|
2994
|
+
# return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
2995
|
+
# except Exception:
|
|
2996
|
+
# # Best effort fallback
|
|
2997
|
+
# nums = plt.get_fignums()
|
|
2998
|
+
# return [plt.figure(n) for n in nums] if nums else []
|
|
2999
|
+
|
|
3000
|
+
# # Choose a concise title
|
|
3001
|
+
# title = None
|
|
3002
|
+
# if isinstance(ctx, dict):
|
|
3003
|
+
# title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
3004
|
+
# title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
3005
|
+
|
|
3006
|
+
# for fig in _all_figures():
|
|
3007
|
+
# for ax in getattr(fig, "axes", []):
|
|
3008
|
+
# try:
|
|
3009
|
+
# if not (ax.get_title() or "").strip():
|
|
3010
|
+
# ax.set_title(title)
|
|
3011
|
+
# except Exception:
|
|
3012
|
+
# pass
|
|
3013
|
+
# try:
|
|
3014
|
+
# fig.tight_layout()
|
|
3015
|
+
# except Exception:
|
|
3016
|
+
# pass
|
|
3017
|
+
|
|
3018
|
+
|
|
3019
|
+
# def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
3020
|
+
# """
|
|
3021
|
+
# Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
3022
|
+
# SB_barplot(barplot) -> SB_barplot()
|
|
3023
|
+
# SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
3024
|
+
# sns.barplot(barplot) -> SB_barplot()
|
|
3025
|
+
# sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
3026
|
+
# Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
3027
|
+
# """
|
|
3028
|
+
# names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
3029
|
+
# for n in names:
|
|
3030
|
+
# # SB_* with sentinel as the first arg (with or without trailing args)
|
|
3031
|
+
# code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3032
|
+
# code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3033
|
+
# # sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
3034
|
+
# code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
3035
|
+
# code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
3036
|
+
# return code
|
|
3037
|
+
|
|
3038
|
+
|
|
3039
|
+
# def patch_rmse_calls(code: str) -> str:
|
|
3040
|
+
# """
|
|
3041
|
+
# Make RMSE robust across sklearn versions.
|
|
3042
|
+
# - Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
3043
|
+
# - Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
3044
|
+
# """
|
|
3045
|
+
# import re
|
|
3046
|
+
# # (a) Specific RMSE pattern
|
|
3047
|
+
# code = re.sub(
|
|
3048
|
+
# r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3049
|
+
# r"_SMX_rmse(\1)",
|
|
3050
|
+
# code,
|
|
3051
|
+
# flags=re.DOTALL
|
|
3052
|
+
# )
|
|
3053
|
+
# # (b) Guard any other MSE calls
|
|
3054
|
+
# code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3055
|
+
# return code
|
|
2887
3056
|
|
|
2888
|
-
return (
|
|
2889
|
-
"\n---\n"
|
|
2890
|
-
"## 2. Methodology and Tasks\n\n"
|
|
2891
|
-
f"{formatted_tasks}\n"
|
|
2892
|
-
)
|
|
2893
|
-
|
|
2894
|
-
def format_chronology(content):
|
|
2895
|
-
"""Formats the Chronology section."""
|
|
2896
|
-
# Uses the given LaTeX format
|
|
2897
|
-
content = content.strip().replace(' ', ' \rightarrow ')
|
|
2898
|
-
formatted = (
|
|
2899
|
-
"\n---\n"
|
|
2900
|
-
"## 3. Chronology of Tasks\n"
|
|
2901
|
-
f"$$\\text{{{content}}}$$"
|
|
2902
|
-
)
|
|
2903
|
-
return formatted
|
|
2904
|
-
|
|
2905
|
-
# --- 4. Format and Return ---
|
|
2906
|
-
formatted_question = format_question(question_content)
|
|
2907
|
-
formatted_intents = format_intents(intents_content)
|
|
2908
|
-
formatted_chronology = format_chronology(chronology_content)
|
|
2909
|
-
|
|
2910
|
-
return formatted_question, formatted_intents, formatted_chronology
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
def generate_full_report(formatted_question: str, formatted_intents: str, formatted_chronology: str) -> str:
|
|
2914
|
-
"""Combines all formatted parts into a final report string."""
|
|
2915
|
-
return (
|
|
2916
|
-
"# 🔬 Machine Learning Pipeline for Predicting Family History of Diabetes\n\n"
|
|
2917
|
-
f"{formatted_question}\n"
|
|
2918
|
-
f"{formatted_intents}\n"
|
|
2919
|
-
f"{formatted_chronology}\n"
|
|
2920
|
-
)
|
|
2921
|
-
|
|
2922
|
-
|
|
2923
|
-
def fix_confusion_matrix_for_multilabel(code: str) -> str:
|
|
2924
|
-
"""
|
|
2925
|
-
Replace ConfusionMatrixDisplay.from_estimator(...) usages with
|
|
2926
|
-
from_predictions(...) which works for multi-label loops without requiring
|
|
2927
|
-
the estimator to expose _estimator_type.
|
|
2928
|
-
"""
|
|
2929
|
-
return re.sub(
|
|
2930
|
-
r"ConfusionMatrixDisplay\.from_estimator\(([^,]+),\s*([^,]+),\s*([^)]+)\)",
|
|
2931
|
-
r"ConfusionMatrixDisplay.from_predictions(\3, \1.predict(\2))",
|
|
2932
|
-
code
|
|
2933
|
-
)
|
|
2934
|
-
|
|
2935
|
-
|
|
2936
|
-
def smx_auto_title_plots(ctx=None, fallback="Analysis"):
|
|
2937
|
-
"""
|
|
2938
|
-
Ensure every Matplotlib/Seaborn Axes has a title.
|
|
2939
|
-
Uses refined_question -> askai_question -> fallback.
|
|
2940
|
-
Only sets a title if it's currently empty.
|
|
2941
|
-
"""
|
|
2942
|
-
import matplotlib.pyplot as plt
|
|
2943
|
-
|
|
2944
|
-
def _all_figures():
|
|
2945
|
-
try:
|
|
2946
|
-
from matplotlib._pylab_helpers import Gcf
|
|
2947
|
-
return [fm.canvas.figure for fm in Gcf.get_all_fig_managers()]
|
|
2948
|
-
except Exception:
|
|
2949
|
-
# Best effort fallback
|
|
2950
|
-
nums = plt.get_fignums()
|
|
2951
|
-
return [plt.figure(n) for n in nums] if nums else []
|
|
2952
|
-
|
|
2953
|
-
# Choose a concise title
|
|
2954
|
-
title = None
|
|
2955
|
-
if isinstance(ctx, dict):
|
|
2956
|
-
title = ctx.get("refined_question") or ctx.get("askai_question")
|
|
2957
|
-
title = (str(title).strip().splitlines()[0][:120]) if title else fallback
|
|
2958
|
-
|
|
2959
|
-
for fig in _all_figures():
|
|
2960
|
-
for ax in getattr(fig, "axes", []):
|
|
2961
|
-
try:
|
|
2962
|
-
if not (ax.get_title() or "").strip():
|
|
2963
|
-
ax.set_title(title)
|
|
2964
|
-
except Exception:
|
|
2965
|
-
pass
|
|
2966
|
-
try:
|
|
2967
|
-
fig.tight_layout()
|
|
2968
|
-
except Exception:
|
|
2969
|
-
pass
|
|
2970
|
-
|
|
2971
|
-
|
|
2972
|
-
def patch_fix_sentinel_plot_calls(code: str) -> str:
|
|
2973
|
-
"""
|
|
2974
|
-
Normalise 'sentinel first-arg' calls so wrappers can pick sane defaults.
|
|
2975
|
-
SB_barplot(barplot) -> SB_barplot()
|
|
2976
|
-
SB_barplot(barplot, ...) -> SB_barplot(...)
|
|
2977
|
-
sns.barplot(barplot) -> SB_barplot()
|
|
2978
|
-
sns.barplot(barplot, ...) -> SB_barplot(...)
|
|
2979
|
-
Same for: histplot, boxplot, lineplot, countplot, heatmap, pairplot, scatterplot.
|
|
2980
|
-
"""
|
|
2981
|
-
names = ['histplot','boxplot','barplot','lineplot','countplot','heatmap','pairplot','scatterplot']
|
|
2982
|
-
for n in names:
|
|
2983
|
-
# SB_* with sentinel as the first arg (with or without trailing args)
|
|
2984
|
-
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
2985
|
-
code = re.sub(rf"\bSB_{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
2986
|
-
# sns.* with sentinel as the first arg → route to SB_* (so our wrappers handle it)
|
|
2987
|
-
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*\)", f"SB_{n}()", code)
|
|
2988
|
-
code = re.sub(rf"\bsns\.{n}\s*\(\s*{n}\s*,", f"SB_{n}(", code)
|
|
2989
|
-
return code
|
|
2990
|
-
|
|
2991
|
-
|
|
2992
|
-
def patch_rmse_calls(code: str) -> str:
|
|
2993
|
-
"""
|
|
2994
|
-
Make RMSE robust across sklearn versions.
|
|
2995
|
-
- Replace mean_squared_error(..., squared=False) -> _SMX_rmse(...)
|
|
2996
|
-
- Wrap any remaining mean_squared_error(...) calls with _SMX_call for safety.
|
|
2997
|
-
"""
|
|
2998
|
-
import re
|
|
2999
|
-
# (a) Specific RMSE pattern
|
|
3000
|
-
code = re.sub(
|
|
3001
|
-
r"\bmean_squared_error\s*\(\s*(.+?)\s*,\s*squared\s*=\s*False\s*\)",
|
|
3002
|
-
r"_SMX_rmse(\1)",
|
|
3003
|
-
code,
|
|
3004
|
-
flags=re.DOTALL
|
|
3005
|
-
)
|
|
3006
|
-
# (b) Guard any other MSE calls
|
|
3007
|
-
code = re.sub(r"\bmean_squared_error\s*\(", r"_SMX_call(mean_squared_error, ", code)
|
|
3008
|
-
return code
|