agnt5 0.2.2__cp39-abi3-macosx_11_0_arm64.whl → 0.2.4__cp39-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agnt5 might be problematic. Click here for more details.

agnt5/function.py CHANGED
@@ -5,10 +5,11 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import functools
7
7
  import inspect
8
- import time
9
8
  import uuid
10
- from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast, get_type_hints
9
+ from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
11
10
 
11
+ from ._retry_utils import execute_with_retry, parse_backoff_policy, parse_retry_policy
12
+ from ._schema_utils import extract_function_metadata, extract_function_schemas
12
13
  from .context import Context
13
14
  from .exceptions import RetryError
14
15
  from .types import BackoffPolicy, BackoffType, FunctionConfig, HandlerFunc, RetryPolicy
@@ -18,13 +19,117 @@ T = TypeVar("T")
18
19
  # Global function registry
19
20
  _FUNCTION_REGISTRY: Dict[str, FunctionConfig] = {}
20
21
 
22
+ class FunctionContext(Context):
23
+ """
24
+ Lightweight context for stateless functions.
25
+
26
+ AGNT5 Philosophy: Context is a convenience, not a requirement.
27
+ The best function is one that doesn't need context at all!
28
+
29
+ Provides only:
30
+ - Quick logging (ctx.log())
31
+ - Execution metadata (run_id, attempt)
32
+ - Smart retry helper (should_retry())
33
+ - Non-durable sleep
34
+
35
+ Does NOT provide:
36
+ - Orchestration (task, parallel, gather) - use workflows
37
+ - State management (get, set, delete) - functions are stateless
38
+ - Checkpointing (step) - functions are atomic
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ run_id: str,
44
+ attempt: int = 0,
45
+ runtime_context: Optional[Any] = None,
46
+ retry_policy: Optional[Any] = None,
47
+ ) -> None:
48
+ """
49
+ Initialize function context.
50
+
51
+ Args:
52
+ run_id: Unique execution identifier
53
+ attempt: Retry attempt number (0-indexed)
54
+ runtime_context: RuntimeContext for trace correlation
55
+ retry_policy: RetryPolicy for should_retry() checks
56
+ """
57
+ super().__init__(run_id, attempt, runtime_context)
58
+ self._retry_policy = retry_policy
59
+
60
+ # === Quick Logging ===
61
+
62
+ def log(self, message: str, **extra) -> None:
63
+ """
64
+ Quick logging shorthand with structured data.
65
+
66
+ Example:
67
+ ctx.log("Processing payment", amount=100.50, user_id="123")
68
+ """
69
+ self._logger.info(message, extra=extra)
70
+
71
+ # === Smart Execution ===
72
+
73
+ def should_retry(self, error: Exception) -> bool:
74
+ """
75
+ Check if error is retryable based on configured policy.
76
+
77
+ Example:
78
+ try:
79
+ result = await external_api()
80
+ except Exception as e:
81
+ if not ctx.should_retry(e):
82
+ raise # Fail fast for non-retryable errors
83
+ # Otherwise let retry policy handle it
84
+ raise
85
+
86
+ Returns:
87
+ True if error is retryable, False otherwise
88
+ """
89
+ # TODO: Implement retry policy checks
90
+ # For now, all errors are retryable (let retry policy handle it)
91
+ return True
92
+
93
+ async def sleep(self, seconds: float) -> None:
94
+ """
95
+ Non-durable async sleep.
96
+
97
+ For durable sleep across failures, use workflows.
98
+
99
+ Args:
100
+ seconds: Number of seconds to sleep
101
+ """
102
+ import asyncio
103
+ await asyncio.sleep(seconds)
104
+
105
+
21
106
 
22
107
  class FunctionRegistry:
23
108
  """Registry for function handlers."""
24
109
 
25
110
  @staticmethod
26
111
  def register(config: FunctionConfig) -> None:
27
- """Register a function handler."""
112
+ """Register a function handler.
113
+
114
+ Args:
115
+ config: Function configuration to register
116
+
117
+ Raises:
118
+ ValueError: If a function with the same name is already registered
119
+ """
120
+ # Check for name collision
121
+ if config.name in _FUNCTION_REGISTRY:
122
+ existing_config = _FUNCTION_REGISTRY[config.name]
123
+ existing_module = existing_config.handler.__module__
124
+ new_module = config.handler.__module__
125
+
126
+ raise ValueError(
127
+ f"Function name collision: '{config.name}' is already registered.\n"
128
+ f" Existing: {existing_module}.{existing_config.handler.__name__}\n"
129
+ f" New: {new_module}.{config.handler.__name__}\n"
130
+ f"Please use a different function name or use name= parameter to specify a unique name."
131
+ )
132
+
28
133
  _FUNCTION_REGISTRY[config.name] = config
29
134
 
30
135
  @staticmethod
@@ -43,216 +148,72 @@ class FunctionRegistry:
43
148
  _FUNCTION_REGISTRY.clear()
44
149
 
45
150
 
46
- def _type_to_json_schema(python_type: Any) -> Dict[str, Any]:
47
- """Convert Python type hint to JSON Schema."""
48
- # Handle None type
49
- if python_type is type(None):
50
- return {"type": "null"}
51
-
52
- # Handle basic types
53
- if python_type is str:
54
- return {"type": "string"}
55
- if python_type is int:
56
- return {"type": "integer"}
57
- if python_type is float:
58
- return {"type": "number"}
59
- if python_type is bool:
60
- return {"type": "boolean"}
61
-
62
- # Handle typing module types
63
- origin = getattr(python_type, "__origin__", None)
64
-
65
- if origin is list:
66
- args = getattr(python_type, "__args__", ())
67
- if args:
68
- return {"type": "array", "items": _type_to_json_schema(args[0])}
69
- return {"type": "array"}
70
-
71
- if origin is dict:
72
- return {"type": "object"}
73
-
74
- if origin is Union:
75
- args = getattr(python_type, "__args__", ())
76
- # Handle Optional[T] (Union[T, None])
77
- if len(args) == 2 and type(None) in args:
78
- non_none = args[0] if args[1] is type(None) else args[1]
79
- schema = _type_to_json_schema(non_none)
80
- # Mark as nullable in JSON Schema
81
- return {**schema, "nullable": True}
82
- # Handle other unions as anyOf
83
- return {"anyOf": [_type_to_json_schema(arg) for arg in args]}
84
-
85
- # Default to object for unknown types
86
- return {"type": "object"}
87
-
88
-
89
- def _extract_function_schemas(func: Callable[..., Any]) -> tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
90
- """Extract input and output schemas from function type hints.
91
-
92
- Returns:
93
- Tuple of (input_schema, output_schema) where each is a JSON Schema dict or None
94
- """
95
- try:
96
- # Get type hints
97
- hints = get_type_hints(func)
98
- sig = inspect.signature(func)
99
-
100
- # Build input schema from parameters (excluding 'ctx')
101
- input_properties = {}
102
- required_params = []
103
-
104
- for param_name, param in sig.parameters.items():
105
- if param_name == "ctx":
106
- continue
107
-
108
- # Get type hint for this parameter
109
- if param_name in hints:
110
- param_type = hints[param_name]
111
- input_properties[param_name] = _type_to_json_schema(param_type)
112
- else:
113
- # No type hint, use generic object
114
- input_properties[param_name] = {"type": "object"}
115
-
116
- # Check if parameter is required (no default value)
117
- if param.default is inspect.Parameter.empty:
118
- required_params.append(param_name)
119
-
120
- input_schema = None
121
- if input_properties:
122
- input_schema = {
123
- "type": "object",
124
- "properties": input_properties,
125
- }
126
- if required_params:
127
- input_schema["required"] = required_params
128
-
129
- # Add description from docstring if available
130
- if func.__doc__:
131
- docstring = inspect.cleandoc(func.__doc__)
132
- first_line = docstring.split('\n')[0].strip()
133
- if first_line:
134
- input_schema["description"] = first_line
135
-
136
- # Build output schema from return type hint
137
- output_schema = None
138
- if "return" in hints:
139
- return_type = hints["return"]
140
- output_schema = _type_to_json_schema(return_type)
141
-
142
- return input_schema, output_schema
143
-
144
- except Exception:
145
- # If schema extraction fails, return None schemas
146
- return None, None
147
-
148
-
149
- def _extract_function_metadata(func: Callable[..., Any]) -> Dict[str, str]:
150
- """Extract metadata from function including description from docstring.
151
-
152
- Returns:
153
- Dictionary with metadata fields like 'description'
154
- """
155
- metadata = {}
156
-
157
- # Extract description from docstring
158
- if func.__doc__:
159
- # Get first line of docstring as description
160
- docstring = inspect.cleandoc(func.__doc__)
161
- first_line = docstring.split('\n')[0].strip()
162
- if first_line:
163
- metadata["description"] = first_line
164
-
165
- return metadata
166
-
167
-
168
- def _calculate_backoff_delay(
169
- attempt: int,
170
- retry_policy: RetryPolicy,
171
- backoff_policy: BackoffPolicy,
172
- ) -> float:
173
- """Calculate backoff delay in seconds based on attempt number."""
174
- if backoff_policy.type == BackoffType.CONSTANT:
175
- delay_ms = retry_policy.initial_interval_ms
176
- elif backoff_policy.type == BackoffType.LINEAR:
177
- delay_ms = retry_policy.initial_interval_ms * (attempt + 1)
178
- else: # EXPONENTIAL
179
- delay_ms = retry_policy.initial_interval_ms * (backoff_policy.multiplier**attempt)
180
-
181
- # Cap at max_interval_ms
182
- delay_ms = min(delay_ms, retry_policy.max_interval_ms)
183
- return delay_ms / 1000.0 # Convert to seconds
184
-
185
-
186
- async def _execute_with_retry(
187
- handler: HandlerFunc,
188
- ctx: Context,
189
- retry_policy: RetryPolicy,
190
- backoff_policy: BackoffPolicy,
191
- *args: Any,
192
- **kwargs: Any,
193
- ) -> Any:
194
- """Execute handler with retry logic."""
195
- last_error: Optional[Exception] = None
196
-
197
- for attempt in range(retry_policy.max_attempts):
198
- try:
199
- # Update context attempt number
200
- ctx._attempt = attempt
201
-
202
- # Execute handler
203
- result = await handler(ctx, *args, **kwargs)
204
- return result
205
-
206
- except Exception as e:
207
- last_error = e
208
- ctx.logger.warning(
209
- f"Function execution failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}"
210
- )
211
-
212
- # If this was the last attempt, raise RetryError
213
- if attempt == retry_policy.max_attempts - 1:
214
- raise RetryError(
215
- f"Function failed after {retry_policy.max_attempts} attempts",
216
- attempts=retry_policy.max_attempts,
217
- last_error=e,
218
- )
219
-
220
- # Calculate backoff delay
221
- delay = _calculate_backoff_delay(attempt, retry_policy, backoff_policy)
222
- ctx.logger.info(f"Retrying in {delay:.2f} seconds...")
223
- await asyncio.sleep(delay)
224
-
225
- # Should never reach here, but for type safety
226
- assert last_error is not None
227
- raise RetryError(
228
- f"Function failed after {retry_policy.max_attempts} attempts",
229
- attempts=retry_policy.max_attempts,
230
- last_error=last_error,
231
- )
232
-
233
-
234
151
  def function(
235
152
  _func: Optional[Callable[..., Any]] = None,
236
153
  *,
237
154
  name: Optional[str] = None,
238
- retries: Optional[RetryPolicy] = None,
239
- backoff: Optional[BackoffPolicy] = None,
155
+ retries: Optional[Union[int, Dict[str, Any], RetryPolicy]] = None,
156
+ backoff: Optional[Union[str, Dict[str, Any], BackoffPolicy]] = None,
240
157
  ) -> Callable[..., Any]:
241
158
  """
242
159
  Decorator to mark a function as an AGNT5 durable function.
243
160
 
244
161
  Args:
245
162
  name: Custom function name (default: function's __name__)
246
- retries: Retry policy configuration
247
- backoff: Backoff policy for retries
163
+ retries: Retry policy configuration. Can be:
164
+ - int: max attempts (e.g., 5)
165
+ - dict: RetryPolicy params (e.g., {"max_attempts": 5, "initial_interval_ms": 1000})
166
+ - RetryPolicy: full policy object
167
+ backoff: Backoff policy for retries. Can be:
168
+ - str: backoff type ("constant", "linear", "exponential")
169
+ - dict: BackoffPolicy params (e.g., {"type": "exponential", "multiplier": 2.0})
170
+ - BackoffPolicy: full policy object
171
+
172
+ Note:
173
+ Sync Functions: Synchronous functions are automatically executed in a thread pool
174
+ to prevent blocking the event loop. This is ideal for I/O-bound operations
175
+ (requests.get(), file I/O, etc.). For CPU-bound operations or when you need
176
+ explicit control over concurrency, use async functions instead.
248
177
 
249
178
  Example:
179
+ # Basic function with context
250
180
  @function
251
- async def greet(ctx: Context, name: str) -> str:
181
+ async def greet(ctx: FunctionContext, name: str) -> str:
182
+ ctx.log(f"Greeting {name}") # AGNT5 shorthand!
252
183
  return f"Hello, {name}!"
253
184
 
254
- @function(name="add_numbers", retries=RetryPolicy(max_attempts=5))
255
- async def add(ctx: Context, a: int, b: int) -> int:
185
+ # Simple function without context (optional)
186
+ @function
187
+ async def add(a: int, b: int) -> int:
188
+ return a + b
189
+
190
+ # With Pydantic models (automatic validation + rich schemas)
191
+ from pydantic import BaseModel
192
+
193
+ class UserInput(BaseModel):
194
+ name: str
195
+ age: int
196
+
197
+ class UserOutput(BaseModel):
198
+ greeting: str
199
+ is_adult: bool
200
+
201
+ @function
202
+ async def process_user(ctx: FunctionContext, user: UserInput) -> UserOutput:
203
+ ctx.log(f"Processing user {user.name}")
204
+ return UserOutput(
205
+ greeting=f"Hello, {user.name}!",
206
+ is_adult=user.age >= 18
207
+ )
208
+
209
+ # Simple retry count
210
+ @function(retries=5)
211
+ async def with_retries(data: str) -> str:
212
+ return data.upper()
213
+
214
+ # Dict configuration
215
+ @function(retries={"max_attempts": 5}, backoff="exponential")
216
+ async def advanced(a: int, b: int) -> int:
256
217
  return a + b
257
218
  """
258
219
 
@@ -260,39 +221,43 @@ def function(
260
221
  # Get function name
261
222
  func_name = name or func.__name__
262
223
 
263
- # Validate function signature
224
+ # Validate function signature and check if context is needed
264
225
  sig = inspect.signature(func)
265
226
  params = list(sig.parameters.values())
266
227
 
267
- if not params or params[0].name != "ctx":
268
- raise ValueError(
269
- f"Function '{func_name}' must have 'ctx: Context' as first parameter"
270
- )
228
+ # Check if function declares 'ctx' parameter
229
+ needs_context = params and params[0].name == "ctx"
271
230
 
272
231
  # Convert sync to async if needed
273
232
  # Note: Async generators should NOT be wrapped - they need to be returned as-is
274
233
  if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
275
234
  handler_func = cast(HandlerFunc, func)
276
235
  else:
277
- # Wrap sync function in async
236
+ # Wrap sync function to run in thread pool (prevents blocking event loop)
278
237
  @functools.wraps(func)
279
238
  async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
280
- return func(*args, **kwargs)
239
+ loop = asyncio.get_running_loop()
240
+ # Run sync function in thread pool executor to prevent blocking
241
+ return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
281
242
 
282
243
  handler_func = cast(HandlerFunc, async_wrapper)
283
244
 
284
245
  # Extract schemas from type hints
285
- input_schema, output_schema = _extract_function_schemas(func)
246
+ input_schema, output_schema = extract_function_schemas(func)
286
247
 
287
248
  # Extract metadata (description, etc.)
288
- metadata = _extract_function_metadata(func)
249
+ metadata = extract_function_metadata(func)
250
+
251
+ # Parse retry and backoff policies from flexible formats
252
+ retry_policy = parse_retry_policy(retries)
253
+ backoff_policy = parse_backoff_policy(backoff)
289
254
 
290
255
  # Register function
291
256
  config = FunctionConfig(
292
257
  name=func_name,
293
258
  handler=handler_func,
294
- retries=retries or RetryPolicy(),
295
- backoff=backoff or BackoffPolicy(),
259
+ retries=retry_policy,
260
+ backoff=backoff_policy,
296
261
  input_schema=input_schema,
297
262
  output_schema=output_schema,
298
263
  metadata=metadata,
@@ -302,33 +267,41 @@ def function(
302
267
  # Create wrapper with retry logic
303
268
  @functools.wraps(func)
304
269
  async def wrapper(*args: Any, **kwargs: Any) -> Any:
305
- # Create context if not provided
306
- if not args or not isinstance(args[0], Context):
307
- # Auto-create context for direct function calls
308
- ctx = Context(
309
- run_id=f"local-{uuid.uuid4().hex[:8]}",
310
- component_type="function",
311
- )
312
- # Execute with retry
313
- return await _execute_with_retry(
314
- handler_func,
315
- ctx,
316
- config.retries or RetryPolicy(),
317
- config.backoff or BackoffPolicy(),
318
- *args,
319
- **kwargs,
320
- )
321
- else:
322
- # Context provided - use it
270
+ # Extract or create context based on function signature
271
+ if needs_context:
272
+ # Function declares ctx parameter - first argument must be FunctionContext
273
+ if not args or not isinstance(args[0], FunctionContext):
274
+ raise TypeError(
275
+ f"Function '{func_name}' requires FunctionContext as first argument. "
276
+ f"Usage: await {func_name}(ctx, ...)"
277
+ )
323
278
  ctx = args[0]
324
- return await _execute_with_retry(
325
- handler_func,
326
- ctx,
327
- config.retries or RetryPolicy(),
328
- config.backoff or BackoffPolicy(),
329
- *args[1:],
330
- **kwargs,
331
- )
279
+ func_args = args[1:]
280
+ else:
281
+ # Function doesn't use context - create a minimal one for internal use
282
+ # But first check if a context was passed anyway (for Worker execution)
283
+ if args and isinstance(args[0], FunctionContext):
284
+ # Context was provided by Worker - use it but don't pass to function
285
+ ctx = args[0]
286
+ func_args = args[1:]
287
+ else:
288
+ # No context provided - create a default one
289
+ ctx = FunctionContext(
290
+ run_id=f"local-{uuid.uuid4().hex[:8]}",
291
+ retry_policy=retry_policy
292
+ )
293
+ func_args = args
294
+
295
+ # Execute with retry
296
+ return await execute_with_retry(
297
+ handler_func,
298
+ ctx,
299
+ config.retries or RetryPolicy(),
300
+ config.backoff or BackoffPolicy(),
301
+ needs_context,
302
+ *func_args,
303
+ **kwargs,
304
+ )
332
305
 
333
306
  # Store config on wrapper for introspection
334
307
  wrapper._agnt5_config = config # type: ignore
agnt5/lm.py CHANGED
@@ -32,10 +32,13 @@ Supported Providers (via model prefix):
32
32
 
33
33
  from __future__ import annotations
34
34
 
35
+ import json
35
36
  from dataclasses import dataclass, field
36
37
  from enum import Enum
37
38
  from typing import Any, AsyncIterator, Dict, List, Optional
38
39
 
40
+ from ._schema_utils import detect_format_type
41
+
39
42
  try:
40
43
  from ._core import LanguageModel as RustLanguageModel
41
44
  from ._core import LanguageModelConfig as RustLanguageModelConfig
@@ -160,6 +163,39 @@ class GenerateResponse:
160
163
  usage: Optional[TokenUsage] = None
161
164
  finish_reason: Optional[str] = None
162
165
  tool_calls: Optional[List[Dict[str, Any]]] = None
166
+ _rust_response: Optional[Any] = field(default=None, repr=False)
167
+
168
+ @property
169
+ def structured_output(self) -> Optional[Any]:
170
+ """Parsed structured output (Pydantic model, dataclass, or dict).
171
+
172
+ Returns the parsed object when response_format is specified.
173
+ This is the recommended property name for accessing structured output.
174
+
175
+ Returns:
176
+ Parsed object according to the specified response_format, or None if not available
177
+ """
178
+ if self._rust_response and hasattr(self._rust_response, 'object'):
179
+ return self._rust_response.object
180
+ return None
181
+
182
+ @property
183
+ def parsed(self) -> Optional[Any]:
184
+ """Alias for structured_output (OpenAI SDK compatibility).
185
+
186
+ Returns:
187
+ Same as structured_output
188
+ """
189
+ return self.structured_output
190
+
191
+ @property
192
+ def object(self) -> Optional[Any]:
193
+ """Alias for structured_output.
194
+
195
+ Returns:
196
+ Same as structured_output
197
+ """
198
+ return self.structured_output
163
199
 
164
200
 
165
201
  @dataclass
@@ -172,6 +208,7 @@ class GenerateRequest:
172
208
  tools: List[ToolDefinition] = field(default_factory=list)
173
209
  tool_choice: Optional[ToolChoice] = None
174
210
  config: GenerationConfig = field(default_factory=GenerationConfig)
211
+ response_schema: Optional[str] = None # JSON-encoded schema for structured output
175
212
 
176
213
 
177
214
  # Internal wrapper for the Rust-backed implementation
@@ -271,14 +308,19 @@ class _LanguageModel:
271
308
  if request.config.top_p is not None:
272
309
  kwargs["top_p"] = request.config.top_p
273
310
 
311
+ # Pass response schema for structured output if provided
312
+ if request.response_schema is not None:
313
+ kwargs["response_schema_kw"] = request.response_schema
314
+
274
315
  # TODO: Add tools and tool_choice support when needed
275
316
  # if request.tools:
276
317
  # kwargs["tools"] = self._serialize_tools(request.tools)
277
318
  # if request.tool_choice:
278
319
  # kwargs["tool_choice"] = request.tool_choice.value
279
320
 
280
- # Call Rust implementation
281
- rust_response = self._rust_lm.generate(prompt=prompt, **kwargs)
321
+ # Call Rust implementation - it returns a proper Python coroutine now
322
+ # Using pyo3-async-runtimes for truly async HTTP calls without blocking
323
+ rust_response = await self._rust_lm.generate(prompt=prompt, **kwargs)
282
324
 
283
325
  # Convert Rust response to Python
284
326
  return self._convert_response(rust_response)
@@ -326,8 +368,9 @@ class _LanguageModel:
326
368
  # if request.tool_choice:
327
369
  # kwargs["tool_choice"] = request.tool_choice.value
328
370
 
329
- # Call Rust implementation (it returns a list of chunks)
330
- rust_chunks = self._rust_lm.stream(prompt=prompt, **kwargs)
371
+ # Call Rust implementation - it returns a proper Python coroutine now
372
+ # Using pyo3-async-runtimes for truly async streaming without blocking
373
+ rust_chunks = await self._rust_lm.stream(prompt=prompt, **kwargs)
331
374
 
332
375
  # Yield each chunk
333
376
  for chunk in rust_chunks:
@@ -378,6 +421,7 @@ class _LanguageModel:
378
421
  usage=usage,
379
422
  finish_reason=None, # TODO: Add finish_reason to Rust response
380
423
  tool_calls=None, # TODO: Add tool calls support
424
+ _rust_response=rust_response, # Store for .structured_output access
381
425
  )
382
426
 
383
427
 
@@ -394,6 +438,7 @@ async def generate(
394
438
  temperature: Optional[float] = None,
395
439
  max_tokens: Optional[int] = None,
396
440
  top_p: Optional[float] = None,
441
+ response_format: Optional[Any] = None,
397
442
  ) -> GenerateResponse:
398
443
  """Generate text using any LLM provider (simplified API).
399
444
 
@@ -408,9 +453,10 @@ async def generate(
408
453
  temperature: Sampling temperature (0.0-2.0)
409
454
  max_tokens: Maximum tokens to generate
410
455
  top_p: Nucleus sampling parameter
456
+ response_format: Pydantic model, dataclass, or JSON schema dict for structured output
411
457
 
412
458
  Returns:
413
- GenerateResponse with text and usage information
459
+ GenerateResponse with text, usage, and optional structured output
414
460
 
415
461
  Examples:
416
462
  Simple prompt:
@@ -421,15 +467,21 @@ async def generate(
421
467
  ... )
422
468
  >>> print(response.text)
423
469
 
424
- Multi-turn conversation:
470
+ Structured output with dataclass:
471
+ >>> from dataclasses import dataclass
472
+ >>>
473
+ >>> @dataclass
474
+ ... class CodeReview:
475
+ ... issues: list[str]
476
+ ... suggestions: list[str]
477
+ ... overall_quality: int
478
+ >>>
425
479
  >>> response = await generate(
426
- ... model="anthropic/claude-3-5-haiku",
427
- ... messages=[
428
- ... {"role": "user", "content": "What is calculus?"},
429
- ... {"role": "assistant", "content": "Calculus is..."},
430
- ... {"role": "user", "content": "Give me an example"}
431
- ... ]
480
+ ... model="openai/gpt-4o",
481
+ ... prompt="Analyze this code...",
482
+ ... response_format=CodeReview
432
483
  ... )
484
+ >>> review = response.structured_output # Returns dict
433
485
  """
434
486
  # Validate input
435
487
  if not prompt and not messages:
@@ -446,6 +498,12 @@ async def generate(
446
498
 
447
499
  provider, model_name = model.split('/', 1)
448
500
 
501
+ # Convert response_format to JSON schema if provided
502
+ response_schema_json = None
503
+ if response_format is not None:
504
+ format_type, json_schema = detect_format_type(response_format)
505
+ response_schema_json = json.dumps(json_schema)
506
+
449
507
  # Create language model client
450
508
  lm = _LanguageModel(provider=provider.lower(), default_model=None)
451
509
 
@@ -478,6 +536,7 @@ async def generate(
478
536
  messages=message_objects,
479
537
  system_prompt=system_prompt,
480
538
  config=config,
539
+ response_schema=response_schema_json,
481
540
  )
482
541
 
483
542
  # Generate and return