flyteplugins-codegen 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ """Core type definitions for LLM code generation."""
2
+
3
+ from flyteplugins.codegen.core.types import (
4
+ CodeGenEvalResult,
5
+ CodePlan,
6
+ CodeSolution,
7
+ ErrorDiagnosis,
8
+ FixVerification,
9
+ TestFailure,
10
+ )
11
+
12
+ __all__ = [
13
+ "CodeGenEvalResult",
14
+ "CodePlan",
15
+ "CodeSolution",
16
+ "ErrorDiagnosis",
17
+ "FixVerification",
18
+ "TestFailure",
19
+ ]
@@ -0,0 +1,337 @@
1
+ from typing import Any, Literal, Optional
2
+
3
+ import flyte
4
+ from flyte.io import File
5
+ from flyte.syncify import syncify
6
+ from pydantic import BaseModel, Field, field_validator
7
+
8
+
9
+ class CodePlan(BaseModel):
10
+ """Structured plan for the code solution."""
11
+
12
+ description: str = Field(description="Overall description of the solution")
13
+ approach: str = Field(description="High-level approach and algorithm to solve the problem")
14
+
15
+
16
+ class CodeSolution(BaseModel):
17
+ """Structured code solution."""
18
+
19
+ language: str = Field(
20
+ default="python",
21
+ description="Programming language",
22
+ )
23
+ code: str = Field(
24
+ default="",
25
+ description="Complete executable code including imports and dependencies",
26
+ )
27
+ system_packages: list[str] = Field(
28
+ default_factory=list,
29
+ description="System packages needed (e.g., gcc, build-essential, curl)",
30
+ )
31
+
32
+ @field_validator("language", mode="before")
33
+ @classmethod
34
+ def normalize_language(cls, v: str) -> str:
35
+ return v.strip().lower()
36
+
37
+
38
+ class CodeGenEvalResult(BaseModel):
39
+ """Result from code generation and evaluation."""
40
+
41
+ plan: Optional[CodePlan] = None
42
+ solution: CodeSolution
43
+ tests: Optional[str] = None
44
+ success: bool
45
+ output: str
46
+ exit_code: int
47
+ error: Optional[str] = None
48
+ attempts: int = 1
49
+ conversation_history: list[dict[str, str]] = Field(default_factory=list)
50
+ detected_packages: list[str] = Field(
51
+ default_factory=list,
52
+ description="Language packages detected by LLM from imports",
53
+ )
54
+ detected_system_packages: list[str] = Field(default_factory=list, description="System packages detected by LLM")
55
+ image: Optional[str] = Field(
56
+ default=None,
57
+ description="The Flyte Image built with all dependencies",
58
+ )
59
+ total_input_tokens: int = Field(
60
+ default=0,
61
+ description="Total input tokens used across all LLM calls",
62
+ )
63
+ total_output_tokens: int = Field(
64
+ default=0,
65
+ description="Total output tokens used across all LLM calls",
66
+ )
67
+ declared_inputs: Optional[dict[str, type]] = Field(
68
+ default=None,
69
+ description="Input types (user-provided or inferred from samples)",
70
+ )
71
+ declared_outputs: Optional[dict[str, type]] = Field(
72
+ default=None,
73
+ description="Output types declared by user",
74
+ )
75
+ data_context: Optional[str] = Field(
76
+ default=None,
77
+ description="Extracted data context (schema, stats, patterns, samples) used for code generation",
78
+ )
79
+ original_samples: Optional[dict[str, File]] = Field(
80
+ default=None,
81
+ description="Sample data converted to Files (defaults for run()/as_task())",
82
+ )
83
+ generated_schemas: Optional[dict[str, str]] = Field(
84
+ default=None,
85
+ description="Auto-generated Pandera schemas (as Python code strings) for validating data inputs",
86
+ )
87
+
88
+ def as_task(
89
+ self,
90
+ name: str = "run_code_on_real_data",
91
+ resources: Optional[flyte.Resources] = None,
92
+ retries: int = 0,
93
+ timeout: Optional[int] = None,
94
+ env_vars: Optional[dict[str, str]] = None,
95
+ secrets: Optional[list] = None,
96
+ cache: str = "auto",
97
+ ):
98
+ """Create a sandbox that runs the generated code in an isolated sandbox.
99
+
100
+ The generated code will write outputs to /var/outputs/{output_name} files.
101
+ Returns a callable wrapper that automatically provides the script file.
102
+
103
+ Args:
104
+ name: Name for the sandbox
105
+ resources: Optional resources for the task
106
+ retries: Number of retries for the task. Defaults to 0.
107
+ timeout: Timeout in seconds. Defaults to None.
108
+ env_vars: Environment variables to pass to the sandbox.
109
+ secrets: flyte.Secret objects to make available.
110
+ cache: CacheRequest: "auto", "override", or "disable". Defaults to "auto".
111
+
112
+ Returns:
113
+ Callable task wrapper with the default inputs baked in. Call with your other declared inputs.
114
+ """
115
+ if not self.success:
116
+ raise ValueError("Cannot create task from failed code generation")
117
+
118
+ if not self.image:
119
+ raise ValueError("No image available - code generation did not build an image")
120
+
121
+ sandbox = flyte.sandbox.create(
122
+ name=name,
123
+ code=self.solution.code,
124
+ inputs=self.declared_inputs or {},
125
+ outputs=self.declared_outputs or {},
126
+ auto_io=False,
127
+ resources=resources or flyte.Resources(cpu=1, memory="1Gi"),
128
+ retries=retries,
129
+ timeout=timeout,
130
+ env_vars=env_vars,
131
+ secrets=secrets,
132
+ cache=cache,
133
+ )
134
+
135
+ image = self.image
136
+
137
+ # If we have samples, wrap to inject sample values as defaults
138
+ if self.original_samples:
139
+ sample_defaults = dict(self.original_samples)
140
+
141
+ @syncify
142
+ async def task_with_defaults(**kwargs):
143
+ merged = {**sample_defaults, **kwargs}
144
+ return await sandbox.run.aio(image=image, **merged)
145
+
146
+ return task_with_defaults
147
+
148
+ @syncify
149
+ async def task(**kwargs):
150
+ return await sandbox.run.aio(image=image, **kwargs)
151
+
152
+ return task
153
+
154
+ async def run(
155
+ self,
156
+ *,
157
+ name: str = "run_code_on_real_data",
158
+ resources: Optional[flyte.Resources] = None,
159
+ retries: int = 0,
160
+ timeout: Optional[int] = None,
161
+ env_vars: Optional[dict[str, str]] = None,
162
+ secrets: Optional[list] = None,
163
+ cache: str = "auto",
164
+ **overrides,
165
+ ) -> Any:
166
+ """Run generated code in an isolated sandbox (one-off execution).
167
+
168
+ If samples were provided during generate(), they are used as defaults.
169
+ Override any input by passing it as a keyword argument. If no samples
170
+ exist, all declared inputs must be provided via ``**overrides``.
171
+
172
+ Args:
173
+ name: Name for the sandbox
174
+ resources: Optional resources for the task
175
+ retries: Number of retries for the task. Defaults to 0.
176
+ timeout: Timeout in seconds. Defaults to None.
177
+ env_vars: Environment variables to pass to the sandbox.
178
+ secrets: flyte.Secret objects to make available.
179
+ cache: CacheRequest: "auto", "override", or "disable". Defaults to "auto".
180
+ **overrides: Input values. Merged on top of sample defaults (if any).
181
+
182
+ Returns:
183
+ Tuple of typed outputs.
184
+ """
185
+ if not self.success:
186
+ raise ValueError("Cannot run failed code generation")
187
+
188
+ if not self.image:
189
+ raise ValueError("No image available - code generation did not build an image")
190
+
191
+ sandbox = flyte.sandbox.create(
192
+ name=name,
193
+ code=self.solution.code,
194
+ inputs=self.declared_inputs or {},
195
+ outputs=self.declared_outputs or {},
196
+ auto_io=False,
197
+ resources=resources or flyte.Resources(cpu=1, memory="1Gi"),
198
+ retries=retries,
199
+ timeout=timeout,
200
+ env_vars=env_vars,
201
+ secrets=secrets,
202
+ cache=cache,
203
+ )
204
+
205
+ run_data = {**(self.original_samples or {}), **overrides}
206
+ return await sandbox.run.aio(image=self.image, **run_data)
207
+
208
+
209
+ # Apply syncify after class definition to avoid Pydantic field detection
210
+ CodeGenEvalResult.run = syncify(CodeGenEvalResult.run)
211
+
212
+
213
+ class TestFailure(BaseModel):
214
+ """Individual test failure with diagnosis."""
215
+
216
+ test_name: str = Field(description="Name of the failing test")
217
+ error_message: str = Field(
218
+ description="The exact final error message from test output "
219
+ "(e.g., 'RecursionError: maximum recursion depth exceeded')"
220
+ )
221
+ expected_behavior: str = Field(description="What this test expected to happen")
222
+ actual_behavior: str = Field(description="What actually happened when the code ran")
223
+ root_cause: str = Field(description="Why the test failed (quote the exact code that's wrong)")
224
+ suggested_fix: str = Field(description="Specific code changes using format: Replace `current code` with `new code`")
225
+ error_type: Literal["environment", "logic", "test_error"] = Field(
226
+ description="Type of error: 'environment' (missing packages/dependencies), "
227
+ "'logic' (bug in solution code), or 'test_error' (bug in test code)"
228
+ )
229
+
230
+
231
+ class ErrorDiagnosis(BaseModel):
232
+ """Structured diagnosis of execution errors."""
233
+
234
+ failures: list[TestFailure] = Field(description="Individual test failures with their diagnoses")
235
+ needs_system_packages: list[str] = Field(
236
+ default_factory=list,
237
+ description="System packages needed (e.g., gcc, pkg-config).",
238
+ )
239
+ needs_language_packages: list[str] = Field(
240
+ default_factory=list,
241
+ description="Language packages needed.",
242
+ )
243
+ needs_additional_commands: list[str] = Field(
244
+ default_factory=list,
245
+ description="Additional RUN commands (e.g., apt-get update, mkdir /data, wget files).",
246
+ )
247
+
248
+
249
+ class FixVerification(BaseModel):
250
+ """Verification that fixes were applied to code."""
251
+
252
+ all_fixes_applied: bool = Field(description="True if all suggested fixes are present in the new code")
253
+ applied_fixes: list[str] = Field(
254
+ default_factory=list,
255
+ description="List of fixes that were successfully applied (by test name)",
256
+ )
257
+ missing_fixes: list[str] = Field(
258
+ default_factory=list,
259
+ description="List of fixes that are still missing (by test name)",
260
+ )
261
+ explanation: str = Field(description="Brief explanation of what was checked and what's missing (if anything)")
262
+
263
+
264
+ class TestFunctionPatch(BaseModel):
265
+ """A single fixed test function."""
266
+
267
+ test_name: str = Field(description="Name of the test function (e.g. test_basic_analysis)")
268
+ fixed_code: str = Field(description="Complete fixed function body including the def line and decorators")
269
+
270
+
271
+ class TestFixResponse(BaseModel):
272
+ """Response containing only the fixed test functions."""
273
+
274
+ patches: list[TestFunctionPatch] = Field(description="List of fixed test functions")
275
+
276
+
277
+ class _PackageReplacementResponse(BaseModel):
278
+ """Response format for suggesting a replacement system package."""
279
+
280
+ replacement: Optional[str] = Field(
281
+ default=None,
282
+ description="Correct Debian/Ubuntu apt package name, or null if no system package is needed",
283
+ )
284
+
285
+
286
+ class _PackageDetectionResponse(BaseModel):
287
+ """Response format for LLM package detection."""
288
+
289
+ packages: list[str] = Field(
290
+ default_factory=list,
291
+ description="List of third-party package names",
292
+ )
293
+
294
+
295
+ class _TestCodeResponse(BaseModel):
296
+ """Response format for LLM test generation."""
297
+
298
+ test_code: str = Field(description="Complete test code")
299
+
300
+
301
+ class _ConstraintParameters(BaseModel):
302
+ """Parameters for a constraint check. Only the fields relevant to the check_type should be set."""
303
+
304
+ value: Optional[float] = Field(
305
+ default=None,
306
+ description="Threshold value for greater_than or less_than checks",
307
+ )
308
+ min: Optional[float] = Field(
309
+ default=None,
310
+ description="Minimum value for between checks",
311
+ )
312
+ max: Optional[float] = Field(
313
+ default=None,
314
+ description="Maximum value for between checks",
315
+ )
316
+ pattern: Optional[str] = Field(
317
+ default=None,
318
+ description="Regex pattern for regex checks",
319
+ )
320
+ values: Optional[list[str]] = Field(
321
+ default=None,
322
+ description="Allowed values for isin checks",
323
+ )
324
+
325
+
326
+ class _ConstraintParse(BaseModel):
327
+ """LLM response for parsing a constraint into Pandera check."""
328
+
329
+ column_name: str = Field(description="Name of the column this constraint applies to")
330
+ check_type: Literal["greater_than", "less_than", "between", "regex", "isin", "not_null", "none"] = Field(
331
+ description="Type of check to apply"
332
+ )
333
+ parameters: _ConstraintParameters = Field(
334
+ default_factory=_ConstraintParameters,
335
+ description="Parameters for the check. Set only the fields relevant to the check_type.",
336
+ )
337
+ explanation: str = Field(description="Brief explanation of what check will be applied")
@@ -0,0 +1,27 @@
1
+ """Data extraction and schema inference."""
2
+
3
+ from flyteplugins.codegen.data.extraction import (
4
+ extract_data_context,
5
+ extract_dataframe_context,
6
+ extract_file_context,
7
+ is_dataframe,
8
+ )
9
+ from flyteplugins.codegen.data.schema import (
10
+ apply_parsed_constraint,
11
+ apply_user_constraints,
12
+ extract_token_usage,
13
+ infer_conservative_schema,
14
+ parse_constraint_with_llm,
15
+ )
16
+
17
+ __all__ = [
18
+ "apply_parsed_constraint",
19
+ "apply_user_constraints",
20
+ "extract_data_context",
21
+ "extract_dataframe_context",
22
+ "extract_file_context",
23
+ "extract_token_usage",
24
+ "infer_conservative_schema",
25
+ "is_dataframe",
26
+ "parse_constraint_with_llm",
27
+ ]
@@ -0,0 +1,281 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import aiofiles
6
+ import flyte
7
+ import pandas as pd
8
+ import pandera.pandas as pa
9
+ from flyte.io import File
10
+
11
+ from flyteplugins.codegen.data.schema import (
12
+ apply_user_constraints,
13
+ infer_conservative_schema,
14
+ schema_to_script,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def is_dataframe(obj) -> bool:
21
+ """Check if object is a pandas DataFrame.
22
+
23
+ Args:
24
+ obj: Object to check
25
+
26
+ Returns:
27
+ True if object is a DataFrame
28
+ """
29
+ try:
30
+ return isinstance(obj, pd.DataFrame)
31
+ except ImportError:
32
+ return False
33
+
34
+
35
+ async def extract_dataframe_context(
36
+ df, name: str, max_sample_rows: int = 5, schema: Optional[pa.DataFrameSchema] = None
37
+ ) -> str:
38
+ """Extract comprehensive context from DataFrame.
39
+
40
+ Args:
41
+ df: pandas DataFrame
42
+ name: Name of the data input
43
+ max_sample_rows: Number of sample rows to include
44
+ schema: Optional Pandera schema to include in context
45
+
46
+ Returns:
47
+ Formatted string with all extracted context
48
+ """
49
+ context_parts = []
50
+
51
+ # 1. Structural Context
52
+ context_parts.append(f"## Data: {name}")
53
+ context_parts.append(f"Shape: {df.shape[0]:,} rows x {df.shape[1]} columns")
54
+
55
+ # Include Pandera schema if provided (use Pandera's built-in formatter)
56
+ if schema:
57
+ context_parts.append(f"\nPandera Schema for {name} (use for validation):")
58
+ context_parts.append("```python")
59
+ context_parts.append(schema_to_script(schema))
60
+ context_parts.append("```")
61
+
62
+ # 2. Statistical Context
63
+ context_parts.append("\nStatistical Summary:")
64
+
65
+ # Numeric columns
66
+ numeric_cols = df.select_dtypes(include=["number"]).columns
67
+ if len(numeric_cols) > 0:
68
+ context_parts.append(" Numeric columns:")
69
+ desc = df[numeric_cols].describe()
70
+ for col in numeric_cols:
71
+ stats = desc[col]
72
+ context_parts.append(
73
+ f" {col}: min={stats['min']:.2g}, max={stats['max']:.2g}, "
74
+ f"mean={stats['mean']:.2g}, median={stats['50%']:.2g}"
75
+ )
76
+
77
+ # Categorical/Object columns
78
+ cat_cols = df.select_dtypes(include=["object", "category"]).columns
79
+ if len(cat_cols) > 0:
80
+ context_parts.append(" Categorical columns:")
81
+ for col in cat_cols:
82
+ unique_count = df[col].nunique()
83
+ total_count = len(df[col].dropna())
84
+ if unique_count <= 20 and total_count > 0:
85
+ # Show value counts for low-cardinality columns
86
+ top_values = df[col].value_counts().head(5)
87
+ top_str = ", ".join([f"'{k}': {v}" for k, v in top_values.items()])
88
+ context_parts.append(f" {col}: {unique_count} unique values. Top 5: {{{top_str}}}")
89
+ else:
90
+ context_parts.append(f" {col}: {unique_count} unique values")
91
+
92
+ # DateTime columns
93
+ date_cols = df.select_dtypes(include=["datetime64"]).columns
94
+ if len(date_cols) > 0:
95
+ context_parts.append(" DateTime columns:")
96
+ for col in date_cols:
97
+ min_date = df[col].min()
98
+ max_date = df[col].max()
99
+ context_parts.append(f" {col}: {min_date} to {max_date}")
100
+
101
+ # 3. Behavioral Context (patterns, invariants)
102
+ context_parts.append("\nData Patterns:")
103
+
104
+ # Check for duplicates
105
+ dup_count = df.duplicated().sum()
106
+ if dup_count > 0:
107
+ context_parts.append(f" - {dup_count:,} duplicate rows ({dup_count / len(df) * 100:.1f}%)")
108
+
109
+ # Check for potential ID columns
110
+ for col in df.columns:
111
+ if df[col].nunique() == len(df) and not df[col].isna().any():
112
+ context_parts.append(f" - '{col}' appears to be a unique identifier")
113
+ break
114
+
115
+ # 4. Representative Samples
116
+ context_parts.append(f"\nRepresentative Samples ({max_sample_rows} rows):")
117
+
118
+ # Sample strategy: first few + random + edge cases
119
+ sample_indices = []
120
+
121
+ # First rows
122
+ sample_indices.extend(range(min(2, len(df))))
123
+
124
+ # Random sample
125
+ if len(df) > max_sample_rows:
126
+ remaining = max_sample_rows - len(sample_indices)
127
+ random_indices = df.sample(n=remaining).index.tolist()
128
+ sample_indices.extend(random_indices)
129
+ else:
130
+ sample_indices = list(range(len(df)))
131
+
132
+ sample_df = df.iloc[sample_indices[:max_sample_rows]]
133
+
134
+ # Format as CSV
135
+ context_parts.append(sample_df.to_csv(index=False))
136
+
137
+ return "\n".join(context_parts)
138
+
139
+
140
+ async def extract_file_context(file: File, name: str, max_sample_rows: int = 5) -> str:
141
+ """Extract context from non-tabular files (text, binary, unknown formats).
142
+
143
+ This is a fallback for files that can't be loaded as DataFrames.
144
+ Structured files (CSV, Parquet, JSON, Excel) are handled by extract_data_context()
145
+ with Pandera schema inference.
146
+
147
+ Args:
148
+ file: File to extract context from
149
+ name: Name of the data input
150
+ max_sample_rows: Number of sample rows to include
151
+
152
+ Returns:
153
+ Formatted string with all extracted context
154
+ """
155
+ local_path = await file.download()
156
+ file_ext = Path(local_path).suffix.lower()
157
+
158
+ # Try to read as text file
159
+ try:
160
+ async with aiofiles.open(local_path, "r", encoding="utf-8", errors="ignore") as f:
161
+ lines = []
162
+ for _ in range(max_sample_rows):
163
+ line = await f.readline()
164
+ if not line:
165
+ break
166
+ lines.append(line)
167
+
168
+ context_parts = [
169
+ f"## Data: {name}",
170
+ f"Type: Text file ({file_ext})",
171
+ f"Lines: {len(lines)}",
172
+ f"\nFirst {max_sample_rows} lines:",
173
+ "".join(lines),
174
+ ]
175
+ return "\n".join(context_parts)
176
+
177
+ except Exception:
178
+ # Binary or unreadable file
179
+ file_size = Path(local_path).stat().st_size # noqa: ASYNC240
180
+ context_parts = [
181
+ f"## Data: {name}",
182
+ f"Type: Binary/Unknown ({file_ext})",
183
+ f"Size: {file_size:,} bytes",
184
+ "\n(Unable to extract text preview)",
185
+ ]
186
+ return "\n".join(context_parts)
187
+
188
+
189
+ @flyte.trace
190
+ async def extract_data_context(
191
+ data: dict[str, pd.DataFrame | File],
192
+ max_sample_rows: int = 5,
193
+ constraints: Optional[list[str]] = None,
194
+ model: Optional[str] = None,
195
+ litellm_params: Optional[dict] = None,
196
+ ) -> tuple[str, dict[str, str], int, int]:
197
+ """Extract comprehensive context from data inputs with Pandera schema inference.
198
+
199
+ Extracts:
200
+ 1. Structural context (schema, types, shape)
201
+ 2. Statistical context (distributions, ranges)
202
+ 3. Behavioral context (patterns, invariants)
203
+ 4. Operational context (scale, nulls)
204
+ 5. Representative samples
205
+ 6. Pandera schemas (inference + user constraints), returned as Python code strings
206
+
207
+ Args:
208
+ data: Dict of data inputs (File or DataFrame)
209
+ max_sample_rows: Number of sample rows to include
210
+ constraints: Optional list of user constraints to apply to schemas
211
+ model: LLM model for constraint parsing (required if constraints provided)
212
+ litellm_params: Optional LiteLLM parameters
213
+
214
+ Returns:
215
+ Tuple of (context_string, schemas_as_code_dict, total_input_tokens, total_output_tokens)
216
+ """
217
+ context_parts = []
218
+ schemas: dict[str, str] = {}
219
+ total_input_tokens = 0
220
+ total_output_tokens = 0
221
+
222
+ for name, value in data.items():
223
+ df = None
224
+
225
+ if isinstance(value, File):
226
+ # Load file as DataFrame for schema inference
227
+ local_path = await value.download()
228
+ file_ext = Path(local_path).suffix.lower()
229
+
230
+ try:
231
+ if file_ext in [".csv", ".tsv"]:
232
+ delimiter = "\t" if file_ext == ".tsv" else ","
233
+ df = pd.read_csv(local_path, delimiter=delimiter, nrows=10000)
234
+ elif file_ext in [".parquet", ".pq"]:
235
+ df = pd.read_parquet(local_path)
236
+ if len(df) > 10000:
237
+ df = df.sample(n=10000)
238
+ elif file_ext == ".json":
239
+ try:
240
+ df = pd.read_json(local_path, lines=True, nrows=10000)
241
+ except Exception:
242
+ df = pd.read_json(local_path)
243
+ elif file_ext in [".xlsx", ".xls"]:
244
+ df = pd.read_excel(local_path, nrows=10000)
245
+ else:
246
+ # Non-tabular file (e.g., .log, .txt) — extract text context
247
+ context = await extract_file_context(value, name, max_sample_rows)
248
+ context_parts.append(context)
249
+ continue
250
+ except Exception as e:
251
+ logger.warning(f"Failed to load {name} as DataFrame for schema inference: {e}")
252
+ # Fall back to non-schema extraction
253
+ context = await extract_file_context(value, name, max_sample_rows)
254
+ context_parts.append(context)
255
+ continue
256
+
257
+ elif is_dataframe(value):
258
+ df = value
259
+ else:
260
+ context_parts.append(f"## Data: {name}\nType: {type(value)}\n(Unsupported type)")
261
+ continue
262
+
263
+ if df is not None:
264
+ # Infer Pandera schema
265
+ schema = infer_conservative_schema(df)
266
+
267
+ # Apply user constraints if provided
268
+ if constraints and model:
269
+ schema, in_tok, out_tok = await apply_user_constraints(schema, constraints, name, model, litellm_params)
270
+ total_input_tokens += in_tok
271
+ total_output_tokens += out_tok
272
+
273
+ # Convert to code string for serialization
274
+ schemas[name] = schema_to_script(schema)
275
+
276
+ # Extract context with schema
277
+ context = await extract_dataframe_context(df, name, max_sample_rows, schema)
278
+ context_parts.append(context)
279
+
280
+ context_str = "\n\n" + "=" * 80 + "\n\n".join(context_parts)
281
+ return context_str, schemas, total_input_tokens, total_output_tokens