agnt5 0.2.8a10__cp310-abi3-manylinux_2_34_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +91 -0
- agnt5/_compat.py +16 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +169 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_telemetry.py +182 -0
- agnt5/agent.py +1685 -0
- agnt5/client.py +741 -0
- agnt5/context.py +178 -0
- agnt5/entity.py +795 -0
- agnt5/exceptions.py +102 -0
- agnt5/function.py +321 -0
- agnt5/lm.py +813 -0
- agnt5/tool.py +648 -0
- agnt5/tracing.py +196 -0
- agnt5/types.py +110 -0
- agnt5/version.py +19 -0
- agnt5/worker.py +1619 -0
- agnt5/workflow.py +1048 -0
- agnt5-0.2.8a10.dist-info/METADATA +25 -0
- agnt5-0.2.8a10.dist-info/RECORD +22 -0
- agnt5-0.2.8a10.dist-info/WHEEL +4 -0
agnt5/lm.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
1
|
+
"""Language Model interface for AGNT5 SDK.
|
|
2
|
+
|
|
3
|
+
Simplified API inspired by Vercel AI SDK for seamless multi-provider LLM access.
|
|
4
|
+
Uses Rust-backed implementation via PyO3 for performance and reliability.
|
|
5
|
+
|
|
6
|
+
Basic Usage:
|
|
7
|
+
>>> from agnt5 import lm
|
|
8
|
+
>>>
|
|
9
|
+
>>> # Simple generation
|
|
10
|
+
>>> response = await lm.generate(
|
|
11
|
+
... model="openai/gpt-4o-mini",
|
|
12
|
+
... prompt="What is love?",
|
|
13
|
+
... temperature=0.7
|
|
14
|
+
... )
|
|
15
|
+
>>> print(response.text)
|
|
16
|
+
>>>
|
|
17
|
+
>>> # Streaming
|
|
18
|
+
>>> async for chunk in lm.stream(
|
|
19
|
+
... model="anthropic/claude-3-5-haiku",
|
|
20
|
+
... prompt="Write a story"
|
|
21
|
+
... ):
|
|
22
|
+
... print(chunk, end="", flush=True)
|
|
23
|
+
|
|
24
|
+
Supported Providers (via model prefix):
|
|
25
|
+
- openai/model-name
|
|
26
|
+
- anthropic/model-name
|
|
27
|
+
- groq/model-name
|
|
28
|
+
- openrouter/provider/model-name
|
|
29
|
+
- azure/model-name
|
|
30
|
+
- bedrock/model-name
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
import json
|
|
36
|
+
from abc import ABC, abstractmethod
|
|
37
|
+
from dataclasses import dataclass, field
|
|
38
|
+
from enum import Enum
|
|
39
|
+
from typing import Any, AsyncIterator, Dict, List, Optional
|
|
40
|
+
|
|
41
|
+
from ._schema_utils import detect_format_type
|
|
42
|
+
from .context import get_current_context
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from ._core import LanguageModel as RustLanguageModel
|
|
46
|
+
from ._core import LanguageModelConfig as RustLanguageModelConfig
|
|
47
|
+
from ._core import Response as RustResponse
|
|
48
|
+
from ._core import StreamChunk as RustStreamChunk
|
|
49
|
+
from ._core import Usage as RustUsage
|
|
50
|
+
_RUST_AVAILABLE = True
|
|
51
|
+
except ImportError:
|
|
52
|
+
_RUST_AVAILABLE = False
|
|
53
|
+
RustLanguageModel = None
|
|
54
|
+
RustLanguageModelConfig = None
|
|
55
|
+
RustResponse = None
|
|
56
|
+
RustStreamChunk = None
|
|
57
|
+
RustUsage = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Keep Python classes for backward compatibility and convenience
|
|
61
|
+
class MessageRole(str, Enum):
|
|
62
|
+
"""Message role in conversation."""
|
|
63
|
+
|
|
64
|
+
SYSTEM = "system"
|
|
65
|
+
USER = "user"
|
|
66
|
+
ASSISTANT = "assistant"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class Message:
|
|
71
|
+
"""Conversation message."""
|
|
72
|
+
|
|
73
|
+
role: MessageRole
|
|
74
|
+
content: str
|
|
75
|
+
|
|
76
|
+
@staticmethod
|
|
77
|
+
def system(content: str) -> Message:
|
|
78
|
+
"""Create system message."""
|
|
79
|
+
return Message(role=MessageRole.SYSTEM, content=content)
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def user(content: str) -> Message:
|
|
83
|
+
"""Create user message."""
|
|
84
|
+
return Message(role=MessageRole.USER, content=content)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def assistant(content: str) -> Message:
|
|
88
|
+
"""Create assistant message."""
|
|
89
|
+
return Message(role=MessageRole.ASSISTANT, content=content)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class ToolDefinition:
|
|
94
|
+
"""Tool definition for LLM."""
|
|
95
|
+
|
|
96
|
+
name: str
|
|
97
|
+
description: Optional[str] = None
|
|
98
|
+
parameters: Optional[Dict[str, Any]] = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ToolChoice(str, Enum):
|
|
102
|
+
"""Tool choice mode."""
|
|
103
|
+
|
|
104
|
+
AUTO = "auto"
|
|
105
|
+
NONE = "none"
|
|
106
|
+
REQUIRED = "required"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class ModelConfig:
|
|
111
|
+
"""Advanced model configuration for custom endpoints and settings.
|
|
112
|
+
|
|
113
|
+
Use this for advanced scenarios like custom API endpoints, special headers,
|
|
114
|
+
or overriding default timeouts. Most users won't need this - the basic
|
|
115
|
+
model string with temperature/max_tokens is sufficient for common cases.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> from agnt5.lm import ModelConfig
|
|
119
|
+
>>> from agnt5 import Agent
|
|
120
|
+
>>>
|
|
121
|
+
>>> # Custom API endpoint
|
|
122
|
+
>>> config = ModelConfig(
|
|
123
|
+
... base_url="https://custom-api.example.com",
|
|
124
|
+
... api_key="custom-key",
|
|
125
|
+
... timeout=60,
|
|
126
|
+
... headers={"X-Custom-Header": "value"}
|
|
127
|
+
... )
|
|
128
|
+
>>>
|
|
129
|
+
>>> agent = Agent(
|
|
130
|
+
... name="custom_agent",
|
|
131
|
+
... model="openai/gpt-4o-mini",
|
|
132
|
+
... instructions="...",
|
|
133
|
+
... model_config=config
|
|
134
|
+
... )
|
|
135
|
+
"""
|
|
136
|
+
base_url: Optional[str] = None
|
|
137
|
+
api_key: Optional[str] = None
|
|
138
|
+
timeout: Optional[int] = None
|
|
139
|
+
headers: Optional[Dict[str, str]] = None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class GenerationConfig:
|
|
144
|
+
"""LLM generation configuration."""
|
|
145
|
+
|
|
146
|
+
temperature: Optional[float] = None
|
|
147
|
+
max_tokens: Optional[int] = None
|
|
148
|
+
top_p: Optional[float] = None
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class TokenUsage:
|
|
153
|
+
"""Token usage statistics."""
|
|
154
|
+
|
|
155
|
+
prompt_tokens: int
|
|
156
|
+
completion_tokens: int
|
|
157
|
+
total_tokens: int
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@dataclass
|
|
161
|
+
class GenerateResponse:
|
|
162
|
+
"""Response from LLM generation."""
|
|
163
|
+
|
|
164
|
+
text: str
|
|
165
|
+
usage: Optional[TokenUsage] = None
|
|
166
|
+
finish_reason: Optional[str] = None
|
|
167
|
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
168
|
+
_rust_response: Optional[Any] = field(default=None, repr=False)
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def structured_output(self) -> Optional[Any]:
|
|
172
|
+
"""Parsed structured output (Pydantic model, dataclass, or dict).
|
|
173
|
+
|
|
174
|
+
Returns the parsed object when response_format is specified.
|
|
175
|
+
This is the recommended property name for accessing structured output.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Parsed object according to the specified response_format, or None if not available
|
|
179
|
+
"""
|
|
180
|
+
if self._rust_response and hasattr(self._rust_response, 'object'):
|
|
181
|
+
return self._rust_response.object
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def parsed(self) -> Optional[Any]:
|
|
186
|
+
"""Alias for structured_output (OpenAI SDK compatibility).
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Same as structured_output
|
|
190
|
+
"""
|
|
191
|
+
return self.structured_output
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def object(self) -> Optional[Any]:
|
|
195
|
+
"""Alias for structured_output.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Same as structured_output
|
|
199
|
+
"""
|
|
200
|
+
return self.structured_output
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@dataclass
|
|
204
|
+
class GenerateRequest:
|
|
205
|
+
"""Request for LLM generation."""
|
|
206
|
+
|
|
207
|
+
model: str
|
|
208
|
+
messages: List[Message] = field(default_factory=list)
|
|
209
|
+
system_prompt: Optional[str] = None
|
|
210
|
+
tools: List[ToolDefinition] = field(default_factory=list)
|
|
211
|
+
tool_choice: Optional[ToolChoice] = None
|
|
212
|
+
config: GenerationConfig = field(default_factory=GenerationConfig)
|
|
213
|
+
response_schema: Optional[str] = None # JSON-encoded schema for structured output
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# Abstract base class for language models
|
|
217
|
+
# This exists primarily for testing/mocking purposes
|
|
218
|
+
class LanguageModel(ABC):
|
|
219
|
+
"""Abstract base class for language model implementations.
|
|
220
|
+
|
|
221
|
+
This class defines the interface that all language models must implement.
|
|
222
|
+
It's primarily used for testing and mocking, as production code should use
|
|
223
|
+
the module-level generate() and stream() functions instead.
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
@abstractmethod
|
|
227
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
228
|
+
"""Generate completion from LLM.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
request: Generation request with model, messages, and configuration
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
GenerateResponse with text, usage, and optional tool calls
|
|
235
|
+
"""
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
@abstractmethod
|
|
239
|
+
async def stream(self, request: GenerateRequest) -> AsyncIterator[str]:
|
|
240
|
+
"""Stream completion from LLM.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
request: Generation request with model, messages, and configuration
|
|
244
|
+
|
|
245
|
+
Yields:
|
|
246
|
+
Text chunks as they are generated
|
|
247
|
+
"""
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# Internal wrapper for the Rust-backed implementation
|
|
252
|
+
# Users should use the module-level generate() and stream() functions instead
|
|
253
|
+
class _LanguageModel(LanguageModel):
|
|
254
|
+
"""Internal Language Model wrapper using Rust SDK core.
|
|
255
|
+
|
|
256
|
+
This class is for internal use only. Users should use the module-level
|
|
257
|
+
lm.generate() and lm.stream() functions for a simpler interface.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
provider: Optional[str] = None,
|
|
263
|
+
default_model: Optional[str] = None,
|
|
264
|
+
):
|
|
265
|
+
"""Initialize language model.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
provider: Provider name (e.g., 'openai', 'anthropic', 'azure', 'bedrock', 'groq', 'openrouter')
|
|
269
|
+
If None, provider will be auto-detected from model prefix (e.g., 'openai/gpt-4o')
|
|
270
|
+
default_model: Default model to use if not specified in requests
|
|
271
|
+
"""
|
|
272
|
+
if not _RUST_AVAILABLE:
|
|
273
|
+
raise ImportError(
|
|
274
|
+
"Rust extension not available. Please rebuild the SDK with: "
|
|
275
|
+
"cd sdk/sdk-python && maturin develop"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
self._provider = provider
|
|
279
|
+
self._default_model = default_model
|
|
280
|
+
|
|
281
|
+
# Create config object for Rust
|
|
282
|
+
config = RustLanguageModelConfig(
|
|
283
|
+
default_model=default_model,
|
|
284
|
+
default_provider=provider,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self._rust_lm = RustLanguageModel(config=config)
|
|
288
|
+
|
|
289
|
+
def _prepare_model_name(self, model: str) -> str:
|
|
290
|
+
"""Prepare model name with provider prefix if needed.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
model: Model name (e.g., 'gpt-4o-mini' or 'openai/gpt-4o-mini')
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
Model name with provider prefix (e.g., 'openai/gpt-4o-mini')
|
|
297
|
+
"""
|
|
298
|
+
# If model already has a prefix, return as is
|
|
299
|
+
# This handles cases like OpenRouter where models already have their provider prefix
|
|
300
|
+
# (e.g., 'anthropic/claude-3.5-haiku' for OpenRouter)
|
|
301
|
+
if '/' in model:
|
|
302
|
+
return model
|
|
303
|
+
|
|
304
|
+
# If we have a default provider, prefix the model
|
|
305
|
+
if self._provider:
|
|
306
|
+
return f"{self._provider}/{model}"
|
|
307
|
+
|
|
308
|
+
# Otherwise return as is and let Rust handle the error
|
|
309
|
+
return model
|
|
310
|
+
|
|
311
|
+
async def generate(self, request: GenerateRequest) -> GenerateResponse:
|
|
312
|
+
"""Generate completion from LLM.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
request: Generation request with model, messages, and configuration
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
GenerateResponse with text, usage, and optional tool calls
|
|
319
|
+
"""
|
|
320
|
+
# Convert Python request to structured format for Rust
|
|
321
|
+
prompt = self._build_prompt_messages(request)
|
|
322
|
+
|
|
323
|
+
# Prepare model name with provider prefix
|
|
324
|
+
model = self._prepare_model_name(request.model)
|
|
325
|
+
|
|
326
|
+
# Build kwargs for Rust
|
|
327
|
+
kwargs = {
|
|
328
|
+
"model": model,
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
# Always pass provider explicitly if set
|
|
332
|
+
# For gateway providers like OpenRouter, this allows them to handle
|
|
333
|
+
# models with provider prefixes (e.g., openrouter can handle anthropic/claude-3.5-haiku)
|
|
334
|
+
if self._provider:
|
|
335
|
+
kwargs["provider"] = self._provider
|
|
336
|
+
|
|
337
|
+
# Pass system prompt separately if provided
|
|
338
|
+
if request.system_prompt:
|
|
339
|
+
kwargs["system_prompt"] = request.system_prompt
|
|
340
|
+
|
|
341
|
+
if request.config.temperature is not None:
|
|
342
|
+
kwargs["temperature"] = request.config.temperature
|
|
343
|
+
if request.config.max_tokens is not None:
|
|
344
|
+
kwargs["max_tokens"] = request.config.max_tokens
|
|
345
|
+
if request.config.top_p is not None:
|
|
346
|
+
kwargs["top_p"] = request.config.top_p
|
|
347
|
+
|
|
348
|
+
# Pass response schema for structured output if provided
|
|
349
|
+
if request.response_schema is not None:
|
|
350
|
+
kwargs["response_schema_kw"] = request.response_schema
|
|
351
|
+
|
|
352
|
+
# Pass tools and tool_choice to Rust
|
|
353
|
+
if request.tools:
|
|
354
|
+
# Serialize tools to JSON for Rust
|
|
355
|
+
tools_list = [
|
|
356
|
+
{
|
|
357
|
+
"name": tool.name,
|
|
358
|
+
"description": tool.description,
|
|
359
|
+
"parameters": tool.parameters,
|
|
360
|
+
}
|
|
361
|
+
for tool in request.tools
|
|
362
|
+
]
|
|
363
|
+
tools_json = json.dumps(tools_list)
|
|
364
|
+
kwargs["tools"] = tools_json
|
|
365
|
+
|
|
366
|
+
if request.tool_choice:
|
|
367
|
+
# Serialize tool_choice to JSON for Rust
|
|
368
|
+
kwargs["tool_choice"] = json.dumps(request.tool_choice.value)
|
|
369
|
+
|
|
370
|
+
# Pass runtime_context for proper trace linking
|
|
371
|
+
# Try to get from current context if available
|
|
372
|
+
current_ctx = get_current_context()
|
|
373
|
+
if current_ctx and hasattr(current_ctx, '_runtime_context') and current_ctx._runtime_context:
|
|
374
|
+
kwargs["runtime_context"] = current_ctx._runtime_context
|
|
375
|
+
|
|
376
|
+
# Emit checkpoint if called within a workflow context
|
|
377
|
+
from .context import get_workflow_context
|
|
378
|
+
workflow_ctx = get_workflow_context()
|
|
379
|
+
if workflow_ctx:
|
|
380
|
+
workflow_ctx._send_checkpoint("workflow.lm.started", {
|
|
381
|
+
"model": model,
|
|
382
|
+
"provider": self._provider,
|
|
383
|
+
"temperature": kwargs.get("temperature"),
|
|
384
|
+
"max_tokens": kwargs.get("max_tokens"),
|
|
385
|
+
})
|
|
386
|
+
|
|
387
|
+
try:
|
|
388
|
+
# Call Rust implementation - it returns a proper Python coroutine now
|
|
389
|
+
# Using pyo3-async-runtimes for truly async HTTP calls without blocking
|
|
390
|
+
rust_response = await self._rust_lm.generate(prompt=prompt, **kwargs)
|
|
391
|
+
|
|
392
|
+
# Convert Rust response to Python
|
|
393
|
+
response = self._convert_response(rust_response)
|
|
394
|
+
|
|
395
|
+
# Emit completion checkpoint with usage stats
|
|
396
|
+
if workflow_ctx:
|
|
397
|
+
usage_dict = None
|
|
398
|
+
if response.usage:
|
|
399
|
+
usage_dict = {
|
|
400
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
401
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
402
|
+
"total_tokens": response.usage.total_tokens,
|
|
403
|
+
}
|
|
404
|
+
workflow_ctx._send_checkpoint("workflow.lm.completed", {
|
|
405
|
+
"model": model,
|
|
406
|
+
"usage": usage_dict,
|
|
407
|
+
})
|
|
408
|
+
|
|
409
|
+
return response
|
|
410
|
+
except Exception as e:
|
|
411
|
+
# Emit error checkpoint for observability
|
|
412
|
+
if workflow_ctx:
|
|
413
|
+
workflow_ctx._send_checkpoint("workflow.lm.error", {
|
|
414
|
+
"model": model,
|
|
415
|
+
"error": str(e),
|
|
416
|
+
"error_type": type(e).__name__,
|
|
417
|
+
})
|
|
418
|
+
raise
|
|
419
|
+
|
|
420
|
+
async def stream(self, request: GenerateRequest) -> AsyncIterator[str]:
|
|
421
|
+
"""Stream completion from LLM.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
request: Generation request with model, messages, and configuration
|
|
425
|
+
|
|
426
|
+
Yields:
|
|
427
|
+
Text chunks as they are generated
|
|
428
|
+
"""
|
|
429
|
+
# Convert Python request to structured format for Rust
|
|
430
|
+
prompt = self._build_prompt_messages(request)
|
|
431
|
+
|
|
432
|
+
# Prepare model name with provider prefix
|
|
433
|
+
model = self._prepare_model_name(request.model)
|
|
434
|
+
|
|
435
|
+
# Build kwargs for Rust
|
|
436
|
+
kwargs = {
|
|
437
|
+
"model": model,
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
# Always pass provider explicitly if set
|
|
441
|
+
# For gateway providers like OpenRouter, this allows them to handle
|
|
442
|
+
# models with provider prefixes (e.g., openrouter can handle anthropic/claude-3.5-haiku)
|
|
443
|
+
if self._provider:
|
|
444
|
+
kwargs["provider"] = self._provider
|
|
445
|
+
|
|
446
|
+
# Pass system prompt separately if provided
|
|
447
|
+
if request.system_prompt:
|
|
448
|
+
kwargs["system_prompt"] = request.system_prompt
|
|
449
|
+
|
|
450
|
+
if request.config.temperature is not None:
|
|
451
|
+
kwargs["temperature"] = request.config.temperature
|
|
452
|
+
if request.config.max_tokens is not None:
|
|
453
|
+
kwargs["max_tokens"] = request.config.max_tokens
|
|
454
|
+
if request.config.top_p is not None:
|
|
455
|
+
kwargs["top_p"] = request.config.top_p
|
|
456
|
+
|
|
457
|
+
# Pass tools and tool_choice to Rust
|
|
458
|
+
if request.tools:
|
|
459
|
+
# Serialize tools to JSON for Rust
|
|
460
|
+
tools_list = [
|
|
461
|
+
{
|
|
462
|
+
"name": tool.name,
|
|
463
|
+
"description": tool.description,
|
|
464
|
+
"parameters": tool.parameters,
|
|
465
|
+
}
|
|
466
|
+
for tool in request.tools
|
|
467
|
+
]
|
|
468
|
+
kwargs["tools"] = json.dumps(tools_list)
|
|
469
|
+
|
|
470
|
+
if request.tool_choice:
|
|
471
|
+
# Serialize tool_choice to JSON for Rust
|
|
472
|
+
kwargs["tool_choice"] = json.dumps(request.tool_choice.value)
|
|
473
|
+
|
|
474
|
+
# Emit checkpoint if called within a workflow context
|
|
475
|
+
from .context import get_workflow_context
|
|
476
|
+
workflow_ctx = get_workflow_context()
|
|
477
|
+
if workflow_ctx:
|
|
478
|
+
workflow_ctx._send_checkpoint("workflow.lm.started", {
|
|
479
|
+
"model": model,
|
|
480
|
+
"provider": self._provider,
|
|
481
|
+
"temperature": kwargs.get("temperature"),
|
|
482
|
+
"max_tokens": kwargs.get("max_tokens"),
|
|
483
|
+
"streaming": True,
|
|
484
|
+
})
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
# Call Rust implementation - it returns a proper Python coroutine now
|
|
488
|
+
# Using pyo3-async-runtimes for truly async streaming without blocking
|
|
489
|
+
rust_chunks = await self._rust_lm.stream(prompt=prompt, **kwargs)
|
|
490
|
+
|
|
491
|
+
# Yield each chunk
|
|
492
|
+
for chunk in rust_chunks:
|
|
493
|
+
if chunk.text:
|
|
494
|
+
yield chunk.text
|
|
495
|
+
|
|
496
|
+
# Emit completion checkpoint after streaming finishes
|
|
497
|
+
if workflow_ctx:
|
|
498
|
+
workflow_ctx._send_checkpoint("workflow.lm.completed", {
|
|
499
|
+
"model": model,
|
|
500
|
+
"streaming": True,
|
|
501
|
+
})
|
|
502
|
+
except Exception as e:
|
|
503
|
+
# Emit error checkpoint for observability
|
|
504
|
+
if workflow_ctx:
|
|
505
|
+
workflow_ctx._send_checkpoint("workflow.lm.error", {
|
|
506
|
+
"model": model,
|
|
507
|
+
"error": str(e),
|
|
508
|
+
"error_type": type(e).__name__,
|
|
509
|
+
})
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
def _build_prompt_messages(self, request: GenerateRequest) -> List[Dict[str, str]]:
|
|
513
|
+
"""Build structured message list for Rust.
|
|
514
|
+
|
|
515
|
+
Rust expects a list of dicts with 'role' and 'content' keys.
|
|
516
|
+
System prompt is passed separately via kwargs.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
request: Generation request with messages
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
List of message dicts with role and content
|
|
523
|
+
"""
|
|
524
|
+
# Convert messages to Rust format (list of dicts with role and content)
|
|
525
|
+
messages = []
|
|
526
|
+
for msg in request.messages:
|
|
527
|
+
messages.append({
|
|
528
|
+
"role": msg.role.value, # "system", "user", or "assistant"
|
|
529
|
+
"content": msg.content
|
|
530
|
+
})
|
|
531
|
+
|
|
532
|
+
# If no messages and no system prompt, return a default user message
|
|
533
|
+
if not messages and not request.system_prompt:
|
|
534
|
+
messages.append({
|
|
535
|
+
"role": "user",
|
|
536
|
+
"content": ""
|
|
537
|
+
})
|
|
538
|
+
|
|
539
|
+
return messages
|
|
540
|
+
|
|
541
|
+
def _convert_response(self, rust_response: RustResponse) -> GenerateResponse:
|
|
542
|
+
"""Convert Rust response to Python response."""
|
|
543
|
+
usage = None
|
|
544
|
+
if rust_response.usage:
|
|
545
|
+
usage = TokenUsage(
|
|
546
|
+
prompt_tokens=rust_response.usage.prompt_tokens,
|
|
547
|
+
completion_tokens=rust_response.usage.completion_tokens,
|
|
548
|
+
total_tokens=rust_response.usage.total_tokens,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Extract tool_calls from Rust response
|
|
552
|
+
tool_calls = None
|
|
553
|
+
if hasattr(rust_response, 'tool_calls') and rust_response.tool_calls:
|
|
554
|
+
tool_calls = rust_response.tool_calls
|
|
555
|
+
|
|
556
|
+
return GenerateResponse(
|
|
557
|
+
text=rust_response.content,
|
|
558
|
+
usage=usage,
|
|
559
|
+
finish_reason=None, # TODO: Add finish_reason to Rust response
|
|
560
|
+
tool_calls=tool_calls,
|
|
561
|
+
_rust_response=rust_response, # Store for .structured_output access
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
# ============================================================================
|
|
566
|
+
# Simplified API (Recommended)
|
|
567
|
+
# ============================================================================
|
|
568
|
+
# This is the recommended simple interface for most use cases
|
|
569
|
+
|
|
570
|
+
async def generate(
|
|
571
|
+
model: str,
|
|
572
|
+
prompt: Optional[str] = None,
|
|
573
|
+
messages: Optional[List[Dict[str, str]]] = None,
|
|
574
|
+
system_prompt: Optional[str] = None,
|
|
575
|
+
temperature: Optional[float] = None,
|
|
576
|
+
max_tokens: Optional[int] = None,
|
|
577
|
+
top_p: Optional[float] = None,
|
|
578
|
+
response_format: Optional[Any] = None,
|
|
579
|
+
) -> GenerateResponse:
|
|
580
|
+
"""Generate text using any LLM provider (simplified API).
|
|
581
|
+
|
|
582
|
+
This is the recommended way to use the LLM API. Provider is auto-detected
|
|
583
|
+
from the model prefix (e.g., 'openai/gpt-4o-mini', 'anthropic/claude-3-5-haiku').
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
model: Model identifier with provider prefix (e.g., 'openai/gpt-4o-mini')
|
|
587
|
+
prompt: Simple text prompt (for single-turn requests)
|
|
588
|
+
messages: List of message dicts with 'role' and 'content' (for multi-turn)
|
|
589
|
+
system_prompt: Optional system prompt
|
|
590
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
591
|
+
max_tokens: Maximum tokens to generate
|
|
592
|
+
top_p: Nucleus sampling parameter
|
|
593
|
+
response_format: Pydantic model, dataclass, or JSON schema dict for structured output
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
GenerateResponse with text, usage, and optional structured output
|
|
597
|
+
|
|
598
|
+
Examples:
|
|
599
|
+
Simple prompt:
|
|
600
|
+
>>> response = await generate(
|
|
601
|
+
... model="openai/gpt-4o-mini",
|
|
602
|
+
... prompt="What is love?",
|
|
603
|
+
... temperature=0.7
|
|
604
|
+
... )
|
|
605
|
+
>>> print(response.text)
|
|
606
|
+
|
|
607
|
+
Structured output with dataclass:
|
|
608
|
+
>>> from dataclasses import dataclass
|
|
609
|
+
>>>
|
|
610
|
+
>>> @dataclass
|
|
611
|
+
... class CodeReview:
|
|
612
|
+
... issues: list[str]
|
|
613
|
+
... suggestions: list[str]
|
|
614
|
+
... overall_quality: int
|
|
615
|
+
>>>
|
|
616
|
+
>>> response = await generate(
|
|
617
|
+
... model="openai/gpt-4o",
|
|
618
|
+
... prompt="Analyze this code...",
|
|
619
|
+
... response_format=CodeReview
|
|
620
|
+
... )
|
|
621
|
+
>>> review = response.structured_output # Returns dict
|
|
622
|
+
"""
|
|
623
|
+
# Validate input
|
|
624
|
+
if not prompt and not messages:
|
|
625
|
+
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
|
626
|
+
if prompt and messages:
|
|
627
|
+
raise ValueError("Provide either 'prompt' or 'messages', not both")
|
|
628
|
+
|
|
629
|
+
# Auto-detect provider from model prefix
|
|
630
|
+
if '/' not in model:
|
|
631
|
+
raise ValueError(
|
|
632
|
+
f"Model must include provider prefix (e.g., 'openai/{model}'). "
|
|
633
|
+
f"Supported providers: openai, anthropic, groq, openrouter, azure, bedrock"
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
provider, model_name = model.split('/', 1)
|
|
637
|
+
|
|
638
|
+
# Convert response_format to JSON schema if provided
|
|
639
|
+
response_schema_json = None
|
|
640
|
+
if response_format is not None:
|
|
641
|
+
format_type, json_schema = detect_format_type(response_format)
|
|
642
|
+
response_schema_json = json.dumps(json_schema)
|
|
643
|
+
|
|
644
|
+
# Create language model client
|
|
645
|
+
lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
646
|
+
|
|
647
|
+
# Build messages list
|
|
648
|
+
if prompt:
|
|
649
|
+
msg_list = [{"role": "user", "content": prompt}]
|
|
650
|
+
else:
|
|
651
|
+
msg_list = messages
|
|
652
|
+
|
|
653
|
+
# Convert to Message objects for internal API
|
|
654
|
+
message_objects = []
|
|
655
|
+
for msg in msg_list:
|
|
656
|
+
role = MessageRole(msg["role"])
|
|
657
|
+
if role == MessageRole.USER:
|
|
658
|
+
message_objects.append(Message.user(msg["content"]))
|
|
659
|
+
elif role == MessageRole.ASSISTANT:
|
|
660
|
+
message_objects.append(Message.assistant(msg["content"]))
|
|
661
|
+
elif role == MessageRole.SYSTEM:
|
|
662
|
+
message_objects.append(Message.system(msg["content"]))
|
|
663
|
+
|
|
664
|
+
# Build request
|
|
665
|
+
config = GenerationConfig(
|
|
666
|
+
temperature=temperature,
|
|
667
|
+
max_tokens=max_tokens,
|
|
668
|
+
top_p=top_p,
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
request = GenerateRequest(
|
|
672
|
+
model=model,
|
|
673
|
+
messages=message_objects,
|
|
674
|
+
system_prompt=system_prompt,
|
|
675
|
+
config=config,
|
|
676
|
+
response_schema=response_schema_json,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Checkpoints are emitted by _LanguageModel.generate() internally
|
|
680
|
+
# to avoid duplication. No need to emit them here.
|
|
681
|
+
|
|
682
|
+
# Generate and return
|
|
683
|
+
result = await lm.generate(request)
|
|
684
|
+
return result
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
async def stream(
|
|
688
|
+
model: str,
|
|
689
|
+
prompt: Optional[str] = None,
|
|
690
|
+
messages: Optional[List[Dict[str, str]]] = None,
|
|
691
|
+
system_prompt: Optional[str] = None,
|
|
692
|
+
temperature: Optional[float] = None,
|
|
693
|
+
max_tokens: Optional[int] = None,
|
|
694
|
+
top_p: Optional[float] = None,
|
|
695
|
+
) -> AsyncIterator[str]:
|
|
696
|
+
"""Stream text using any LLM provider (simplified API).
|
|
697
|
+
|
|
698
|
+
This is the recommended way to use streaming. Provider is auto-detected
|
|
699
|
+
from the model prefix (e.g., 'openai/gpt-4o-mini', 'anthropic/claude-3-5-haiku').
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
model: Model identifier with provider prefix (e.g., 'openai/gpt-4o-mini')
|
|
703
|
+
prompt: Simple text prompt (for single-turn requests)
|
|
704
|
+
messages: List of message dicts with 'role' and 'content' (for multi-turn)
|
|
705
|
+
system_prompt: Optional system prompt
|
|
706
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
707
|
+
max_tokens: Maximum tokens to generate
|
|
708
|
+
top_p: Nucleus sampling parameter
|
|
709
|
+
|
|
710
|
+
Yields:
|
|
711
|
+
Text chunks as they are generated
|
|
712
|
+
|
|
713
|
+
Examples:
|
|
714
|
+
Simple streaming:
|
|
715
|
+
>>> async for chunk in stream(
|
|
716
|
+
... model="openai/gpt-4o-mini",
|
|
717
|
+
... prompt="Write a story"
|
|
718
|
+
... ):
|
|
719
|
+
... print(chunk, end="", flush=True)
|
|
720
|
+
|
|
721
|
+
Streaming conversation:
|
|
722
|
+
>>> async for chunk in stream(
|
|
723
|
+
... model="groq/llama-3.3-70b-versatile",
|
|
724
|
+
... messages=[
|
|
725
|
+
... {"role": "user", "content": "Tell me a joke"}
|
|
726
|
+
... ],
|
|
727
|
+
... temperature=0.9
|
|
728
|
+
... ):
|
|
729
|
+
... print(chunk, end="")
|
|
730
|
+
"""
|
|
731
|
+
# Validate input
|
|
732
|
+
if not prompt and not messages:
|
|
733
|
+
raise ValueError("Either 'prompt' or 'messages' must be provided")
|
|
734
|
+
if prompt and messages:
|
|
735
|
+
raise ValueError("Provide either 'prompt' or 'messages', not both")
|
|
736
|
+
|
|
737
|
+
# Auto-detect provider from model prefix
|
|
738
|
+
if '/' not in model:
|
|
739
|
+
raise ValueError(
|
|
740
|
+
f"Model must include provider prefix (e.g., 'openai/{model}'). "
|
|
741
|
+
f"Supported providers: openai, anthropic, groq, openrouter, azure, bedrock"
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
provider, model_name = model.split('/', 1)
|
|
745
|
+
|
|
746
|
+
# Create language model client
|
|
747
|
+
lm = _LanguageModel(provider=provider.lower(), default_model=None)
|
|
748
|
+
|
|
749
|
+
# Build messages list
|
|
750
|
+
if prompt:
|
|
751
|
+
msg_list = [{"role": "user", "content": prompt}]
|
|
752
|
+
else:
|
|
753
|
+
msg_list = messages
|
|
754
|
+
|
|
755
|
+
# Convert to Message objects for internal API
|
|
756
|
+
message_objects = []
|
|
757
|
+
for msg in msg_list:
|
|
758
|
+
role = MessageRole(msg["role"])
|
|
759
|
+
if role == MessageRole.USER:
|
|
760
|
+
message_objects.append(Message.user(msg["content"]))
|
|
761
|
+
elif role == MessageRole.ASSISTANT:
|
|
762
|
+
message_objects.append(Message.assistant(msg["content"]))
|
|
763
|
+
elif role == MessageRole.SYSTEM:
|
|
764
|
+
message_objects.append(Message.system(msg["content"]))
|
|
765
|
+
|
|
766
|
+
# Build request
|
|
767
|
+
config = GenerationConfig(
|
|
768
|
+
temperature=temperature,
|
|
769
|
+
max_tokens=max_tokens,
|
|
770
|
+
top_p=top_p,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
request = GenerateRequest(
|
|
774
|
+
model=model,
|
|
775
|
+
messages=message_objects,
|
|
776
|
+
system_prompt=system_prompt,
|
|
777
|
+
config=config,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# Emit checkpoint if called within a workflow context
|
|
781
|
+
from .context import get_workflow_context
|
|
782
|
+
|
|
783
|
+
workflow_ctx = get_workflow_context()
|
|
784
|
+
if workflow_ctx:
|
|
785
|
+
workflow_ctx._send_checkpoint("workflow.lm.started", {
|
|
786
|
+
"model": model,
|
|
787
|
+
"provider": provider,
|
|
788
|
+
"max_tokens": max_tokens,
|
|
789
|
+
"temperature": temperature,
|
|
790
|
+
"streaming": True,
|
|
791
|
+
})
|
|
792
|
+
|
|
793
|
+
try:
|
|
794
|
+
# Stream and yield chunks
|
|
795
|
+
async for chunk in lm.stream(request):
|
|
796
|
+
yield chunk
|
|
797
|
+
|
|
798
|
+
# Emit completion checkpoint (note: no usage stats for streaming)
|
|
799
|
+
if workflow_ctx:
|
|
800
|
+
workflow_ctx._send_checkpoint("workflow.lm.completed", {
|
|
801
|
+
"model": model,
|
|
802
|
+
"streaming": True,
|
|
803
|
+
})
|
|
804
|
+
except Exception as e:
|
|
805
|
+
# Emit error checkpoint for observability
|
|
806
|
+
if workflow_ctx:
|
|
807
|
+
workflow_ctx._send_checkpoint("workflow.lm.error", {
|
|
808
|
+
"model": model,
|
|
809
|
+
"error": str(e),
|
|
810
|
+
"error_type": type(e).__name__,
|
|
811
|
+
"streaming": True,
|
|
812
|
+
})
|
|
813
|
+
raise
|