flyteplugins-codegen 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/codegen/__init__.py +18 -0
- flyteplugins/codegen/auto_coder_agent.py +1088 -0
- flyteplugins/codegen/core/__init__.py +19 -0
- flyteplugins/codegen/core/types.py +337 -0
- flyteplugins/codegen/data/__init__.py +27 -0
- flyteplugins/codegen/data/extraction.py +281 -0
- flyteplugins/codegen/data/schema.py +270 -0
- flyteplugins/codegen/execution/__init__.py +7 -0
- flyteplugins/codegen/execution/agent.py +671 -0
- flyteplugins/codegen/execution/docker.py +206 -0
- flyteplugins/codegen/generation/__init__.py +41 -0
- flyteplugins/codegen/generation/llm.py +1269 -0
- flyteplugins/codegen/generation/prompts.py +136 -0
- flyteplugins_codegen-2.0.6.dist-info/METADATA +441 -0
- flyteplugins_codegen-2.0.6.dist-info/RECORD +17 -0
- flyteplugins_codegen-2.0.6.dist-info/WHEEL +5 -0
- flyteplugins_codegen-2.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pandera.pandas as pa
|
|
8
|
+
|
|
9
|
+
from flyteplugins.codegen.core.types import _ConstraintParse
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def schema_to_script(schema: pa.DataFrameSchema) -> str:
|
|
15
|
+
"""Convert a Pandera schema to a script string, falling back to repr if black is not installed."""
|
|
16
|
+
try:
|
|
17
|
+
return schema.to_script()
|
|
18
|
+
except ImportError:
|
|
19
|
+
return repr(schema)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def infer_conservative_schema(df: pd.DataFrame) -> pa.DataFrameSchema:
|
|
23
|
+
"""Infer Pandera schema conservatively - types only, no value constraints.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
df: DataFrame to infer schema from
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
DataFrameSchema with only dtype and nullability checks
|
|
30
|
+
"""
|
|
31
|
+
# Normalize nullable extension dtypes that Pandera doesn't recognize
|
|
32
|
+
df = df.copy()
|
|
33
|
+
for col in df.columns:
|
|
34
|
+
dtype = df[col].dtype
|
|
35
|
+
if isinstance(dtype, pd.StringDtype):
|
|
36
|
+
df[col] = df[col].astype(object)
|
|
37
|
+
elif isinstance(dtype, pd.Int8Dtype | pd.Int16Dtype | pd.Int32Dtype | pd.Int64Dtype):
|
|
38
|
+
df[col] = df[col].astype("int64" if not df[col].isna().any() else "float64")
|
|
39
|
+
elif isinstance(dtype, pd.UInt8Dtype | pd.UInt16Dtype | pd.UInt32Dtype | pd.UInt64Dtype):
|
|
40
|
+
df[col] = df[col].astype("uint64" if not df[col].isna().any() else "float64")
|
|
41
|
+
elif isinstance(dtype, pd.Float32Dtype | pd.Float64Dtype):
|
|
42
|
+
df[col] = df[col].astype("float64")
|
|
43
|
+
elif isinstance(dtype, pd.BooleanDtype):
|
|
44
|
+
df[col] = df[col].astype("bool" if not df[col].isna().any() else object)
|
|
45
|
+
|
|
46
|
+
# Use Pandera's built-in inference
|
|
47
|
+
base_schema = pa.infer_schema(df)
|
|
48
|
+
|
|
49
|
+
# Remove all value-based checks (keep only type checks)
|
|
50
|
+
relaxed_columns = {}
|
|
51
|
+
|
|
52
|
+
for col_name, col_schema in base_schema.columns.items():
|
|
53
|
+
# Keep dtype and nullable, remove all checks
|
|
54
|
+
relaxed_columns[col_name] = pa.Column(
|
|
55
|
+
dtype=col_schema.dtype,
|
|
56
|
+
nullable=col_schema.nullable,
|
|
57
|
+
checks=None, # Remove all inferred checks
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return pa.DataFrameSchema(
|
|
61
|
+
columns=relaxed_columns,
|
|
62
|
+
strict=False, # Allow additional columns
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def extract_token_usage(response) -> tuple[int, int]:
|
|
67
|
+
"""Extract token usage from LLM response.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
response: LiteLLM response object
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Tuple of (input_tokens, output_tokens)
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
usage = response.usage
|
|
77
|
+
input_tokens = getattr(usage, "prompt_tokens", 0)
|
|
78
|
+
output_tokens = getattr(usage, "completion_tokens", 0)
|
|
79
|
+
return input_tokens, output_tokens
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.warning(f"Failed to extract token usage: {e}")
|
|
82
|
+
return 0, 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def parse_constraint_with_llm(
|
|
86
|
+
constraint: str,
|
|
87
|
+
data_name: str,
|
|
88
|
+
schema: pa.DataFrameSchema,
|
|
89
|
+
model: str,
|
|
90
|
+
litellm_params: Optional[dict] = None,
|
|
91
|
+
) -> tuple[Optional[_ConstraintParse], int, int]:
|
|
92
|
+
"""Use LLM to parse natural language constraint into Pandera check.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
constraint: Natural language constraint (e.g., "quantity must be positive")
|
|
96
|
+
data_name: Name of the data this constraint applies to
|
|
97
|
+
schema: Inferred schema for column name validation
|
|
98
|
+
model: LLM model to use
|
|
99
|
+
litellm_params: Optional LiteLLM parameters
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Tuple of (ConstraintParse or None, input_tokens, output_tokens)
|
|
103
|
+
"""
|
|
104
|
+
column_names = list(schema.columns.keys())
|
|
105
|
+
|
|
106
|
+
parse_prompt = f"""Parse this data constraint into a structured check.
|
|
107
|
+
|
|
108
|
+
Data name: {data_name}
|
|
109
|
+
Available columns: {", ".join(column_names)}
|
|
110
|
+
|
|
111
|
+
Constraint: "{constraint}"
|
|
112
|
+
|
|
113
|
+
Determine:
|
|
114
|
+
1. Which column does this apply to? (must be from available columns)
|
|
115
|
+
2. What type of check is this?
|
|
116
|
+
- 'greater_than': column > value (for "positive", "must be at least X")
|
|
117
|
+
- 'less_than': column < value
|
|
118
|
+
- 'between': min <= column <= max
|
|
119
|
+
- 'regex': column matches pattern (for format constraints like "YYYY-MM-DD")
|
|
120
|
+
- 'isin': column value in list (only if specific values listed)
|
|
121
|
+
- 'not_null': column cannot be null
|
|
122
|
+
- 'none': constraint doesn't apply to data validation
|
|
123
|
+
3. Parameters needed for the check
|
|
124
|
+
|
|
125
|
+
Examples:
|
|
126
|
+
- "quantity must be positive" →
|
|
127
|
+
column_name: "quantity", check_type: "greater_than",
|
|
128
|
+
parameters: {{"value": 0}}, explanation: "quantity must be greater than 0"
|
|
129
|
+
- "price between 0 and 1000" →
|
|
130
|
+
column_name: "price", check_type: "between",
|
|
131
|
+
parameters: {{"min": 0, "max": 1000}},
|
|
132
|
+
explanation: "price must be between 0 and 1000"
|
|
133
|
+
- "date in YYYY-MM-DD format" →
|
|
134
|
+
column_name: "date", check_type: "regex",
|
|
135
|
+
parameters: {{"pattern": "\\\\d{{4}}-\\\\d{{2}}-\\\\d{{2}}"}},
|
|
136
|
+
explanation: "date must match YYYY-MM-DD format"
|
|
137
|
+
- "product must be Widget A or B" →
|
|
138
|
+
column_name: "product", check_type: "isin",
|
|
139
|
+
parameters: {{"values": ["Widget A", "Widget B"]}},
|
|
140
|
+
explanation: "product must be one of the allowed values"
|
|
141
|
+
|
|
142
|
+
If constraint is unclear or doesn't apply to a specific column, use check_type: 'none'."""
|
|
143
|
+
|
|
144
|
+
params = {
|
|
145
|
+
"model": model,
|
|
146
|
+
"messages": [{"role": "user", "content": parse_prompt}],
|
|
147
|
+
"max_tokens": 300,
|
|
148
|
+
"temperature": 0.1,
|
|
149
|
+
}
|
|
150
|
+
params.update(litellm_params or {})
|
|
151
|
+
params["response_format"] = _ConstraintParse
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
response = await litellm.acompletion(**params)
|
|
155
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
156
|
+
|
|
157
|
+
content = response.choices[0].message.content
|
|
158
|
+
if isinstance(content, str):
|
|
159
|
+
parse_dict = json.loads(content)
|
|
160
|
+
parsed = _ConstraintParse(**parse_dict)
|
|
161
|
+
else:
|
|
162
|
+
parsed = content
|
|
163
|
+
|
|
164
|
+
# Validate column exists
|
|
165
|
+
if parsed.check_type != "none" and parsed.column_name not in column_names:
|
|
166
|
+
logger.warning(f"Constraint '{constraint}' references unknown column '{parsed.column_name}'. Skipping.")
|
|
167
|
+
return None, input_tokens, output_tokens
|
|
168
|
+
|
|
169
|
+
return parsed, input_tokens, output_tokens
|
|
170
|
+
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.warning(f"Failed to parse constraint '{constraint}': {e}")
|
|
173
|
+
return None, 0, 0
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def apply_parsed_constraint(
|
|
177
|
+
schema: pa.DataFrameSchema,
|
|
178
|
+
parsed: _ConstraintParse,
|
|
179
|
+
) -> pa.DataFrameSchema:
|
|
180
|
+
"""Apply a parsed constraint to the schema.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
schema: DataFrameSchema to update
|
|
184
|
+
parsed: Parsed constraint
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Updated schema
|
|
188
|
+
"""
|
|
189
|
+
if parsed.check_type == "none":
|
|
190
|
+
return schema
|
|
191
|
+
|
|
192
|
+
col_name = parsed.column_name
|
|
193
|
+
params = parsed.parameters
|
|
194
|
+
|
|
195
|
+
# Build Pandera check based on type
|
|
196
|
+
check = None
|
|
197
|
+
|
|
198
|
+
if parsed.check_type == "greater_than":
|
|
199
|
+
check = pa.Check.gt(params.value if params.value is not None else 0)
|
|
200
|
+
|
|
201
|
+
elif parsed.check_type == "less_than":
|
|
202
|
+
check = pa.Check.lt(params.value if params.value is not None else 0)
|
|
203
|
+
|
|
204
|
+
elif parsed.check_type == "between":
|
|
205
|
+
min_val = params.min if params.min is not None else 0
|
|
206
|
+
max_val = params.max if params.max is not None else 100
|
|
207
|
+
check = pa.Check.in_range(min_val, max_val)
|
|
208
|
+
|
|
209
|
+
elif parsed.check_type == "regex":
|
|
210
|
+
pattern = params.pattern if params.pattern is not None else ".*"
|
|
211
|
+
check = pa.Check.str_matches(pattern)
|
|
212
|
+
|
|
213
|
+
elif parsed.check_type == "isin":
|
|
214
|
+
values = params.values or []
|
|
215
|
+
if values:
|
|
216
|
+
check = pa.Check.isin(values)
|
|
217
|
+
|
|
218
|
+
elif parsed.check_type == "not_null":
|
|
219
|
+
# Update nullable flag instead of adding check
|
|
220
|
+
schema = schema.update_column(col_name, nullable=False)
|
|
221
|
+
return schema
|
|
222
|
+
|
|
223
|
+
if check:
|
|
224
|
+
# Add check to column
|
|
225
|
+
existing_checks = schema.columns[col_name].checks or []
|
|
226
|
+
if not isinstance(existing_checks, list):
|
|
227
|
+
existing_checks = [existing_checks]
|
|
228
|
+
|
|
229
|
+
schema = schema.update_column(col_name, checks=[*existing_checks, check])
|
|
230
|
+
|
|
231
|
+
logger.info(f"Applied constraint to '{col_name}': {parsed.explanation}")
|
|
232
|
+
|
|
233
|
+
return schema
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
async def apply_user_constraints(
|
|
237
|
+
schema: pa.DataFrameSchema,
|
|
238
|
+
constraints: list[str],
|
|
239
|
+
data_name: str,
|
|
240
|
+
model: str,
|
|
241
|
+
litellm_params: Optional[dict] = None,
|
|
242
|
+
) -> tuple[pa.DataFrameSchema, int, int]:
|
|
243
|
+
"""Apply user-specified constraints to schema using LLM parsing.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
schema: Base schema (types only)
|
|
247
|
+
constraints: List of natural language constraints
|
|
248
|
+
data_name: Name of the data
|
|
249
|
+
model: LLM model for parsing
|
|
250
|
+
litellm_params: Optional LiteLLM parameters
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Tuple of (enhanced_schema, total_input_tokens, total_output_tokens)
|
|
254
|
+
"""
|
|
255
|
+
enhanced_schema = schema
|
|
256
|
+
total_input_tokens = 0
|
|
257
|
+
total_output_tokens = 0
|
|
258
|
+
|
|
259
|
+
for constraint in constraints:
|
|
260
|
+
# Use LLM to parse constraint
|
|
261
|
+
parsed, in_tok, out_tok = await parse_constraint_with_llm(constraint, data_name, schema, model, litellm_params)
|
|
262
|
+
|
|
263
|
+
total_input_tokens += in_tok
|
|
264
|
+
total_output_tokens += out_tok
|
|
265
|
+
|
|
266
|
+
if parsed:
|
|
267
|
+
# Apply to schema
|
|
268
|
+
enhanced_schema = apply_parsed_constraint(enhanced_schema, parsed)
|
|
269
|
+
|
|
270
|
+
return enhanced_schema, total_input_tokens, total_output_tokens
|