flyteplugins-codegen 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,270 @@
1
+ import json
2
+ import logging
3
+ from typing import Optional
4
+
5
+ import litellm
6
+ import pandas as pd
7
+ import pandera.pandas as pa
8
+
9
+ from flyteplugins.codegen.core.types import _ConstraintParse
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def schema_to_script(schema: pa.DataFrameSchema) -> str:
15
+ """Convert a Pandera schema to a script string, falling back to repr if black is not installed."""
16
+ try:
17
+ return schema.to_script()
18
+ except ImportError:
19
+ return repr(schema)
20
+
21
+
22
+ def infer_conservative_schema(df: pd.DataFrame) -> pa.DataFrameSchema:
23
+ """Infer Pandera schema conservatively - types only, no value constraints.
24
+
25
+ Args:
26
+ df: DataFrame to infer schema from
27
+
28
+ Returns:
29
+ DataFrameSchema with only dtype and nullability checks
30
+ """
31
+ # Normalize nullable extension dtypes that Pandera doesn't recognize
32
+ df = df.copy()
33
+ for col in df.columns:
34
+ dtype = df[col].dtype
35
+ if isinstance(dtype, pd.StringDtype):
36
+ df[col] = df[col].astype(object)
37
+ elif isinstance(dtype, pd.Int8Dtype | pd.Int16Dtype | pd.Int32Dtype | pd.Int64Dtype):
38
+ df[col] = df[col].astype("int64" if not df[col].isna().any() else "float64")
39
+ elif isinstance(dtype, pd.UInt8Dtype | pd.UInt16Dtype | pd.UInt32Dtype | pd.UInt64Dtype):
40
+ df[col] = df[col].astype("uint64" if not df[col].isna().any() else "float64")
41
+ elif isinstance(dtype, pd.Float32Dtype | pd.Float64Dtype):
42
+ df[col] = df[col].astype("float64")
43
+ elif isinstance(dtype, pd.BooleanDtype):
44
+ df[col] = df[col].astype("bool" if not df[col].isna().any() else object)
45
+
46
+ # Use Pandera's built-in inference
47
+ base_schema = pa.infer_schema(df)
48
+
49
+ # Remove all value-based checks (keep only type checks)
50
+ relaxed_columns = {}
51
+
52
+ for col_name, col_schema in base_schema.columns.items():
53
+ # Keep dtype and nullable, remove all checks
54
+ relaxed_columns[col_name] = pa.Column(
55
+ dtype=col_schema.dtype,
56
+ nullable=col_schema.nullable,
57
+ checks=None, # Remove all inferred checks
58
+ )
59
+
60
+ return pa.DataFrameSchema(
61
+ columns=relaxed_columns,
62
+ strict=False, # Allow additional columns
63
+ )
64
+
65
+
66
+ def extract_token_usage(response) -> tuple[int, int]:
67
+ """Extract token usage from LLM response.
68
+
69
+ Args:
70
+ response: LiteLLM response object
71
+
72
+ Returns:
73
+ Tuple of (input_tokens, output_tokens)
74
+ """
75
+ try:
76
+ usage = response.usage
77
+ input_tokens = getattr(usage, "prompt_tokens", 0)
78
+ output_tokens = getattr(usage, "completion_tokens", 0)
79
+ return input_tokens, output_tokens
80
+ except Exception as e:
81
+ logger.warning(f"Failed to extract token usage: {e}")
82
+ return 0, 0
83
+
84
+
85
+ async def parse_constraint_with_llm(
86
+ constraint: str,
87
+ data_name: str,
88
+ schema: pa.DataFrameSchema,
89
+ model: str,
90
+ litellm_params: Optional[dict] = None,
91
+ ) -> tuple[Optional[_ConstraintParse], int, int]:
92
+ """Use LLM to parse natural language constraint into Pandera check.
93
+
94
+ Args:
95
+ constraint: Natural language constraint (e.g., "quantity must be positive")
96
+ data_name: Name of the data this constraint applies to
97
+ schema: Inferred schema for column name validation
98
+ model: LLM model to use
99
+ litellm_params: Optional LiteLLM parameters
100
+
101
+ Returns:
102
+ Tuple of (ConstraintParse or None, input_tokens, output_tokens)
103
+ """
104
+ column_names = list(schema.columns.keys())
105
+
106
+ parse_prompt = f"""Parse this data constraint into a structured check.
107
+
108
+ Data name: {data_name}
109
+ Available columns: {", ".join(column_names)}
110
+
111
+ Constraint: "{constraint}"
112
+
113
+ Determine:
114
+ 1. Which column does this apply to? (must be from available columns)
115
+ 2. What type of check is this?
116
+ - 'greater_than': column > value (for "positive", "must be at least X")
117
+ - 'less_than': column < value
118
+ - 'between': min <= column <= max
119
+ - 'regex': column matches pattern (for format constraints like "YYYY-MM-DD")
120
+ - 'isin': column value in list (only if specific values listed)
121
+ - 'not_null': column cannot be null
122
+ - 'none': constraint doesn't apply to data validation
123
+ 3. Parameters needed for the check
124
+
125
+ Examples:
126
+ - "quantity must be positive" →
127
+ column_name: "quantity", check_type: "greater_than",
128
+ parameters: {{"value": 0}}, explanation: "quantity must be greater than 0"
129
+ - "price between 0 and 1000" →
130
+ column_name: "price", check_type: "between",
131
+ parameters: {{"min": 0, "max": 1000}},
132
+ explanation: "price must be between 0 and 1000"
133
+ - "date in YYYY-MM-DD format" →
134
+ column_name: "date", check_type: "regex",
135
+ parameters: {{"pattern": "\\\\d{{4}}-\\\\d{{2}}-\\\\d{{2}}"}},
136
+ explanation: "date must match YYYY-MM-DD format"
137
+ - "product must be Widget A or B" →
138
+ column_name: "product", check_type: "isin",
139
+ parameters: {{"values": ["Widget A", "Widget B"]}},
140
+ explanation: "product must be one of the allowed values"
141
+
142
+ If constraint is unclear or doesn't apply to a specific column, use check_type: 'none'."""
143
+
144
+ params = {
145
+ "model": model,
146
+ "messages": [{"role": "user", "content": parse_prompt}],
147
+ "max_tokens": 300,
148
+ "temperature": 0.1,
149
+ }
150
+ params.update(litellm_params or {})
151
+ params["response_format"] = _ConstraintParse
152
+
153
+ try:
154
+ response = await litellm.acompletion(**params)
155
+ input_tokens, output_tokens = extract_token_usage(response)
156
+
157
+ content = response.choices[0].message.content
158
+ if isinstance(content, str):
159
+ parse_dict = json.loads(content)
160
+ parsed = _ConstraintParse(**parse_dict)
161
+ else:
162
+ parsed = content
163
+
164
+ # Validate column exists
165
+ if parsed.check_type != "none" and parsed.column_name not in column_names:
166
+ logger.warning(f"Constraint '{constraint}' references unknown column '{parsed.column_name}'. Skipping.")
167
+ return None, input_tokens, output_tokens
168
+
169
+ return parsed, input_tokens, output_tokens
170
+
171
+ except Exception as e:
172
+ logger.warning(f"Failed to parse constraint '{constraint}': {e}")
173
+ return None, 0, 0
174
+
175
+
176
+ def apply_parsed_constraint(
177
+ schema: pa.DataFrameSchema,
178
+ parsed: _ConstraintParse,
179
+ ) -> pa.DataFrameSchema:
180
+ """Apply a parsed constraint to the schema.
181
+
182
+ Args:
183
+ schema: DataFrameSchema to update
184
+ parsed: Parsed constraint
185
+
186
+ Returns:
187
+ Updated schema
188
+ """
189
+ if parsed.check_type == "none":
190
+ return schema
191
+
192
+ col_name = parsed.column_name
193
+ params = parsed.parameters
194
+
195
+ # Build Pandera check based on type
196
+ check = None
197
+
198
+ if parsed.check_type == "greater_than":
199
+ check = pa.Check.gt(params.value if params.value is not None else 0)
200
+
201
+ elif parsed.check_type == "less_than":
202
+ check = pa.Check.lt(params.value if params.value is not None else 0)
203
+
204
+ elif parsed.check_type == "between":
205
+ min_val = params.min if params.min is not None else 0
206
+ max_val = params.max if params.max is not None else 100
207
+ check = pa.Check.in_range(min_val, max_val)
208
+
209
+ elif parsed.check_type == "regex":
210
+ pattern = params.pattern if params.pattern is not None else ".*"
211
+ check = pa.Check.str_matches(pattern)
212
+
213
+ elif parsed.check_type == "isin":
214
+ values = params.values or []
215
+ if values:
216
+ check = pa.Check.isin(values)
217
+
218
+ elif parsed.check_type == "not_null":
219
+ # Update nullable flag instead of adding check
220
+ schema = schema.update_column(col_name, nullable=False)
221
+ return schema
222
+
223
+ if check:
224
+ # Add check to column
225
+ existing_checks = schema.columns[col_name].checks or []
226
+ if not isinstance(existing_checks, list):
227
+ existing_checks = [existing_checks]
228
+
229
+ schema = schema.update_column(col_name, checks=[*existing_checks, check])
230
+
231
+ logger.info(f"Applied constraint to '{col_name}': {parsed.explanation}")
232
+
233
+ return schema
234
+
235
+
236
+ async def apply_user_constraints(
237
+ schema: pa.DataFrameSchema,
238
+ constraints: list[str],
239
+ data_name: str,
240
+ model: str,
241
+ litellm_params: Optional[dict] = None,
242
+ ) -> tuple[pa.DataFrameSchema, int, int]:
243
+ """Apply user-specified constraints to schema using LLM parsing.
244
+
245
+ Args:
246
+ schema: Base schema (types only)
247
+ constraints: List of natural language constraints
248
+ data_name: Name of the data
249
+ model: LLM model for parsing
250
+ litellm_params: Optional LiteLLM parameters
251
+
252
+ Returns:
253
+ Tuple of (enhanced_schema, total_input_tokens, total_output_tokens)
254
+ """
255
+ enhanced_schema = schema
256
+ total_input_tokens = 0
257
+ total_output_tokens = 0
258
+
259
+ for constraint in constraints:
260
+ # Use LLM to parse constraint
261
+ parsed, in_tok, out_tok = await parse_constraint_with_llm(constraint, data_name, schema, model, litellm_params)
262
+
263
+ total_input_tokens += in_tok
264
+ total_output_tokens += out_tok
265
+
266
+ if parsed:
267
+ # Apply to schema
268
+ enhanced_schema = apply_parsed_constraint(enhanced_schema, parsed)
269
+
270
+ return enhanced_schema, total_input_tokens, total_output_tokens
@@ -0,0 +1,7 @@
1
+ from flyteplugins.codegen.execution.agent import code_gen_eval_agent
2
+ from flyteplugins.codegen.execution.docker import build_image
3
+
4
+ __all__ = [
5
+ "build_image",
6
+ "code_gen_eval_agent",
7
+ ]