flyteplugins-codegen 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/codegen/__init__.py +18 -0
- flyteplugins/codegen/auto_coder_agent.py +1088 -0
- flyteplugins/codegen/core/__init__.py +19 -0
- flyteplugins/codegen/core/types.py +337 -0
- flyteplugins/codegen/data/__init__.py +27 -0
- flyteplugins/codegen/data/extraction.py +281 -0
- flyteplugins/codegen/data/schema.py +270 -0
- flyteplugins/codegen/execution/__init__.py +7 -0
- flyteplugins/codegen/execution/agent.py +671 -0
- flyteplugins/codegen/execution/docker.py +206 -0
- flyteplugins/codegen/generation/__init__.py +41 -0
- flyteplugins/codegen/generation/llm.py +1269 -0
- flyteplugins/codegen/generation/prompts.py +136 -0
- flyteplugins_codegen-2.0.6.dist-info/METADATA +441 -0
- flyteplugins_codegen-2.0.6.dist-info/RECORD +17 -0
- flyteplugins_codegen-2.0.6.dist-info/WHEEL +5 -0
- flyteplugins_codegen-2.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Core type definitions for LLM code generation."""
|
|
2
|
+
|
|
3
|
+
from flyteplugins.codegen.core.types import (
|
|
4
|
+
CodeGenEvalResult,
|
|
5
|
+
CodePlan,
|
|
6
|
+
CodeSolution,
|
|
7
|
+
ErrorDiagnosis,
|
|
8
|
+
FixVerification,
|
|
9
|
+
TestFailure,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"CodeGenEvalResult",
|
|
14
|
+
"CodePlan",
|
|
15
|
+
"CodeSolution",
|
|
16
|
+
"ErrorDiagnosis",
|
|
17
|
+
"FixVerification",
|
|
18
|
+
"TestFailure",
|
|
19
|
+
]
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
from typing import Any, Literal, Optional
|
|
2
|
+
|
|
3
|
+
import flyte
|
|
4
|
+
from flyte.io import File
|
|
5
|
+
from flyte.syncify import syncify
|
|
6
|
+
from pydantic import BaseModel, Field, field_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CodePlan(BaseModel):
|
|
10
|
+
"""Structured plan for the code solution."""
|
|
11
|
+
|
|
12
|
+
description: str = Field(description="Overall description of the solution")
|
|
13
|
+
approach: str = Field(description="High-level approach and algorithm to solve the problem")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CodeSolution(BaseModel):
|
|
17
|
+
"""Structured code solution."""
|
|
18
|
+
|
|
19
|
+
language: str = Field(
|
|
20
|
+
default="python",
|
|
21
|
+
description="Programming language",
|
|
22
|
+
)
|
|
23
|
+
code: str = Field(
|
|
24
|
+
default="",
|
|
25
|
+
description="Complete executable code including imports and dependencies",
|
|
26
|
+
)
|
|
27
|
+
system_packages: list[str] = Field(
|
|
28
|
+
default_factory=list,
|
|
29
|
+
description="System packages needed (e.g., gcc, build-essential, curl)",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
@field_validator("language", mode="before")
|
|
33
|
+
@classmethod
|
|
34
|
+
def normalize_language(cls, v: str) -> str:
|
|
35
|
+
return v.strip().lower()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CodeGenEvalResult(BaseModel):
|
|
39
|
+
"""Result from code generation and evaluation."""
|
|
40
|
+
|
|
41
|
+
plan: Optional[CodePlan] = None
|
|
42
|
+
solution: CodeSolution
|
|
43
|
+
tests: Optional[str] = None
|
|
44
|
+
success: bool
|
|
45
|
+
output: str
|
|
46
|
+
exit_code: int
|
|
47
|
+
error: Optional[str] = None
|
|
48
|
+
attempts: int = 1
|
|
49
|
+
conversation_history: list[dict[str, str]] = Field(default_factory=list)
|
|
50
|
+
detected_packages: list[str] = Field(
|
|
51
|
+
default_factory=list,
|
|
52
|
+
description="Language packages detected by LLM from imports",
|
|
53
|
+
)
|
|
54
|
+
detected_system_packages: list[str] = Field(default_factory=list, description="System packages detected by LLM")
|
|
55
|
+
image: Optional[str] = Field(
|
|
56
|
+
default=None,
|
|
57
|
+
description="The Flyte Image built with all dependencies",
|
|
58
|
+
)
|
|
59
|
+
total_input_tokens: int = Field(
|
|
60
|
+
default=0,
|
|
61
|
+
description="Total input tokens used across all LLM calls",
|
|
62
|
+
)
|
|
63
|
+
total_output_tokens: int = Field(
|
|
64
|
+
default=0,
|
|
65
|
+
description="Total output tokens used across all LLM calls",
|
|
66
|
+
)
|
|
67
|
+
declared_inputs: Optional[dict[str, type]] = Field(
|
|
68
|
+
default=None,
|
|
69
|
+
description="Input types (user-provided or inferred from samples)",
|
|
70
|
+
)
|
|
71
|
+
declared_outputs: Optional[dict[str, type]] = Field(
|
|
72
|
+
default=None,
|
|
73
|
+
description="Output types declared by user",
|
|
74
|
+
)
|
|
75
|
+
data_context: Optional[str] = Field(
|
|
76
|
+
default=None,
|
|
77
|
+
description="Extracted data context (schema, stats, patterns, samples) used for code generation",
|
|
78
|
+
)
|
|
79
|
+
original_samples: Optional[dict[str, File]] = Field(
|
|
80
|
+
default=None,
|
|
81
|
+
description="Sample data converted to Files (defaults for run()/as_task())",
|
|
82
|
+
)
|
|
83
|
+
generated_schemas: Optional[dict[str, str]] = Field(
|
|
84
|
+
default=None,
|
|
85
|
+
description="Auto-generated Pandera schemas (as Python code strings) for validating data inputs",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def as_task(
|
|
89
|
+
self,
|
|
90
|
+
name: str = "run_code_on_real_data",
|
|
91
|
+
resources: Optional[flyte.Resources] = None,
|
|
92
|
+
retries: int = 0,
|
|
93
|
+
timeout: Optional[int] = None,
|
|
94
|
+
env_vars: Optional[dict[str, str]] = None,
|
|
95
|
+
secrets: Optional[list] = None,
|
|
96
|
+
cache: str = "auto",
|
|
97
|
+
):
|
|
98
|
+
"""Create a sandbox that runs the generated code in an isolated sandbox.
|
|
99
|
+
|
|
100
|
+
The generated code will write outputs to /var/outputs/{output_name} files.
|
|
101
|
+
Returns a callable wrapper that automatically provides the script file.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Name for the sandbox
|
|
105
|
+
resources: Optional resources for the task
|
|
106
|
+
retries: Number of retries for the task. Defaults to 0.
|
|
107
|
+
timeout: Timeout in seconds. Defaults to None.
|
|
108
|
+
env_vars: Environment variables to pass to the sandbox.
|
|
109
|
+
secrets: flyte.Secret objects to make available.
|
|
110
|
+
cache: CacheRequest: "auto", "override", or "disable". Defaults to "auto".
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Callable task wrapper with the default inputs baked in. Call with your other declared inputs.
|
|
114
|
+
"""
|
|
115
|
+
if not self.success:
|
|
116
|
+
raise ValueError("Cannot create task from failed code generation")
|
|
117
|
+
|
|
118
|
+
if not self.image:
|
|
119
|
+
raise ValueError("No image available - code generation did not build an image")
|
|
120
|
+
|
|
121
|
+
sandbox = flyte.sandbox.create(
|
|
122
|
+
name=name,
|
|
123
|
+
code=self.solution.code,
|
|
124
|
+
inputs=self.declared_inputs or {},
|
|
125
|
+
outputs=self.declared_outputs or {},
|
|
126
|
+
auto_io=False,
|
|
127
|
+
resources=resources or flyte.Resources(cpu=1, memory="1Gi"),
|
|
128
|
+
retries=retries,
|
|
129
|
+
timeout=timeout,
|
|
130
|
+
env_vars=env_vars,
|
|
131
|
+
secrets=secrets,
|
|
132
|
+
cache=cache,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
image = self.image
|
|
136
|
+
|
|
137
|
+
# If we have samples, wrap to inject sample values as defaults
|
|
138
|
+
if self.original_samples:
|
|
139
|
+
sample_defaults = dict(self.original_samples)
|
|
140
|
+
|
|
141
|
+
@syncify
|
|
142
|
+
async def task_with_defaults(**kwargs):
|
|
143
|
+
merged = {**sample_defaults, **kwargs}
|
|
144
|
+
return await sandbox.run.aio(image=image, **merged)
|
|
145
|
+
|
|
146
|
+
return task_with_defaults
|
|
147
|
+
|
|
148
|
+
@syncify
|
|
149
|
+
async def task(**kwargs):
|
|
150
|
+
return await sandbox.run.aio(image=image, **kwargs)
|
|
151
|
+
|
|
152
|
+
return task
|
|
153
|
+
|
|
154
|
+
async def run(
|
|
155
|
+
self,
|
|
156
|
+
*,
|
|
157
|
+
name: str = "run_code_on_real_data",
|
|
158
|
+
resources: Optional[flyte.Resources] = None,
|
|
159
|
+
retries: int = 0,
|
|
160
|
+
timeout: Optional[int] = None,
|
|
161
|
+
env_vars: Optional[dict[str, str]] = None,
|
|
162
|
+
secrets: Optional[list] = None,
|
|
163
|
+
cache: str = "auto",
|
|
164
|
+
**overrides,
|
|
165
|
+
) -> Any:
|
|
166
|
+
"""Run generated code in an isolated sandbox (one-off execution).
|
|
167
|
+
|
|
168
|
+
If samples were provided during generate(), they are used as defaults.
|
|
169
|
+
Override any input by passing it as a keyword argument. If no samples
|
|
170
|
+
exist, all declared inputs must be provided via ``**overrides``.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
name: Name for the sandbox
|
|
174
|
+
resources: Optional resources for the task
|
|
175
|
+
retries: Number of retries for the task. Defaults to 0.
|
|
176
|
+
timeout: Timeout in seconds. Defaults to None.
|
|
177
|
+
env_vars: Environment variables to pass to the sandbox.
|
|
178
|
+
secrets: flyte.Secret objects to make available.
|
|
179
|
+
cache: CacheRequest: "auto", "override", or "disable". Defaults to "auto".
|
|
180
|
+
**overrides: Input values. Merged on top of sample defaults (if any).
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Tuple of typed outputs.
|
|
184
|
+
"""
|
|
185
|
+
if not self.success:
|
|
186
|
+
raise ValueError("Cannot run failed code generation")
|
|
187
|
+
|
|
188
|
+
if not self.image:
|
|
189
|
+
raise ValueError("No image available - code generation did not build an image")
|
|
190
|
+
|
|
191
|
+
sandbox = flyte.sandbox.create(
|
|
192
|
+
name=name,
|
|
193
|
+
code=self.solution.code,
|
|
194
|
+
inputs=self.declared_inputs or {},
|
|
195
|
+
outputs=self.declared_outputs or {},
|
|
196
|
+
auto_io=False,
|
|
197
|
+
resources=resources or flyte.Resources(cpu=1, memory="1Gi"),
|
|
198
|
+
retries=retries,
|
|
199
|
+
timeout=timeout,
|
|
200
|
+
env_vars=env_vars,
|
|
201
|
+
secrets=secrets,
|
|
202
|
+
cache=cache,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
run_data = {**(self.original_samples or {}), **overrides}
|
|
206
|
+
return await sandbox.run.aio(image=self.image, **run_data)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Apply syncify after class definition to avoid Pydantic field detection
|
|
210
|
+
CodeGenEvalResult.run = syncify(CodeGenEvalResult.run)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class TestFailure(BaseModel):
|
|
214
|
+
"""Individual test failure with diagnosis."""
|
|
215
|
+
|
|
216
|
+
test_name: str = Field(description="Name of the failing test")
|
|
217
|
+
error_message: str = Field(
|
|
218
|
+
description="The exact final error message from test output "
|
|
219
|
+
"(e.g., 'RecursionError: maximum recursion depth exceeded')"
|
|
220
|
+
)
|
|
221
|
+
expected_behavior: str = Field(description="What this test expected to happen")
|
|
222
|
+
actual_behavior: str = Field(description="What actually happened when the code ran")
|
|
223
|
+
root_cause: str = Field(description="Why the test failed (quote the exact code that's wrong)")
|
|
224
|
+
suggested_fix: str = Field(description="Specific code changes using format: Replace `current code` with `new code`")
|
|
225
|
+
error_type: Literal["environment", "logic", "test_error"] = Field(
|
|
226
|
+
description="Type of error: 'environment' (missing packages/dependencies), "
|
|
227
|
+
"'logic' (bug in solution code), or 'test_error' (bug in test code)"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class ErrorDiagnosis(BaseModel):
|
|
232
|
+
"""Structured diagnosis of execution errors."""
|
|
233
|
+
|
|
234
|
+
failures: list[TestFailure] = Field(description="Individual test failures with their diagnoses")
|
|
235
|
+
needs_system_packages: list[str] = Field(
|
|
236
|
+
default_factory=list,
|
|
237
|
+
description="System packages needed (e.g., gcc, pkg-config).",
|
|
238
|
+
)
|
|
239
|
+
needs_language_packages: list[str] = Field(
|
|
240
|
+
default_factory=list,
|
|
241
|
+
description="Language packages needed.",
|
|
242
|
+
)
|
|
243
|
+
needs_additional_commands: list[str] = Field(
|
|
244
|
+
default_factory=list,
|
|
245
|
+
description="Additional RUN commands (e.g., apt-get update, mkdir /data, wget files).",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class FixVerification(BaseModel):
|
|
250
|
+
"""Verification that fixes were applied to code."""
|
|
251
|
+
|
|
252
|
+
all_fixes_applied: bool = Field(description="True if all suggested fixes are present in the new code")
|
|
253
|
+
applied_fixes: list[str] = Field(
|
|
254
|
+
default_factory=list,
|
|
255
|
+
description="List of fixes that were successfully applied (by test name)",
|
|
256
|
+
)
|
|
257
|
+
missing_fixes: list[str] = Field(
|
|
258
|
+
default_factory=list,
|
|
259
|
+
description="List of fixes that are still missing (by test name)",
|
|
260
|
+
)
|
|
261
|
+
explanation: str = Field(description="Brief explanation of what was checked and what's missing (if anything)")
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class TestFunctionPatch(BaseModel):
|
|
265
|
+
"""A single fixed test function."""
|
|
266
|
+
|
|
267
|
+
test_name: str = Field(description="Name of the test function (e.g. test_basic_analysis)")
|
|
268
|
+
fixed_code: str = Field(description="Complete fixed function body including the def line and decorators")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class TestFixResponse(BaseModel):
|
|
272
|
+
"""Response containing only the fixed test functions."""
|
|
273
|
+
|
|
274
|
+
patches: list[TestFunctionPatch] = Field(description="List of fixed test functions")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class _PackageReplacementResponse(BaseModel):
|
|
278
|
+
"""Response format for suggesting a replacement system package."""
|
|
279
|
+
|
|
280
|
+
replacement: Optional[str] = Field(
|
|
281
|
+
default=None,
|
|
282
|
+
description="Correct Debian/Ubuntu apt package name, or null if no system package is needed",
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class _PackageDetectionResponse(BaseModel):
|
|
287
|
+
"""Response format for LLM package detection."""
|
|
288
|
+
|
|
289
|
+
packages: list[str] = Field(
|
|
290
|
+
default_factory=list,
|
|
291
|
+
description="List of third-party package names",
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class _TestCodeResponse(BaseModel):
|
|
296
|
+
"""Response format for LLM test generation."""
|
|
297
|
+
|
|
298
|
+
test_code: str = Field(description="Complete test code")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class _ConstraintParameters(BaseModel):
|
|
302
|
+
"""Parameters for a constraint check. Only the fields relevant to the check_type should be set."""
|
|
303
|
+
|
|
304
|
+
value: Optional[float] = Field(
|
|
305
|
+
default=None,
|
|
306
|
+
description="Threshold value for greater_than or less_than checks",
|
|
307
|
+
)
|
|
308
|
+
min: Optional[float] = Field(
|
|
309
|
+
default=None,
|
|
310
|
+
description="Minimum value for between checks",
|
|
311
|
+
)
|
|
312
|
+
max: Optional[float] = Field(
|
|
313
|
+
default=None,
|
|
314
|
+
description="Maximum value for between checks",
|
|
315
|
+
)
|
|
316
|
+
pattern: Optional[str] = Field(
|
|
317
|
+
default=None,
|
|
318
|
+
description="Regex pattern for regex checks",
|
|
319
|
+
)
|
|
320
|
+
values: Optional[list[str]] = Field(
|
|
321
|
+
default=None,
|
|
322
|
+
description="Allowed values for isin checks",
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class _ConstraintParse(BaseModel):
|
|
327
|
+
"""LLM response for parsing a constraint into Pandera check."""
|
|
328
|
+
|
|
329
|
+
column_name: str = Field(description="Name of the column this constraint applies to")
|
|
330
|
+
check_type: Literal["greater_than", "less_than", "between", "regex", "isin", "not_null", "none"] = Field(
|
|
331
|
+
description="Type of check to apply"
|
|
332
|
+
)
|
|
333
|
+
parameters: _ConstraintParameters = Field(
|
|
334
|
+
default_factory=_ConstraintParameters,
|
|
335
|
+
description="Parameters for the check. Set only the fields relevant to the check_type.",
|
|
336
|
+
)
|
|
337
|
+
explanation: str = Field(description="Brief explanation of what check will be applied")
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Data extraction and schema inference."""
|
|
2
|
+
|
|
3
|
+
from flyteplugins.codegen.data.extraction import (
|
|
4
|
+
extract_data_context,
|
|
5
|
+
extract_dataframe_context,
|
|
6
|
+
extract_file_context,
|
|
7
|
+
is_dataframe,
|
|
8
|
+
)
|
|
9
|
+
from flyteplugins.codegen.data.schema import (
|
|
10
|
+
apply_parsed_constraint,
|
|
11
|
+
apply_user_constraints,
|
|
12
|
+
extract_token_usage,
|
|
13
|
+
infer_conservative_schema,
|
|
14
|
+
parse_constraint_with_llm,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"apply_parsed_constraint",
|
|
19
|
+
"apply_user_constraints",
|
|
20
|
+
"extract_data_context",
|
|
21
|
+
"extract_dataframe_context",
|
|
22
|
+
"extract_file_context",
|
|
23
|
+
"extract_token_usage",
|
|
24
|
+
"infer_conservative_schema",
|
|
25
|
+
"is_dataframe",
|
|
26
|
+
"parse_constraint_with_llm",
|
|
27
|
+
]
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import aiofiles
|
|
6
|
+
import flyte
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import pandera.pandas as pa
|
|
9
|
+
from flyte.io import File
|
|
10
|
+
|
|
11
|
+
from flyteplugins.codegen.data.schema import (
|
|
12
|
+
apply_user_constraints,
|
|
13
|
+
infer_conservative_schema,
|
|
14
|
+
schema_to_script,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def is_dataframe(obj) -> bool:
|
|
21
|
+
"""Check if object is a pandas DataFrame.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
obj: Object to check
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
True if object is a DataFrame
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
return isinstance(obj, pd.DataFrame)
|
|
31
|
+
except ImportError:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def extract_dataframe_context(
|
|
36
|
+
df, name: str, max_sample_rows: int = 5, schema: Optional[pa.DataFrameSchema] = None
|
|
37
|
+
) -> str:
|
|
38
|
+
"""Extract comprehensive context from DataFrame.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
df: pandas DataFrame
|
|
42
|
+
name: Name of the data input
|
|
43
|
+
max_sample_rows: Number of sample rows to include
|
|
44
|
+
schema: Optional Pandera schema to include in context
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Formatted string with all extracted context
|
|
48
|
+
"""
|
|
49
|
+
context_parts = []
|
|
50
|
+
|
|
51
|
+
# 1. Structural Context
|
|
52
|
+
context_parts.append(f"## Data: {name}")
|
|
53
|
+
context_parts.append(f"Shape: {df.shape[0]:,} rows x {df.shape[1]} columns")
|
|
54
|
+
|
|
55
|
+
# Include Pandera schema if provided (use Pandera's built-in formatter)
|
|
56
|
+
if schema:
|
|
57
|
+
context_parts.append(f"\nPandera Schema for {name} (use for validation):")
|
|
58
|
+
context_parts.append("```python")
|
|
59
|
+
context_parts.append(schema_to_script(schema))
|
|
60
|
+
context_parts.append("```")
|
|
61
|
+
|
|
62
|
+
# 2. Statistical Context
|
|
63
|
+
context_parts.append("\nStatistical Summary:")
|
|
64
|
+
|
|
65
|
+
# Numeric columns
|
|
66
|
+
numeric_cols = df.select_dtypes(include=["number"]).columns
|
|
67
|
+
if len(numeric_cols) > 0:
|
|
68
|
+
context_parts.append(" Numeric columns:")
|
|
69
|
+
desc = df[numeric_cols].describe()
|
|
70
|
+
for col in numeric_cols:
|
|
71
|
+
stats = desc[col]
|
|
72
|
+
context_parts.append(
|
|
73
|
+
f" {col}: min={stats['min']:.2g}, max={stats['max']:.2g}, "
|
|
74
|
+
f"mean={stats['mean']:.2g}, median={stats['50%']:.2g}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Categorical/Object columns
|
|
78
|
+
cat_cols = df.select_dtypes(include=["object", "category"]).columns
|
|
79
|
+
if len(cat_cols) > 0:
|
|
80
|
+
context_parts.append(" Categorical columns:")
|
|
81
|
+
for col in cat_cols:
|
|
82
|
+
unique_count = df[col].nunique()
|
|
83
|
+
total_count = len(df[col].dropna())
|
|
84
|
+
if unique_count <= 20 and total_count > 0:
|
|
85
|
+
# Show value counts for low-cardinality columns
|
|
86
|
+
top_values = df[col].value_counts().head(5)
|
|
87
|
+
top_str = ", ".join([f"'{k}': {v}" for k, v in top_values.items()])
|
|
88
|
+
context_parts.append(f" {col}: {unique_count} unique values. Top 5: {{{top_str}}}")
|
|
89
|
+
else:
|
|
90
|
+
context_parts.append(f" {col}: {unique_count} unique values")
|
|
91
|
+
|
|
92
|
+
# DateTime columns
|
|
93
|
+
date_cols = df.select_dtypes(include=["datetime64"]).columns
|
|
94
|
+
if len(date_cols) > 0:
|
|
95
|
+
context_parts.append(" DateTime columns:")
|
|
96
|
+
for col in date_cols:
|
|
97
|
+
min_date = df[col].min()
|
|
98
|
+
max_date = df[col].max()
|
|
99
|
+
context_parts.append(f" {col}: {min_date} to {max_date}")
|
|
100
|
+
|
|
101
|
+
# 3. Behavioral Context (patterns, invariants)
|
|
102
|
+
context_parts.append("\nData Patterns:")
|
|
103
|
+
|
|
104
|
+
# Check for duplicates
|
|
105
|
+
dup_count = df.duplicated().sum()
|
|
106
|
+
if dup_count > 0:
|
|
107
|
+
context_parts.append(f" - {dup_count:,} duplicate rows ({dup_count / len(df) * 100:.1f}%)")
|
|
108
|
+
|
|
109
|
+
# Check for potential ID columns
|
|
110
|
+
for col in df.columns:
|
|
111
|
+
if df[col].nunique() == len(df) and not df[col].isna().any():
|
|
112
|
+
context_parts.append(f" - '{col}' appears to be a unique identifier")
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
# 4. Representative Samples
|
|
116
|
+
context_parts.append(f"\nRepresentative Samples ({max_sample_rows} rows):")
|
|
117
|
+
|
|
118
|
+
# Sample strategy: first few + random + edge cases
|
|
119
|
+
sample_indices = []
|
|
120
|
+
|
|
121
|
+
# First rows
|
|
122
|
+
sample_indices.extend(range(min(2, len(df))))
|
|
123
|
+
|
|
124
|
+
# Random sample
|
|
125
|
+
if len(df) > max_sample_rows:
|
|
126
|
+
remaining = max_sample_rows - len(sample_indices)
|
|
127
|
+
random_indices = df.sample(n=remaining).index.tolist()
|
|
128
|
+
sample_indices.extend(random_indices)
|
|
129
|
+
else:
|
|
130
|
+
sample_indices = list(range(len(df)))
|
|
131
|
+
|
|
132
|
+
sample_df = df.iloc[sample_indices[:max_sample_rows]]
|
|
133
|
+
|
|
134
|
+
# Format as CSV
|
|
135
|
+
context_parts.append(sample_df.to_csv(index=False))
|
|
136
|
+
|
|
137
|
+
return "\n".join(context_parts)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def extract_file_context(file: File, name: str, max_sample_rows: int = 5) -> str:
|
|
141
|
+
"""Extract context from non-tabular files (text, binary, unknown formats).
|
|
142
|
+
|
|
143
|
+
This is a fallback for files that can't be loaded as DataFrames.
|
|
144
|
+
Structured files (CSV, Parquet, JSON, Excel) are handled by extract_data_context()
|
|
145
|
+
with Pandera schema inference.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
file: File to extract context from
|
|
149
|
+
name: Name of the data input
|
|
150
|
+
max_sample_rows: Number of sample rows to include
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Formatted string with all extracted context
|
|
154
|
+
"""
|
|
155
|
+
local_path = await file.download()
|
|
156
|
+
file_ext = Path(local_path).suffix.lower()
|
|
157
|
+
|
|
158
|
+
# Try to read as text file
|
|
159
|
+
try:
|
|
160
|
+
async with aiofiles.open(local_path, "r", encoding="utf-8", errors="ignore") as f:
|
|
161
|
+
lines = []
|
|
162
|
+
for _ in range(max_sample_rows):
|
|
163
|
+
line = await f.readline()
|
|
164
|
+
if not line:
|
|
165
|
+
break
|
|
166
|
+
lines.append(line)
|
|
167
|
+
|
|
168
|
+
context_parts = [
|
|
169
|
+
f"## Data: {name}",
|
|
170
|
+
f"Type: Text file ({file_ext})",
|
|
171
|
+
f"Lines: {len(lines)}",
|
|
172
|
+
f"\nFirst {max_sample_rows} lines:",
|
|
173
|
+
"".join(lines),
|
|
174
|
+
]
|
|
175
|
+
return "\n".join(context_parts)
|
|
176
|
+
|
|
177
|
+
except Exception:
|
|
178
|
+
# Binary or unreadable file
|
|
179
|
+
file_size = Path(local_path).stat().st_size # noqa: ASYNC240
|
|
180
|
+
context_parts = [
|
|
181
|
+
f"## Data: {name}",
|
|
182
|
+
f"Type: Binary/Unknown ({file_ext})",
|
|
183
|
+
f"Size: {file_size:,} bytes",
|
|
184
|
+
"\n(Unable to extract text preview)",
|
|
185
|
+
]
|
|
186
|
+
return "\n".join(context_parts)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@flyte.trace
|
|
190
|
+
async def extract_data_context(
|
|
191
|
+
data: dict[str, pd.DataFrame | File],
|
|
192
|
+
max_sample_rows: int = 5,
|
|
193
|
+
constraints: Optional[list[str]] = None,
|
|
194
|
+
model: Optional[str] = None,
|
|
195
|
+
litellm_params: Optional[dict] = None,
|
|
196
|
+
) -> tuple[str, dict[str, str], int, int]:
|
|
197
|
+
"""Extract comprehensive context from data inputs with Pandera schema inference.
|
|
198
|
+
|
|
199
|
+
Extracts:
|
|
200
|
+
1. Structural context (schema, types, shape)
|
|
201
|
+
2. Statistical context (distributions, ranges)
|
|
202
|
+
3. Behavioral context (patterns, invariants)
|
|
203
|
+
4. Operational context (scale, nulls)
|
|
204
|
+
5. Representative samples
|
|
205
|
+
6. Pandera schemas (inference + user constraints), returned as Python code strings
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
data: Dict of data inputs (File or DataFrame)
|
|
209
|
+
max_sample_rows: Number of sample rows to include
|
|
210
|
+
constraints: Optional list of user constraints to apply to schemas
|
|
211
|
+
model: LLM model for constraint parsing (required if constraints provided)
|
|
212
|
+
litellm_params: Optional LiteLLM parameters
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Tuple of (context_string, schemas_as_code_dict, total_input_tokens, total_output_tokens)
|
|
216
|
+
"""
|
|
217
|
+
context_parts = []
|
|
218
|
+
schemas: dict[str, str] = {}
|
|
219
|
+
total_input_tokens = 0
|
|
220
|
+
total_output_tokens = 0
|
|
221
|
+
|
|
222
|
+
for name, value in data.items():
|
|
223
|
+
df = None
|
|
224
|
+
|
|
225
|
+
if isinstance(value, File):
|
|
226
|
+
# Load file as DataFrame for schema inference
|
|
227
|
+
local_path = await value.download()
|
|
228
|
+
file_ext = Path(local_path).suffix.lower()
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
if file_ext in [".csv", ".tsv"]:
|
|
232
|
+
delimiter = "\t" if file_ext == ".tsv" else ","
|
|
233
|
+
df = pd.read_csv(local_path, delimiter=delimiter, nrows=10000)
|
|
234
|
+
elif file_ext in [".parquet", ".pq"]:
|
|
235
|
+
df = pd.read_parquet(local_path)
|
|
236
|
+
if len(df) > 10000:
|
|
237
|
+
df = df.sample(n=10000)
|
|
238
|
+
elif file_ext == ".json":
|
|
239
|
+
try:
|
|
240
|
+
df = pd.read_json(local_path, lines=True, nrows=10000)
|
|
241
|
+
except Exception:
|
|
242
|
+
df = pd.read_json(local_path)
|
|
243
|
+
elif file_ext in [".xlsx", ".xls"]:
|
|
244
|
+
df = pd.read_excel(local_path, nrows=10000)
|
|
245
|
+
else:
|
|
246
|
+
# Non-tabular file (e.g., .log, .txt) — extract text context
|
|
247
|
+
context = await extract_file_context(value, name, max_sample_rows)
|
|
248
|
+
context_parts.append(context)
|
|
249
|
+
continue
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.warning(f"Failed to load {name} as DataFrame for schema inference: {e}")
|
|
252
|
+
# Fall back to non-schema extraction
|
|
253
|
+
context = await extract_file_context(value, name, max_sample_rows)
|
|
254
|
+
context_parts.append(context)
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
elif is_dataframe(value):
|
|
258
|
+
df = value
|
|
259
|
+
else:
|
|
260
|
+
context_parts.append(f"## Data: {name}\nType: {type(value)}\n(Unsupported type)")
|
|
261
|
+
continue
|
|
262
|
+
|
|
263
|
+
if df is not None:
|
|
264
|
+
# Infer Pandera schema
|
|
265
|
+
schema = infer_conservative_schema(df)
|
|
266
|
+
|
|
267
|
+
# Apply user constraints if provided
|
|
268
|
+
if constraints and model:
|
|
269
|
+
schema, in_tok, out_tok = await apply_user_constraints(schema, constraints, name, model, litellm_params)
|
|
270
|
+
total_input_tokens += in_tok
|
|
271
|
+
total_output_tokens += out_tok
|
|
272
|
+
|
|
273
|
+
# Convert to code string for serialization
|
|
274
|
+
schemas[name] = schema_to_script(schema)
|
|
275
|
+
|
|
276
|
+
# Extract context with schema
|
|
277
|
+
context = await extract_dataframe_context(df, name, max_sample_rows, schema)
|
|
278
|
+
context_parts.append(context)
|
|
279
|
+
|
|
280
|
+
context_str = "\n\n" + "=" * 80 + "\n\n".join(context_parts)
|
|
281
|
+
return context_str, schemas, total_input_tokens, total_output_tokens
|