erdo 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of erdo might be problematic. Click here for more details.

erdo/state.py ADDED
@@ -0,0 +1,376 @@
1
+ """
2
+ Erdo State Management
3
+
4
+ Provides a magic `state` object that allows clean Python syntax like:
5
+ - state.code
6
+ - state.dataset.id
7
+ - f"Analysis for: {state.code}"
8
+
9
+ The state object tracks field access for static analysis and template conversion.
10
+ """
11
+
12
+ import ast
13
+ from collections import defaultdict
14
+ from typing import Any, Dict, Optional, Set
15
+
16
+ # Import template functions list - no fallback, fail fast if missing
17
+ from ._generated.template_functions import ALL_TEMPLATE_FUNCTIONS
18
+
19
+
20
+ class StateFieldTracker:
21
+ """Tracks field access on the state object for template conversion."""
22
+
23
+ def __init__(self):
24
+ self.accessed_fields: Set[str] = set()
25
+ self.nested_access: Dict[str, Set[str]] = defaultdict(set)
26
+
27
+ def record_access(self, field_path: str) -> None:
28
+ """Record that a state field was accessed."""
29
+ self.accessed_fields.add(field_path)
30
+
31
+ # Track nested access (e.g., "dataset.id" -> nested_access["dataset"].add("id"))
32
+ parts = field_path.split(".")
33
+ if len(parts) > 1:
34
+ parent = parts[0]
35
+ child = ".".join(parts[1:])
36
+ self.nested_access[parent].add(child)
37
+
38
+
39
+ class StateMethodProxy:
40
+ """Proxy object for state method calls like state.toJSON(x)"""
41
+
42
+ def __init__(self, method_name: str, tracker: StateFieldTracker):
43
+ self._method_name = method_name
44
+ self._tracker = tracker
45
+
46
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
47
+ """Handle method calls like state.toJSON(state.security_issues)"""
48
+ # Record that this method was called
49
+ self._tracker.record_access(f"{self._method_name}(*args)")
50
+
51
+ # For import-time safety, return a safe placeholder
52
+ if self._method_name == "toJSON":
53
+ return f"{{{{toJSON {args[0] if args else ''}}}}}"
54
+ elif self._method_name == "len":
55
+ return f"{{{{len {args[0] if args else ''}}}}}"
56
+ else:
57
+ # Generic method call placeholder
58
+ return f"{{{{{self._method_name} {' '.join(str(arg) for arg in args)}}}}}"
59
+
60
+ def __str__(self) -> str:
61
+ return f"{{{{.{self._method_name}}}}}"
62
+
63
+ def __repr__(self) -> str:
64
+ return f"StateMethodProxy('{self._method_name}')"
65
+
66
+
67
+ class NestedStateProxy(str):
68
+ """Proxy object for nested state access like state.dataset.id"""
69
+
70
+ def __new__(cls, parent_path: str, tracker: StateFieldTracker):
71
+ # Create a string with the template representation
72
+ template_str = f"{{{{.Data.{parent_path}}}}}"
73
+ obj = str.__new__(cls, template_str)
74
+ obj._parent_path = parent_path
75
+ obj._tracker = tracker
76
+ return obj
77
+
78
+ def __getattr__(self, name: str) -> Any:
79
+ if name.startswith("_"):
80
+ return super().__getattribute__(name)
81
+ field_path = f"{self._parent_path}.{name}"
82
+ self._tracker.record_access(field_path)
83
+
84
+ # Return another proxy for further nesting
85
+ return NestedStateProxy(field_path, self._tracker)
86
+
87
+ def __str__(self) -> str:
88
+ """Convert to template string when used in f-strings."""
89
+ # Handle special references that need .Data prefix
90
+ if self._parent_path.startswith("steps.") or self._parent_path.startswith(
91
+ "system."
92
+ ):
93
+ return f"{{{{.Data.{self._parent_path}}}}}"
94
+ return f"{{{{{self._parent_path}}}}}"
95
+
96
+ def __repr__(self) -> str:
97
+ return f"NestedStateProxy('{self._parent_path}')"
98
+
99
+ def __reduce__(self):
100
+ """Support for pickling/serialization - return the template string."""
101
+ return (str, (f"{{{{.Data.{self._parent_path}}}}}",))
102
+
103
+ def __eq__(self, other: Any) -> bool:
104
+ """Handle equality comparisons gracefully."""
105
+ if isinstance(other, NestedStateProxy):
106
+ return self._parent_path == other._parent_path
107
+ return False
108
+
109
+ def __bool__(self):
110
+ """Handle boolean context gracefully."""
111
+ return True
112
+
113
+ def __hash__(self):
114
+ """Make proxy hashable for use in dicts/sets."""
115
+ return hash(self._parent_path)
116
+
117
+ def __iter__(self):
118
+ """Handle iteration attempts gracefully."""
119
+ return iter([])
120
+
121
+ def __len__(self):
122
+ """Handle len() calls gracefully."""
123
+ return 0
124
+
125
+ def __getitem__(self, key: Any) -> "NestedStateProxy":
126
+ """Handle indexing gracefully."""
127
+ return NestedStateProxy(f"{self._parent_path}[{key}]", self._tracker)
128
+
129
+ def __setattr__(self, name: str, value: Any):
130
+ """Override setattr to allow internal attributes while tracking field access."""
131
+ if name.startswith("_"):
132
+ super().__setattr__(name, value)
133
+ else:
134
+ # Record the assignment as a field access
135
+ field_path = (
136
+ f"{self._parent_path}.{name}" if hasattr(self, "_parent_path") else name
137
+ )
138
+ if hasattr(self, "_tracker"):
139
+ self._tracker.record_access(field_path)
140
+
141
+
142
+ class StateMagic:
143
+ """Magic state object that tracks field access and provides clean Python syntax."""
144
+
145
+ def __init__(self):
146
+ self._tracker = StateFieldTracker()
147
+ self._test_values: Dict[str, Any] = {}
148
+
149
+ def __getattr__(self, name: str) -> Any:
150
+ """Handle attribute access like state.code, state.dataset, etc."""
151
+ self._tracker.record_access(name)
152
+
153
+ # Handle method calls that should return callable proxies
154
+ if name in ALL_TEMPLATE_FUNCTIONS:
155
+ return StateMethodProxy(name, self._tracker)
156
+
157
+ # If we have a test value, check if it's a dict (nested object)
158
+ if name in self._test_values:
159
+ test_value = self._test_values[name]
160
+ if isinstance(test_value, dict):
161
+ # For nested objects, return a proxy that can handle further access
162
+ proxy = NestedStateProxy(name, self._tracker)
163
+ # Attach test data to the proxy for local testing
164
+ proxy._test_data = test_value
165
+ return proxy
166
+ # For non-dict test values, still return a proxy to allow chaining
167
+ # but it will return the string representation when accessed
168
+ return NestedStateProxy(name, self._tracker)
169
+
170
+ # Always return a NestedStateProxy to support nested access
171
+ # This allows state.organization.name to work correctly
172
+ return NestedStateProxy(name, self._tracker)
173
+
174
+ def __str__(self) -> str:
175
+ """When used in f-strings, this shouldn't happen (individual fields should be accessed)."""
176
+ return "{{state}}" # Fallback, though this should rarely be used
177
+
178
+ def __setattr__(self, name: str, value: Any):
179
+ """Override setattr to allow internal attributes while tracking field access."""
180
+ if name.startswith("_"):
181
+ super().__setattr__(name, value)
182
+ else:
183
+ # Record the assignment as a field access
184
+ self._tracker.record_access(name)
185
+ self._test_values[name] = value
186
+
187
+ def set_test_value(self, field_path: str, value: Any):
188
+ """Set a test value for local development/testing."""
189
+ parts = field_path.split(".")
190
+ current = self._test_values
191
+
192
+ for part in parts[:-1]:
193
+ if part not in current:
194
+ current[part] = {}
195
+ current = current[part]
196
+
197
+ current[parts[-1]] = value
198
+
199
+ def get_accessed_fields(self) -> Set[str]:
200
+ """Get all fields that have been accessed."""
201
+ return self._tracker.accessed_fields.copy()
202
+
203
+ def clear_tracking(self):
204
+ """Clear the field access tracking."""
205
+ self._tracker = StateFieldTracker()
206
+
207
+
208
+ # Global state object for use in agent definitions
209
+ state = StateMagic()
210
+
211
+
212
+ def extract_state_references_from_ast(source_code: str) -> Set[str]:
213
+ """Extract all state.* references from Python source code using AST parsing."""
214
+ try:
215
+ tree = ast.parse(source_code)
216
+ except SyntaxError:
217
+ return set()
218
+
219
+ state_refs = set()
220
+
221
+ class StateVisitor(ast.NodeVisitor):
222
+ def visit_Attribute(self, node):
223
+ """Visit attribute access like state.code, state.dataset.id"""
224
+ if isinstance(node.value, ast.Name) and node.value.id == "state":
225
+ # Simple case: state.field
226
+ state_refs.add(node.attr)
227
+ elif isinstance(node.value, ast.Attribute):
228
+ # Nested case: state.dataset.id
229
+ path = self._get_full_attribute_path(node)
230
+ if path and path.startswith("state."):
231
+ # Remove 'state.' prefix
232
+ field_path = path[6:]
233
+ state_refs.add(field_path)
234
+
235
+ self.generic_visit(node)
236
+
237
+ def visit_JoinedStr(self, node):
238
+ """Visit f-string expressions like f"Analysis for: {state.code}" """
239
+ for value in node.values:
240
+ if isinstance(value, ast.FormattedValue):
241
+ # Extract the expression inside the f-string
242
+ if isinstance(value.value, ast.Attribute):
243
+ path = self._get_full_attribute_path(value.value)
244
+ if path and path.startswith("state."):
245
+ field_path = path[6:]
246
+ state_refs.add(field_path)
247
+ elif (
248
+ isinstance(value.value, ast.Name) and value.value.id == "state"
249
+ ):
250
+ state_refs.add("state") # Direct state reference
251
+
252
+ self.generic_visit(node)
253
+
254
+ def _get_full_attribute_path(self, node):
255
+ """Get the full dotted path for an attribute access."""
256
+ if isinstance(node, ast.Attribute):
257
+ if isinstance(node.value, ast.Name):
258
+ return f"{node.value.id}.{node.attr}"
259
+ else:
260
+ parent_path = self._get_full_attribute_path(node.value)
261
+ if parent_path:
262
+ return f"{parent_path}.{node.attr}"
263
+ elif isinstance(node, ast.Name):
264
+ return node.id
265
+ return None
266
+
267
+ visitor = StateVisitor()
268
+ visitor.visit(tree)
269
+
270
+ return state_refs
271
+
272
+
273
+ def convert_fstring_to_template(source_code: str, state_refs: Set[str]) -> str:
274
+ """Convert f-strings with state references to Go template format."""
275
+
276
+ class FStringConverter(ast.NodeTransformer):
277
+ def visit_JoinedStr(self, node):
278
+ """Convert f-strings to regular strings with Go template syntax."""
279
+ parts = []
280
+ has_state_ref = False
281
+
282
+ for value in node.values:
283
+ if isinstance(value, ast.Constant):
284
+ # Regular string part
285
+ parts.append(value.value)
286
+ elif isinstance(value, ast.FormattedValue):
287
+ # Expression inside f-string
288
+ if isinstance(value.value, ast.Attribute):
289
+ path = self._get_full_attribute_path(value.value)
290
+ if path and path.startswith("state."):
291
+ # Convert state.field to {{field}}
292
+ field_path = path[6:]
293
+ parts.append(f"{{{{{field_path}}}}}")
294
+ has_state_ref = True
295
+ continue
296
+
297
+ # Non-state expression - convert back to string representation
298
+ # This is complex, so for now we'll leave it as is
299
+ # In practice, most f-strings in agent code should be simple state refs
300
+ parts.append(f"{{{ast.unparse(value.value)}}}")
301
+
302
+ if has_state_ref:
303
+ # Replace the f-string with a regular string
304
+ template_str = "".join(parts)
305
+ return ast.Constant(value=template_str)
306
+
307
+ return node
308
+
309
+ def _get_full_attribute_path(self, node):
310
+ """Get the full dotted path for an attribute access."""
311
+ if isinstance(node, ast.Attribute):
312
+ if isinstance(node.value, ast.Name):
313
+ return f"{node.value.id}.{node.attr}"
314
+ else:
315
+ parent_path = self._get_full_attribute_path(node.value)
316
+ if parent_path:
317
+ return f"{parent_path}.{node.attr}"
318
+ elif isinstance(node, ast.Name):
319
+ return node.id
320
+ return None
321
+
322
+ try:
323
+ tree = ast.parse(source_code)
324
+ converter = FStringConverter()
325
+ new_tree = converter.visit(tree)
326
+ return ast.unparse(new_tree)
327
+ except Exception:
328
+ # If conversion fails, return original
329
+ return source_code
330
+
331
+
332
+ def validate_state_fields(
333
+ state_refs: Set[str], available_fields: Optional[Set[str]] = None
334
+ ) -> Dict[str, str]:
335
+ """Validate that all referenced state fields are available.
336
+
337
+ Args:
338
+ state_refs: Set of state field references found in the code
339
+ available_fields: Optional set of known available fields. If None, no validation is performed.
340
+
341
+ Returns a dict of {invalid_field: error_message} for any issues.
342
+ """
343
+ errors: Dict[str, str] = {}
344
+
345
+ # If no available_fields provided, skip validation (user-defined state is flexible)
346
+ if available_fields is None:
347
+ return errors
348
+
349
+ for field_ref in state_refs:
350
+ if field_ref not in available_fields:
351
+ # Check if it's a nested field
352
+ parts = field_ref.split(".")
353
+ if len(parts) > 1:
354
+ parent = parts[0]
355
+ if parent not in available_fields:
356
+ errors[field_ref] = f"State field '{parent}' is not available"
357
+ # For nested fields, assume they're valid if parent exists
358
+ # Real validation would require schema knowledge
359
+ else:
360
+ errors[field_ref] = f"State field '{field_ref}' is not available"
361
+
362
+ return errors
363
+
364
+
365
+ def setup_test_state(**test_values):
366
+ """Setup test values for local development and testing.
367
+
368
+ Example:
369
+ setup_test_state(
370
+ code="print('hello')",
371
+ dataset={'id': 'test123', 'config': {'type': 'csv'}},
372
+ query="analyze this data"
373
+ )
374
+ """
375
+ for field_path, value in test_values.items():
376
+ state.set_test_value(field_path, value)
erdo/template.py ADDED
@@ -0,0 +1,136 @@
1
+ """
2
+ Template string handling for export/import roundtrip compatibility.
3
+
4
+ This module provides utilities for handling template strings during the
5
+ export/import roundtrip process. Template strings are Go template expressions
6
+ that can't be executed as Python, so they need special handling.
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, Union
10
+
11
+ if TYPE_CHECKING:
12
+ from erdo.types import Prompt
13
+
14
+
15
+ class TemplateString:
16
+ """
17
+ A wrapper class for template strings during export/import roundtrip.
18
+
19
+ This class represents a Go template string (like {{.Data.field}}) in a way
20
+ that can be executed as Python code and then converted back to the original
21
+ template string during import.
22
+
23
+ It implements various duck-typing methods to be compatible with Pydantic
24
+ validation while preserving the template content for later extraction.
25
+ """
26
+
27
+ template: str
28
+
29
+ def __init__(self, template: Union[str, "Prompt"]):
30
+ """
31
+ Initialize a TemplateString with the template content.
32
+
33
+ Args:
34
+ template: The template string content (e.g., "{{.Data.field}}") or a Prompt object
35
+ """
36
+ # Convert Prompt objects to strings, ensure template is always a string
37
+ if hasattr(template, "content"):
38
+ # This is a Prompt object
39
+ self.template = str(template)
40
+ else:
41
+ self.template = str(template)
42
+
43
+ def __str__(self) -> str:
44
+ """Return the template string for display purposes."""
45
+ return self.template
46
+
47
+ def __repr__(self) -> str:
48
+ """Return a representation of the TemplateString."""
49
+ return f"TemplateString({self.template!r})"
50
+
51
+ def __eq__(self, other: object) -> bool:
52
+ """Check equality with another TemplateString."""
53
+ if isinstance(other, TemplateString):
54
+ return self.template == other.template
55
+ return False
56
+
57
+ def __hash__(self) -> int:
58
+ """Make TemplateString hashable."""
59
+ return hash(self.template)
60
+
61
+ # Duck typing methods to make it behave like a string for Pydantic validation
62
+ def __len__(self) -> int:
63
+ """Return length to behave like a string."""
64
+ return len(self.template)
65
+
66
+ def __contains__(self, item) -> bool:
67
+ """Support 'in' operator to behave like a string."""
68
+ return item in self.template
69
+
70
+ def __getitem__(self, key):
71
+ """Support indexing to behave like a string."""
72
+ return self.template[key]
73
+
74
+ def __iter__(self):
75
+ """Support iteration to behave like a string."""
76
+ return iter(self.template)
77
+
78
+ # List-like methods for cases where template strings represent arrays
79
+ def __getstate__(self):
80
+ """Support pickling."""
81
+ return {"template": self.template}
82
+
83
+ def __setstate__(self, state):
84
+ """Support unpickling."""
85
+ self.template = state["template"]
86
+
87
+ # Additional methods to help with Pydantic validation
88
+ @classmethod
89
+ def __get_validators__(cls):
90
+ """Pydantic v1 compatibility."""
91
+ yield cls.validate
92
+
93
+ @classmethod
94
+ def validate(cls, v, *args, **kwargs):
95
+ """Validate that the value is a string, TemplateString, or basic type."""
96
+ if isinstance(v, cls):
97
+ return v
98
+ # Allow strings and basic types that can be converted to strings
99
+ if isinstance(v, (str, int, float, bool)):
100
+ return v
101
+ # Allow Prompt objects (they have __str__ method and content attribute)
102
+ if hasattr(v, "__str__") and hasattr(v, "content"):
103
+ # This is likely a Prompt object
104
+ return str(v)
105
+ # Reject functions, lambdas, and other complex types
106
+ if callable(v):
107
+ raise ValueError(
108
+ f"Template fields cannot accept callable objects like {type(v).__name__}"
109
+ )
110
+ # For other types, try to convert to string
111
+ try:
112
+ str(v)
113
+ return v
114
+ except Exception:
115
+ raise ValueError(
116
+ f"Template fields cannot accept objects of type {type(v).__name__}"
117
+ )
118
+
119
+ @classmethod
120
+ def __get_pydantic_core_schema__(cls, source_type, handler):
121
+ """Pydantic v2 compatibility."""
122
+ from pydantic_core import core_schema
123
+
124
+ return core_schema.no_info_plain_validator_function(cls.validate)
125
+
126
+ def to_template_string(self) -> str:
127
+ """
128
+ Convert back to the original template string format.
129
+
130
+ This is used during the import process to convert the TemplateString
131
+ object back to the raw template string.
132
+
133
+ Returns:
134
+ The original template string
135
+ """
136
+ return self.template