agnt5 0.2.2__cp39-abi3-macosx_11_0_arm64.whl → 0.2.4__cp39-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +12 -12
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +169 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_telemetry.py +28 -7
- agnt5/agent.py +153 -140
- agnt5/client.py +50 -12
- agnt5/context.py +36 -756
- agnt5/entity.py +368 -1160
- agnt5/function.py +208 -235
- agnt5/lm.py +71 -12
- agnt5/tool.py +25 -11
- agnt5/tracing.py +196 -0
- agnt5/worker.py +205 -173
- agnt5/workflow.py +444 -20
- {agnt5-0.2.2.dist-info → agnt5-0.2.4.dist-info}/METADATA +2 -1
- agnt5-0.2.4.dist-info/RECORD +22 -0
- agnt5-0.2.2.dist-info/RECORD +0 -19
- {agnt5-0.2.2.dist-info → agnt5-0.2.4.dist-info}/WHEEL +0 -0
agnt5/function.py
CHANGED
|
@@ -5,10 +5,11 @@ from __future__ import annotations
|
|
|
5
5
|
import asyncio
|
|
6
6
|
import functools
|
|
7
7
|
import inspect
|
|
8
|
-
import time
|
|
9
8
|
import uuid
|
|
10
|
-
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
|
9
|
+
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
|
11
10
|
|
|
11
|
+
from ._retry_utils import execute_with_retry, parse_backoff_policy, parse_retry_policy
|
|
12
|
+
from ._schema_utils import extract_function_metadata, extract_function_schemas
|
|
12
13
|
from .context import Context
|
|
13
14
|
from .exceptions import RetryError
|
|
14
15
|
from .types import BackoffPolicy, BackoffType, FunctionConfig, HandlerFunc, RetryPolicy
|
|
@@ -18,13 +19,117 @@ T = TypeVar("T")
|
|
|
18
19
|
# Global function registry
|
|
19
20
|
_FUNCTION_REGISTRY: Dict[str, FunctionConfig] = {}
|
|
20
21
|
|
|
22
|
+
class FunctionContext(Context):
|
|
23
|
+
"""
|
|
24
|
+
Lightweight context for stateless functions.
|
|
25
|
+
|
|
26
|
+
AGNT5 Philosophy: Context is a convenience, not a requirement.
|
|
27
|
+
The best function is one that doesn't need context at all!
|
|
28
|
+
|
|
29
|
+
Provides only:
|
|
30
|
+
- Quick logging (ctx.log())
|
|
31
|
+
- Execution metadata (run_id, attempt)
|
|
32
|
+
- Smart retry helper (should_retry())
|
|
33
|
+
- Non-durable sleep
|
|
34
|
+
|
|
35
|
+
Does NOT provide:
|
|
36
|
+
- Orchestration (task, parallel, gather) - use workflows
|
|
37
|
+
- State management (get, set, delete) - functions are stateless
|
|
38
|
+
- Checkpointing (step) - functions are atomic
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
run_id: str,
|
|
44
|
+
attempt: int = 0,
|
|
45
|
+
runtime_context: Optional[Any] = None,
|
|
46
|
+
retry_policy: Optional[Any] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Initialize function context.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
run_id: Unique execution identifier
|
|
53
|
+
attempt: Retry attempt number (0-indexed)
|
|
54
|
+
runtime_context: RuntimeContext for trace correlation
|
|
55
|
+
retry_policy: RetryPolicy for should_retry() checks
|
|
56
|
+
"""
|
|
57
|
+
super().__init__(run_id, attempt, runtime_context)
|
|
58
|
+
self._retry_policy = retry_policy
|
|
59
|
+
|
|
60
|
+
# === Quick Logging ===
|
|
61
|
+
|
|
62
|
+
def log(self, message: str, **extra) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Quick logging shorthand with structured data.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
ctx.log("Processing payment", amount=100.50, user_id="123")
|
|
68
|
+
"""
|
|
69
|
+
self._logger.info(message, extra=extra)
|
|
70
|
+
|
|
71
|
+
# === Smart Execution ===
|
|
72
|
+
|
|
73
|
+
def should_retry(self, error: Exception) -> bool:
|
|
74
|
+
"""
|
|
75
|
+
Check if error is retryable based on configured policy.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
try:
|
|
79
|
+
result = await external_api()
|
|
80
|
+
except Exception as e:
|
|
81
|
+
if not ctx.should_retry(e):
|
|
82
|
+
raise # Fail fast for non-retryable errors
|
|
83
|
+
# Otherwise let retry policy handle it
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
True if error is retryable, False otherwise
|
|
88
|
+
"""
|
|
89
|
+
# TODO: Implement retry policy checks
|
|
90
|
+
# For now, all errors are retryable (let retry policy handle it)
|
|
91
|
+
return True
|
|
92
|
+
|
|
93
|
+
async def sleep(self, seconds: float) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Non-durable async sleep.
|
|
96
|
+
|
|
97
|
+
For durable sleep across failures, use workflows.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
seconds: Number of seconds to sleep
|
|
101
|
+
"""
|
|
102
|
+
import asyncio
|
|
103
|
+
await asyncio.sleep(seconds)
|
|
104
|
+
|
|
105
|
+
|
|
21
106
|
|
|
22
107
|
class FunctionRegistry:
|
|
23
108
|
"""Registry for function handlers."""
|
|
24
109
|
|
|
25
110
|
@staticmethod
|
|
26
111
|
def register(config: FunctionConfig) -> None:
|
|
27
|
-
"""Register a function handler.
|
|
112
|
+
"""Register a function handler.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
config: Function configuration to register
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
ValueError: If a function with the same name is already registered
|
|
119
|
+
"""
|
|
120
|
+
# Check for name collision
|
|
121
|
+
if config.name in _FUNCTION_REGISTRY:
|
|
122
|
+
existing_config = _FUNCTION_REGISTRY[config.name]
|
|
123
|
+
existing_module = existing_config.handler.__module__
|
|
124
|
+
new_module = config.handler.__module__
|
|
125
|
+
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Function name collision: '{config.name}' is already registered.\n"
|
|
128
|
+
f" Existing: {existing_module}.{existing_config.handler.__name__}\n"
|
|
129
|
+
f" New: {new_module}.{config.handler.__name__}\n"
|
|
130
|
+
f"Please use a different function name or use name= parameter to specify a unique name."
|
|
131
|
+
)
|
|
132
|
+
|
|
28
133
|
_FUNCTION_REGISTRY[config.name] = config
|
|
29
134
|
|
|
30
135
|
@staticmethod
|
|
@@ -43,216 +148,72 @@ class FunctionRegistry:
|
|
|
43
148
|
_FUNCTION_REGISTRY.clear()
|
|
44
149
|
|
|
45
150
|
|
|
46
|
-
def _type_to_json_schema(python_type: Any) -> Dict[str, Any]:
|
|
47
|
-
"""Convert Python type hint to JSON Schema."""
|
|
48
|
-
# Handle None type
|
|
49
|
-
if python_type is type(None):
|
|
50
|
-
return {"type": "null"}
|
|
51
|
-
|
|
52
|
-
# Handle basic types
|
|
53
|
-
if python_type is str:
|
|
54
|
-
return {"type": "string"}
|
|
55
|
-
if python_type is int:
|
|
56
|
-
return {"type": "integer"}
|
|
57
|
-
if python_type is float:
|
|
58
|
-
return {"type": "number"}
|
|
59
|
-
if python_type is bool:
|
|
60
|
-
return {"type": "boolean"}
|
|
61
|
-
|
|
62
|
-
# Handle typing module types
|
|
63
|
-
origin = getattr(python_type, "__origin__", None)
|
|
64
|
-
|
|
65
|
-
if origin is list:
|
|
66
|
-
args = getattr(python_type, "__args__", ())
|
|
67
|
-
if args:
|
|
68
|
-
return {"type": "array", "items": _type_to_json_schema(args[0])}
|
|
69
|
-
return {"type": "array"}
|
|
70
|
-
|
|
71
|
-
if origin is dict:
|
|
72
|
-
return {"type": "object"}
|
|
73
|
-
|
|
74
|
-
if origin is Union:
|
|
75
|
-
args = getattr(python_type, "__args__", ())
|
|
76
|
-
# Handle Optional[T] (Union[T, None])
|
|
77
|
-
if len(args) == 2 and type(None) in args:
|
|
78
|
-
non_none = args[0] if args[1] is type(None) else args[1]
|
|
79
|
-
schema = _type_to_json_schema(non_none)
|
|
80
|
-
# Mark as nullable in JSON Schema
|
|
81
|
-
return {**schema, "nullable": True}
|
|
82
|
-
# Handle other unions as anyOf
|
|
83
|
-
return {"anyOf": [_type_to_json_schema(arg) for arg in args]}
|
|
84
|
-
|
|
85
|
-
# Default to object for unknown types
|
|
86
|
-
return {"type": "object"}
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def _extract_function_schemas(func: Callable[..., Any]) -> tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
90
|
-
"""Extract input and output schemas from function type hints.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Tuple of (input_schema, output_schema) where each is a JSON Schema dict or None
|
|
94
|
-
"""
|
|
95
|
-
try:
|
|
96
|
-
# Get type hints
|
|
97
|
-
hints = get_type_hints(func)
|
|
98
|
-
sig = inspect.signature(func)
|
|
99
|
-
|
|
100
|
-
# Build input schema from parameters (excluding 'ctx')
|
|
101
|
-
input_properties = {}
|
|
102
|
-
required_params = []
|
|
103
|
-
|
|
104
|
-
for param_name, param in sig.parameters.items():
|
|
105
|
-
if param_name == "ctx":
|
|
106
|
-
continue
|
|
107
|
-
|
|
108
|
-
# Get type hint for this parameter
|
|
109
|
-
if param_name in hints:
|
|
110
|
-
param_type = hints[param_name]
|
|
111
|
-
input_properties[param_name] = _type_to_json_schema(param_type)
|
|
112
|
-
else:
|
|
113
|
-
# No type hint, use generic object
|
|
114
|
-
input_properties[param_name] = {"type": "object"}
|
|
115
|
-
|
|
116
|
-
# Check if parameter is required (no default value)
|
|
117
|
-
if param.default is inspect.Parameter.empty:
|
|
118
|
-
required_params.append(param_name)
|
|
119
|
-
|
|
120
|
-
input_schema = None
|
|
121
|
-
if input_properties:
|
|
122
|
-
input_schema = {
|
|
123
|
-
"type": "object",
|
|
124
|
-
"properties": input_properties,
|
|
125
|
-
}
|
|
126
|
-
if required_params:
|
|
127
|
-
input_schema["required"] = required_params
|
|
128
|
-
|
|
129
|
-
# Add description from docstring if available
|
|
130
|
-
if func.__doc__:
|
|
131
|
-
docstring = inspect.cleandoc(func.__doc__)
|
|
132
|
-
first_line = docstring.split('\n')[0].strip()
|
|
133
|
-
if first_line:
|
|
134
|
-
input_schema["description"] = first_line
|
|
135
|
-
|
|
136
|
-
# Build output schema from return type hint
|
|
137
|
-
output_schema = None
|
|
138
|
-
if "return" in hints:
|
|
139
|
-
return_type = hints["return"]
|
|
140
|
-
output_schema = _type_to_json_schema(return_type)
|
|
141
|
-
|
|
142
|
-
return input_schema, output_schema
|
|
143
|
-
|
|
144
|
-
except Exception:
|
|
145
|
-
# If schema extraction fails, return None schemas
|
|
146
|
-
return None, None
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
def _extract_function_metadata(func: Callable[..., Any]) -> Dict[str, str]:
|
|
150
|
-
"""Extract metadata from function including description from docstring.
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
Dictionary with metadata fields like 'description'
|
|
154
|
-
"""
|
|
155
|
-
metadata = {}
|
|
156
|
-
|
|
157
|
-
# Extract description from docstring
|
|
158
|
-
if func.__doc__:
|
|
159
|
-
# Get first line of docstring as description
|
|
160
|
-
docstring = inspect.cleandoc(func.__doc__)
|
|
161
|
-
first_line = docstring.split('\n')[0].strip()
|
|
162
|
-
if first_line:
|
|
163
|
-
metadata["description"] = first_line
|
|
164
|
-
|
|
165
|
-
return metadata
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def _calculate_backoff_delay(
|
|
169
|
-
attempt: int,
|
|
170
|
-
retry_policy: RetryPolicy,
|
|
171
|
-
backoff_policy: BackoffPolicy,
|
|
172
|
-
) -> float:
|
|
173
|
-
"""Calculate backoff delay in seconds based on attempt number."""
|
|
174
|
-
if backoff_policy.type == BackoffType.CONSTANT:
|
|
175
|
-
delay_ms = retry_policy.initial_interval_ms
|
|
176
|
-
elif backoff_policy.type == BackoffType.LINEAR:
|
|
177
|
-
delay_ms = retry_policy.initial_interval_ms * (attempt + 1)
|
|
178
|
-
else: # EXPONENTIAL
|
|
179
|
-
delay_ms = retry_policy.initial_interval_ms * (backoff_policy.multiplier**attempt)
|
|
180
|
-
|
|
181
|
-
# Cap at max_interval_ms
|
|
182
|
-
delay_ms = min(delay_ms, retry_policy.max_interval_ms)
|
|
183
|
-
return delay_ms / 1000.0 # Convert to seconds
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
async def _execute_with_retry(
|
|
187
|
-
handler: HandlerFunc,
|
|
188
|
-
ctx: Context,
|
|
189
|
-
retry_policy: RetryPolicy,
|
|
190
|
-
backoff_policy: BackoffPolicy,
|
|
191
|
-
*args: Any,
|
|
192
|
-
**kwargs: Any,
|
|
193
|
-
) -> Any:
|
|
194
|
-
"""Execute handler with retry logic."""
|
|
195
|
-
last_error: Optional[Exception] = None
|
|
196
|
-
|
|
197
|
-
for attempt in range(retry_policy.max_attempts):
|
|
198
|
-
try:
|
|
199
|
-
# Update context attempt number
|
|
200
|
-
ctx._attempt = attempt
|
|
201
|
-
|
|
202
|
-
# Execute handler
|
|
203
|
-
result = await handler(ctx, *args, **kwargs)
|
|
204
|
-
return result
|
|
205
|
-
|
|
206
|
-
except Exception as e:
|
|
207
|
-
last_error = e
|
|
208
|
-
ctx.logger.warning(
|
|
209
|
-
f"Function execution failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}"
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
# If this was the last attempt, raise RetryError
|
|
213
|
-
if attempt == retry_policy.max_attempts - 1:
|
|
214
|
-
raise RetryError(
|
|
215
|
-
f"Function failed after {retry_policy.max_attempts} attempts",
|
|
216
|
-
attempts=retry_policy.max_attempts,
|
|
217
|
-
last_error=e,
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
# Calculate backoff delay
|
|
221
|
-
delay = _calculate_backoff_delay(attempt, retry_policy, backoff_policy)
|
|
222
|
-
ctx.logger.info(f"Retrying in {delay:.2f} seconds...")
|
|
223
|
-
await asyncio.sleep(delay)
|
|
224
|
-
|
|
225
|
-
# Should never reach here, but for type safety
|
|
226
|
-
assert last_error is not None
|
|
227
|
-
raise RetryError(
|
|
228
|
-
f"Function failed after {retry_policy.max_attempts} attempts",
|
|
229
|
-
attempts=retry_policy.max_attempts,
|
|
230
|
-
last_error=last_error,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
|
|
234
151
|
def function(
|
|
235
152
|
_func: Optional[Callable[..., Any]] = None,
|
|
236
153
|
*,
|
|
237
154
|
name: Optional[str] = None,
|
|
238
|
-
retries: Optional[RetryPolicy] = None,
|
|
239
|
-
backoff: Optional[BackoffPolicy] = None,
|
|
155
|
+
retries: Optional[Union[int, Dict[str, Any], RetryPolicy]] = None,
|
|
156
|
+
backoff: Optional[Union[str, Dict[str, Any], BackoffPolicy]] = None,
|
|
240
157
|
) -> Callable[..., Any]:
|
|
241
158
|
"""
|
|
242
159
|
Decorator to mark a function as an AGNT5 durable function.
|
|
243
160
|
|
|
244
161
|
Args:
|
|
245
162
|
name: Custom function name (default: function's __name__)
|
|
246
|
-
retries: Retry policy configuration
|
|
247
|
-
|
|
163
|
+
retries: Retry policy configuration. Can be:
|
|
164
|
+
- int: max attempts (e.g., 5)
|
|
165
|
+
- dict: RetryPolicy params (e.g., {"max_attempts": 5, "initial_interval_ms": 1000})
|
|
166
|
+
- RetryPolicy: full policy object
|
|
167
|
+
backoff: Backoff policy for retries. Can be:
|
|
168
|
+
- str: backoff type ("constant", "linear", "exponential")
|
|
169
|
+
- dict: BackoffPolicy params (e.g., {"type": "exponential", "multiplier": 2.0})
|
|
170
|
+
- BackoffPolicy: full policy object
|
|
171
|
+
|
|
172
|
+
Note:
|
|
173
|
+
Sync Functions: Synchronous functions are automatically executed in a thread pool
|
|
174
|
+
to prevent blocking the event loop. This is ideal for I/O-bound operations
|
|
175
|
+
(requests.get(), file I/O, etc.). For CPU-bound operations or when you need
|
|
176
|
+
explicit control over concurrency, use async functions instead.
|
|
248
177
|
|
|
249
178
|
Example:
|
|
179
|
+
# Basic function with context
|
|
250
180
|
@function
|
|
251
|
-
async def greet(ctx:
|
|
181
|
+
async def greet(ctx: FunctionContext, name: str) -> str:
|
|
182
|
+
ctx.log(f"Greeting {name}") # AGNT5 shorthand!
|
|
252
183
|
return f"Hello, {name}!"
|
|
253
184
|
|
|
254
|
-
|
|
255
|
-
|
|
185
|
+
# Simple function without context (optional)
|
|
186
|
+
@function
|
|
187
|
+
async def add(a: int, b: int) -> int:
|
|
188
|
+
return a + b
|
|
189
|
+
|
|
190
|
+
# With Pydantic models (automatic validation + rich schemas)
|
|
191
|
+
from pydantic import BaseModel
|
|
192
|
+
|
|
193
|
+
class UserInput(BaseModel):
|
|
194
|
+
name: str
|
|
195
|
+
age: int
|
|
196
|
+
|
|
197
|
+
class UserOutput(BaseModel):
|
|
198
|
+
greeting: str
|
|
199
|
+
is_adult: bool
|
|
200
|
+
|
|
201
|
+
@function
|
|
202
|
+
async def process_user(ctx: FunctionContext, user: UserInput) -> UserOutput:
|
|
203
|
+
ctx.log(f"Processing user {user.name}")
|
|
204
|
+
return UserOutput(
|
|
205
|
+
greeting=f"Hello, {user.name}!",
|
|
206
|
+
is_adult=user.age >= 18
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Simple retry count
|
|
210
|
+
@function(retries=5)
|
|
211
|
+
async def with_retries(data: str) -> str:
|
|
212
|
+
return data.upper()
|
|
213
|
+
|
|
214
|
+
# Dict configuration
|
|
215
|
+
@function(retries={"max_attempts": 5}, backoff="exponential")
|
|
216
|
+
async def advanced(a: int, b: int) -> int:
|
|
256
217
|
return a + b
|
|
257
218
|
"""
|
|
258
219
|
|
|
@@ -260,39 +221,43 @@ def function(
|
|
|
260
221
|
# Get function name
|
|
261
222
|
func_name = name or func.__name__
|
|
262
223
|
|
|
263
|
-
# Validate function signature
|
|
224
|
+
# Validate function signature and check if context is needed
|
|
264
225
|
sig = inspect.signature(func)
|
|
265
226
|
params = list(sig.parameters.values())
|
|
266
227
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
f"Function '{func_name}' must have 'ctx: Context' as first parameter"
|
|
270
|
-
)
|
|
228
|
+
# Check if function declares 'ctx' parameter
|
|
229
|
+
needs_context = params and params[0].name == "ctx"
|
|
271
230
|
|
|
272
231
|
# Convert sync to async if needed
|
|
273
232
|
# Note: Async generators should NOT be wrapped - they need to be returned as-is
|
|
274
233
|
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
|
|
275
234
|
handler_func = cast(HandlerFunc, func)
|
|
276
235
|
else:
|
|
277
|
-
# Wrap sync function in
|
|
236
|
+
# Wrap sync function to run in thread pool (prevents blocking event loop)
|
|
278
237
|
@functools.wraps(func)
|
|
279
238
|
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
280
|
-
|
|
239
|
+
loop = asyncio.get_running_loop()
|
|
240
|
+
# Run sync function in thread pool executor to prevent blocking
|
|
241
|
+
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
|
281
242
|
|
|
282
243
|
handler_func = cast(HandlerFunc, async_wrapper)
|
|
283
244
|
|
|
284
245
|
# Extract schemas from type hints
|
|
285
|
-
input_schema, output_schema =
|
|
246
|
+
input_schema, output_schema = extract_function_schemas(func)
|
|
286
247
|
|
|
287
248
|
# Extract metadata (description, etc.)
|
|
288
|
-
metadata =
|
|
249
|
+
metadata = extract_function_metadata(func)
|
|
250
|
+
|
|
251
|
+
# Parse retry and backoff policies from flexible formats
|
|
252
|
+
retry_policy = parse_retry_policy(retries)
|
|
253
|
+
backoff_policy = parse_backoff_policy(backoff)
|
|
289
254
|
|
|
290
255
|
# Register function
|
|
291
256
|
config = FunctionConfig(
|
|
292
257
|
name=func_name,
|
|
293
258
|
handler=handler_func,
|
|
294
|
-
retries=
|
|
295
|
-
backoff=
|
|
259
|
+
retries=retry_policy,
|
|
260
|
+
backoff=backoff_policy,
|
|
296
261
|
input_schema=input_schema,
|
|
297
262
|
output_schema=output_schema,
|
|
298
263
|
metadata=metadata,
|
|
@@ -302,33 +267,41 @@ def function(
|
|
|
302
267
|
# Create wrapper with retry logic
|
|
303
268
|
@functools.wraps(func)
|
|
304
269
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
305
|
-
#
|
|
306
|
-
if
|
|
307
|
-
#
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
return await _execute_with_retry(
|
|
314
|
-
handler_func,
|
|
315
|
-
ctx,
|
|
316
|
-
config.retries or RetryPolicy(),
|
|
317
|
-
config.backoff or BackoffPolicy(),
|
|
318
|
-
*args,
|
|
319
|
-
**kwargs,
|
|
320
|
-
)
|
|
321
|
-
else:
|
|
322
|
-
# Context provided - use it
|
|
270
|
+
# Extract or create context based on function signature
|
|
271
|
+
if needs_context:
|
|
272
|
+
# Function declares ctx parameter - first argument must be FunctionContext
|
|
273
|
+
if not args or not isinstance(args[0], FunctionContext):
|
|
274
|
+
raise TypeError(
|
|
275
|
+
f"Function '{func_name}' requires FunctionContext as first argument. "
|
|
276
|
+
f"Usage: await {func_name}(ctx, ...)"
|
|
277
|
+
)
|
|
323
278
|
ctx = args[0]
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
279
|
+
func_args = args[1:]
|
|
280
|
+
else:
|
|
281
|
+
# Function doesn't use context - create a minimal one for internal use
|
|
282
|
+
# But first check if a context was passed anyway (for Worker execution)
|
|
283
|
+
if args and isinstance(args[0], FunctionContext):
|
|
284
|
+
# Context was provided by Worker - use it but don't pass to function
|
|
285
|
+
ctx = args[0]
|
|
286
|
+
func_args = args[1:]
|
|
287
|
+
else:
|
|
288
|
+
# No context provided - create a default one
|
|
289
|
+
ctx = FunctionContext(
|
|
290
|
+
run_id=f"local-{uuid.uuid4().hex[:8]}",
|
|
291
|
+
retry_policy=retry_policy
|
|
292
|
+
)
|
|
293
|
+
func_args = args
|
|
294
|
+
|
|
295
|
+
# Execute with retry
|
|
296
|
+
return await execute_with_retry(
|
|
297
|
+
handler_func,
|
|
298
|
+
ctx,
|
|
299
|
+
config.retries or RetryPolicy(),
|
|
300
|
+
config.backoff or BackoffPolicy(),
|
|
301
|
+
needs_context,
|
|
302
|
+
*func_args,
|
|
303
|
+
**kwargs,
|
|
304
|
+
)
|
|
332
305
|
|
|
333
306
|
# Store config on wrapper for introspection
|
|
334
307
|
wrapper._agnt5_config = config # type: ignore
|
agnt5/lm.py
CHANGED
|
@@ -32,10 +32,13 @@ Supported Providers (via model prefix):
|
|
|
32
32
|
|
|
33
33
|
from __future__ import annotations
|
|
34
34
|
|
|
35
|
+
import json
|
|
35
36
|
from dataclasses import dataclass, field
|
|
36
37
|
from enum import Enum
|
|
37
38
|
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
38
39
|
|
|
40
|
+
from ._schema_utils import detect_format_type
|
|
41
|
+
|
|
39
42
|
try:
|
|
40
43
|
from ._core import LanguageModel as RustLanguageModel
|
|
41
44
|
from ._core import LanguageModelConfig as RustLanguageModelConfig
|
|
@@ -160,6 +163,39 @@ class GenerateResponse:
|
|
|
160
163
|
usage: Optional[TokenUsage] = None
|
|
161
164
|
finish_reason: Optional[str] = None
|
|
162
165
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
166
|
+
_rust_response: Optional[Any] = field(default=None, repr=False)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def structured_output(self) -> Optional[Any]:
|
|
170
|
+
"""Parsed structured output (Pydantic model, dataclass, or dict).
|
|
171
|
+
|
|
172
|
+
Returns the parsed object when response_format is specified.
|
|
173
|
+
This is the recommended property name for accessing structured output.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Parsed object according to the specified response_format, or None if not available
|
|
177
|
+
"""
|
|
178
|
+
if self._rust_response and hasattr(self._rust_response, 'object'):
|
|
179
|
+
return self._rust_response.object
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def parsed(self) -> Optional[Any]:
|
|
184
|
+
"""Alias for structured_output (OpenAI SDK compatibility).
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Same as structured_output
|
|
188
|
+
"""
|
|
189
|
+
return self.structured_output
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def object(self) -> Optional[Any]:
|
|
193
|
+
"""Alias for structured_output.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Same as structured_output
|
|
197
|
+
"""
|
|
198
|
+
return self.structured_output
|
|
163
199
|
|
|
164
200
|
|
|
165
201
|
@dataclass
|
|
@@ -172,6 +208,7 @@ class GenerateRequest:
|
|
|
172
208
|
tools: List[ToolDefinition] = field(default_factory=list)
|
|
173
209
|
tool_choice: Optional[ToolChoice] = None
|
|
174
210
|
config: GenerationConfig = field(default_factory=GenerationConfig)
|
|
211
|
+
response_schema: Optional[str] = None # JSON-encoded schema for structured output
|
|
175
212
|
|
|
176
213
|
|
|
177
214
|
# Internal wrapper for the Rust-backed implementation
|
|
@@ -271,14 +308,19 @@ class _LanguageModel:
|
|
|
271
308
|
if request.config.top_p is not None:
|
|
272
309
|
kwargs["top_p"] = request.config.top_p
|
|
273
310
|
|
|
311
|
+
# Pass response schema for structured output if provided
|
|
312
|
+
if request.response_schema is not None:
|
|
313
|
+
kwargs["response_schema_kw"] = request.response_schema
|
|
314
|
+
|
|
274
315
|
# TODO: Add tools and tool_choice support when needed
|
|
275
316
|
# if request.tools:
|
|
276
317
|
# kwargs["tools"] = self._serialize_tools(request.tools)
|
|
277
318
|
# if request.tool_choice:
|
|
278
319
|
# kwargs["tool_choice"] = request.tool_choice.value
|
|
279
320
|
|
|
280
|
-
# Call Rust implementation
|
|
281
|
-
|
|
321
|
+
# Call Rust implementation - it returns a proper Python coroutine now
|
|
322
|
+
# Using pyo3-async-runtimes for truly async HTTP calls without blocking
|
|
323
|
+
rust_response = await self._rust_lm.generate(prompt=prompt, **kwargs)
|
|
282
324
|
|
|
283
325
|
# Convert Rust response to Python
|
|
284
326
|
return self._convert_response(rust_response)
|
|
@@ -326,8 +368,9 @@ class _LanguageModel:
|
|
|
326
368
|
# if request.tool_choice:
|
|
327
369
|
# kwargs["tool_choice"] = request.tool_choice.value
|
|
328
370
|
|
|
329
|
-
# Call Rust implementation
|
|
330
|
-
|
|
371
|
+
# Call Rust implementation - it returns a proper Python coroutine now
|
|
372
|
+
# Using pyo3-async-runtimes for truly async streaming without blocking
|
|
373
|
+
rust_chunks = await self._rust_lm.stream(prompt=prompt, **kwargs)
|
|
331
374
|
|
|
332
375
|
# Yield each chunk
|
|
333
376
|
for chunk in rust_chunks:
|
|
@@ -378,6 +421,7 @@ class _LanguageModel:
|
|
|
378
421
|
usage=usage,
|
|
379
422
|
finish_reason=None, # TODO: Add finish_reason to Rust response
|
|
380
423
|
tool_calls=None, # TODO: Add tool calls support
|
|
424
|
+
_rust_response=rust_response, # Store for .structured_output access
|
|
381
425
|
)
|
|
382
426
|
|
|
383
427
|
|
|
@@ -394,6 +438,7 @@ async def generate(
|
|
|
394
438
|
temperature: Optional[float] = None,
|
|
395
439
|
max_tokens: Optional[int] = None,
|
|
396
440
|
top_p: Optional[float] = None,
|
|
441
|
+
response_format: Optional[Any] = None,
|
|
397
442
|
) -> GenerateResponse:
|
|
398
443
|
"""Generate text using any LLM provider (simplified API).
|
|
399
444
|
|
|
@@ -408,9 +453,10 @@ async def generate(
|
|
|
408
453
|
temperature: Sampling temperature (0.0-2.0)
|
|
409
454
|
max_tokens: Maximum tokens to generate
|
|
410
455
|
top_p: Nucleus sampling parameter
|
|
456
|
+
response_format: Pydantic model, dataclass, or JSON schema dict for structured output
|
|
411
457
|
|
|
412
458
|
Returns:
|
|
413
|
-
GenerateResponse with text and
|
|
459
|
+
GenerateResponse with text, usage, and optional structured output
|
|
414
460
|
|
|
415
461
|
Examples:
|
|
416
462
|
Simple prompt:
|
|
@@ -421,15 +467,21 @@ async def generate(
|
|
|
421
467
|
... )
|
|
422
468
|
>>> print(response.text)
|
|
423
469
|
|
|
424
|
-
|
|
470
|
+
Structured output with dataclass:
|
|
471
|
+
>>> from dataclasses import dataclass
|
|
472
|
+
>>>
|
|
473
|
+
>>> @dataclass
|
|
474
|
+
... class CodeReview:
|
|
475
|
+
... issues: list[str]
|
|
476
|
+
... suggestions: list[str]
|
|
477
|
+
... overall_quality: int
|
|
478
|
+
>>>
|
|
425
479
|
>>> response = await generate(
|
|
426
|
-
... model="
|
|
427
|
-
...
|
|
428
|
-
...
|
|
429
|
-
... {"role": "assistant", "content": "Calculus is..."},
|
|
430
|
-
... {"role": "user", "content": "Give me an example"}
|
|
431
|
-
... ]
|
|
480
|
+
... model="openai/gpt-4o",
|
|
481
|
+
... prompt="Analyze this code...",
|
|
482
|
+
... response_format=CodeReview
|
|
432
483
|
... )
|
|
484
|
+
>>> review = response.structured_output # Returns dict
|
|
433
485
|
"""
|
|
434
486
|
# Validate input
|
|
435
487
|
if not prompt and not messages:
|
|
@@ -446,6 +498,12 @@ async def generate(
|
|
|
446
498
|
|
|
447
499
|
provider, model_name = model.split('/', 1)
|
|
448
500
|
|
|
501
|
+
# Convert response_format to JSON schema if provided
|
|
502
|
+
response_schema_json = None
|
|
503
|
+
if response_format is not None:
|
|
504
|
+
format_type, json_schema = detect_format_type(response_format)
|
|
505
|
+
response_schema_json = json.dumps(json_schema)
|
|
506
|
+
|
|
449
507
|
# Create language model client
|
|
450
508
|
lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
451
509
|
|
|
@@ -478,6 +536,7 @@ async def generate(
|
|
|
478
536
|
messages=message_objects,
|
|
479
537
|
system_prompt=system_prompt,
|
|
480
538
|
config=config,
|
|
539
|
+
response_schema=response_schema_json,
|
|
481
540
|
)
|
|
482
541
|
|
|
483
542
|
# Generate and return
|