jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +310 -210
- jaf/core/types.py +403 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +475 -283
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
- jaf_py-2.5.12.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/core/types.py
CHANGED
|
@@ -36,239 +36,274 @@ if TYPE_CHECKING:
|
|
|
36
36
|
# Comprehensive enums for type safety and improved developer experience
|
|
37
37
|
class Model(str, Enum):
|
|
38
38
|
"""Supported model identifiers."""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
39
|
+
|
|
40
|
+
GEMINI_2_0_FLASH = "gemini-2.0-flash"
|
|
41
|
+
GEMINI_2_5_PRO = "gemini-2.5-pro"
|
|
42
|
+
GEMINI_PRO = "gemini-pro"
|
|
43
|
+
GPT_4 = "gpt-4"
|
|
44
|
+
GPT_4_TURBO = "gpt-4-turbo"
|
|
45
|
+
GPT_3_5_TURBO = "gpt-3.5-turbo"
|
|
46
|
+
CLAUDE_3_SONNET = "claude-3-sonnet"
|
|
47
|
+
CLAUDE_3_HAIKU = "claude-3-haiku"
|
|
48
|
+
CLAUDE_3_OPUS = "claude-3-opus"
|
|
48
49
|
|
|
49
50
|
|
|
50
51
|
class ToolParameterType(str, Enum):
|
|
51
52
|
"""Tool parameter types."""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
|
|
54
|
+
STRING = "string"
|
|
55
|
+
NUMBER = "number"
|
|
56
|
+
INTEGER = "integer"
|
|
57
|
+
BOOLEAN = "boolean"
|
|
58
|
+
ARRAY = "array"
|
|
59
|
+
OBJECT = "object"
|
|
60
|
+
NULL = "null"
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
class ToolSource(str, Enum):
|
|
62
64
|
"""Source of tool definitions."""
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
|
|
66
|
+
NATIVE = "native"
|
|
67
|
+
MCP = "mcp"
|
|
68
|
+
PLUGIN = "plugin"
|
|
69
|
+
EXTERNAL = "external"
|
|
67
70
|
|
|
68
71
|
|
|
69
72
|
class ContentRole(str, Enum):
|
|
70
73
|
"""Message content roles."""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
74
|
+
|
|
75
|
+
USER = "user"
|
|
76
|
+
ASSISTANT = "assistant"
|
|
77
|
+
TOOL = "tool"
|
|
78
|
+
SYSTEM = "system"
|
|
75
79
|
|
|
76
80
|
|
|
77
81
|
class PartType(str, Enum):
|
|
78
82
|
"""Message part types."""
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
83
|
+
|
|
84
|
+
TEXT = "text"
|
|
85
|
+
IMAGE = "image"
|
|
86
|
+
AUDIO = "audio"
|
|
87
|
+
VIDEO = "video"
|
|
88
|
+
FILE = "file"
|
|
84
89
|
|
|
85
90
|
|
|
86
91
|
# Branded types for type safety - using class-based approach for better type safety
|
|
87
92
|
class TraceId(str):
|
|
88
93
|
"""Branded string type for trace IDs."""
|
|
89
|
-
|
|
94
|
+
|
|
95
|
+
def __new__(cls, value: str) -> "TraceId":
|
|
90
96
|
return str.__new__(cls, value)
|
|
91
97
|
|
|
98
|
+
|
|
92
99
|
class RunId(str):
|
|
93
100
|
"""Branded string type for run IDs."""
|
|
94
|
-
|
|
101
|
+
|
|
102
|
+
def __new__(cls, value: str) -> "RunId":
|
|
95
103
|
return str.__new__(cls, value)
|
|
96
104
|
|
|
105
|
+
|
|
97
106
|
class MessageId(str):
|
|
98
107
|
"""Branded string type for message IDs."""
|
|
99
|
-
|
|
108
|
+
|
|
109
|
+
def __new__(cls, value: str) -> "MessageId":
|
|
100
110
|
return str.__new__(cls, value)
|
|
101
111
|
|
|
112
|
+
|
|
102
113
|
def create_trace_id(id_str: str) -> TraceId:
|
|
103
114
|
"""Create a TraceId from a string."""
|
|
104
115
|
return TraceId(id_str)
|
|
105
116
|
|
|
117
|
+
|
|
106
118
|
def create_run_id(id_str: str) -> RunId:
|
|
107
119
|
"""Create a RunId from a string."""
|
|
108
120
|
return RunId(id_str)
|
|
109
121
|
|
|
122
|
+
|
|
110
123
|
def create_message_id(id_str: Union[str, MessageId]) -> MessageId:
|
|
111
124
|
"""
|
|
112
125
|
Create a MessageId from a string or return existing MessageId.
|
|
113
|
-
|
|
126
|
+
|
|
114
127
|
Args:
|
|
115
128
|
id_str: Either a string to convert to MessageId or an existing MessageId
|
|
116
|
-
|
|
129
|
+
|
|
117
130
|
Returns:
|
|
118
131
|
MessageId: A validated MessageId instance
|
|
119
|
-
|
|
132
|
+
|
|
120
133
|
Raises:
|
|
121
134
|
ValueError: If the input is invalid or empty
|
|
122
135
|
"""
|
|
123
136
|
# Handle None input
|
|
124
137
|
if id_str is None:
|
|
125
138
|
raise ValueError("Message ID cannot be None")
|
|
126
|
-
|
|
139
|
+
|
|
127
140
|
# If already a MessageId, return as-is
|
|
128
141
|
if isinstance(id_str, MessageId):
|
|
129
142
|
return id_str
|
|
130
|
-
|
|
143
|
+
|
|
131
144
|
# Convert string to MessageId with validation
|
|
132
145
|
if isinstance(id_str, str):
|
|
133
146
|
if not id_str.strip():
|
|
134
147
|
raise ValueError("Message ID cannot be empty or whitespace")
|
|
135
148
|
return MessageId(id_str.strip())
|
|
136
|
-
|
|
149
|
+
|
|
137
150
|
# Handle any other type
|
|
138
151
|
raise ValueError(f"Message ID must be a string or MessageId, got {type(id_str)}")
|
|
139
152
|
|
|
153
|
+
|
|
140
154
|
def generate_run_id() -> RunId:
|
|
141
155
|
"""Generate a new unique run ID."""
|
|
142
156
|
import time
|
|
143
157
|
import uuid
|
|
158
|
+
|
|
144
159
|
return RunId(f"run_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
145
160
|
|
|
161
|
+
|
|
146
162
|
def generate_trace_id() -> TraceId:
|
|
147
163
|
"""Generate a new unique trace ID."""
|
|
148
164
|
import time
|
|
149
165
|
import uuid
|
|
166
|
+
|
|
150
167
|
return TraceId(f"trace_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
151
168
|
|
|
169
|
+
|
|
152
170
|
def generate_message_id() -> MessageId:
|
|
153
171
|
"""Generate a new unique message ID."""
|
|
154
172
|
import time
|
|
155
173
|
import uuid
|
|
174
|
+
|
|
156
175
|
return MessageId(f"msg_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
157
176
|
|
|
177
|
+
|
|
158
178
|
# Type variables for generic contexts and outputs
|
|
159
|
-
Ctx = TypeVar(
|
|
160
|
-
Out = TypeVar(
|
|
161
|
-
Args = TypeVar(
|
|
179
|
+
Ctx = TypeVar("Ctx")
|
|
180
|
+
Out = TypeVar("Out")
|
|
181
|
+
Args = TypeVar("Args")
|
|
182
|
+
|
|
162
183
|
|
|
163
184
|
# Discriminated union for ValidationResult to match TypeScript version
|
|
164
185
|
@dataclass(frozen=True)
|
|
165
186
|
class ValidValidationResult:
|
|
166
187
|
"""Valid validation result."""
|
|
188
|
+
|
|
167
189
|
is_valid: Literal[True] = True
|
|
168
190
|
|
|
191
|
+
|
|
169
192
|
@dataclass(frozen=True)
|
|
170
193
|
class InvalidValidationResult:
|
|
171
194
|
"""Invalid validation result with error message."""
|
|
195
|
+
|
|
172
196
|
is_valid: Literal[False] = False
|
|
173
197
|
error_message: str = ""
|
|
174
198
|
|
|
199
|
+
|
|
175
200
|
ValidationResult = Union[ValidValidationResult, InvalidValidationResult]
|
|
176
201
|
|
|
202
|
+
|
|
177
203
|
@dataclass(frozen=True)
|
|
178
204
|
class ToolCall:
|
|
179
205
|
"""Represents a tool call from the model."""
|
|
206
|
+
|
|
180
207
|
id: str
|
|
181
|
-
type: Literal[
|
|
182
|
-
function:
|
|
208
|
+
type: Literal["function"]
|
|
209
|
+
function: "ToolCallFunction"
|
|
210
|
+
|
|
183
211
|
|
|
184
212
|
@dataclass(frozen=True)
|
|
185
213
|
class ToolCallFunction:
|
|
186
214
|
"""Function information within a tool call."""
|
|
215
|
+
|
|
187
216
|
name: str
|
|
188
217
|
arguments: str
|
|
189
218
|
|
|
219
|
+
|
|
190
220
|
@dataclass(frozen=True)
|
|
191
221
|
class Attachment:
|
|
192
222
|
"""Represents an attachment with various content types."""
|
|
193
|
-
|
|
223
|
+
|
|
224
|
+
kind: Literal["image", "document", "file"]
|
|
194
225
|
mime_type: Optional[str] = None # e.g. image/png, application/pdf
|
|
195
|
-
name: Optional[str] = None
|
|
196
|
-
url: Optional[str] = None
|
|
197
|
-
data: Optional[str] = None
|
|
198
|
-
format: Optional[str] = None
|
|
226
|
+
name: Optional[str] = None # Optional filename
|
|
227
|
+
url: Optional[str] = None # Remote URL or data URL
|
|
228
|
+
data: Optional[str] = None # Base64 without data: prefix
|
|
229
|
+
format: Optional[str] = None # Optional short format like 'pdf', 'txt'
|
|
199
230
|
use_litellm_format: Optional[bool] = None # Use LiteLLM native file format
|
|
200
|
-
|
|
231
|
+
|
|
201
232
|
def __post_init__(self):
|
|
202
233
|
"""Validate that at least one of url or data is provided."""
|
|
203
234
|
if self.url is None and self.data is None:
|
|
204
235
|
raise ValueError("At least one of 'url' or 'data' must be provided for an Attachment.")
|
|
205
236
|
|
|
237
|
+
|
|
206
238
|
@dataclass(frozen=True)
|
|
207
239
|
class MessageContentPart:
|
|
208
240
|
"""Part of multi-part message content."""
|
|
209
|
-
|
|
241
|
+
|
|
242
|
+
type: Literal["text", "image_url", "file"]
|
|
210
243
|
text: Optional[str] = None
|
|
211
244
|
image_url: Optional[Dict[str, Any]] = None # Contains url and optional detail
|
|
212
|
-
file: Optional[Dict[str, Any]] = None
|
|
245
|
+
file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
|
|
246
|
+
|
|
213
247
|
|
|
214
248
|
@dataclass(frozen=True)
|
|
215
249
|
class Message:
|
|
216
250
|
"""
|
|
217
251
|
A message in the conversation.
|
|
218
|
-
|
|
252
|
+
|
|
219
253
|
BACKWARDS COMPATIBILITY:
|
|
220
254
|
- Messages created with string content remain fully backwards compatible
|
|
221
255
|
- Direct access to .content returns the original string when created with string
|
|
222
256
|
- Use .text_content property for guaranteed string access in all cases
|
|
223
257
|
- Use get_text_content() function to extract text from any content type
|
|
224
258
|
- message_id is optional for backward compatibility
|
|
225
|
-
|
|
259
|
+
|
|
226
260
|
Examples:
|
|
227
261
|
# Original usage - still works exactly the same
|
|
228
262
|
msg = Message(role='user', content='Hello')
|
|
229
263
|
text = msg.content # Returns 'Hello' as string
|
|
230
|
-
|
|
264
|
+
|
|
231
265
|
# New usage with message ID
|
|
232
266
|
msg = Message(role='user', content='Hello', message_id='msg_123')
|
|
233
|
-
|
|
267
|
+
|
|
234
268
|
# Guaranteed string access (recommended for new code)
|
|
235
269
|
text = msg.text_content # Always returns string
|
|
236
|
-
|
|
270
|
+
|
|
237
271
|
# Universal text extraction
|
|
238
272
|
text = get_text_content(msg.content) # Works with any content type
|
|
239
273
|
"""
|
|
274
|
+
|
|
240
275
|
role: ContentRole
|
|
241
276
|
content: Union[str, List[MessageContentPart]]
|
|
242
277
|
attachments: Optional[List[Attachment]] = None
|
|
243
278
|
tool_call_id: Optional[str] = None
|
|
244
279
|
tool_calls: Optional[List[ToolCall]] = None
|
|
245
280
|
message_id: Optional[MessageId] = None # Optional for backward compatibility
|
|
246
|
-
|
|
281
|
+
|
|
247
282
|
def __post_init__(self):
|
|
248
283
|
"""
|
|
249
284
|
Auto-generate message ID if not provided.
|
|
250
|
-
|
|
285
|
+
|
|
251
286
|
This implementation uses object.__setattr__ to bypass frozen dataclass restrictions,
|
|
252
287
|
which is a recommended pattern for one-time initialization of computed fields in
|
|
253
288
|
frozen dataclasses. This ensures:
|
|
254
|
-
|
|
289
|
+
|
|
255
290
|
1. Backward compatibility - existing code with message_id=None continues to work
|
|
256
291
|
2. Immutability - the dataclass remains frozen after initialization
|
|
257
292
|
3. Guaranteed unique IDs - every message gets a unique identifier
|
|
258
293
|
4. Clean API - users don't need to manually generate IDs in most cases
|
|
259
|
-
|
|
294
|
+
|
|
260
295
|
This pattern is preferred over using field(default_factory=...) because it
|
|
261
296
|
maintains the Optional[MessageId] type hint for backward compatibility while
|
|
262
297
|
ensuring the field is never actually None after object creation.
|
|
263
298
|
"""
|
|
264
299
|
if self.message_id is None:
|
|
265
|
-
object.__setattr__(self,
|
|
266
|
-
|
|
300
|
+
object.__setattr__(self, "message_id", generate_message_id())
|
|
301
|
+
|
|
267
302
|
@property
|
|
268
303
|
def text_content(self) -> str:
|
|
269
304
|
"""Get text content as string for backwards compatibility."""
|
|
270
305
|
return get_text_content(self.content)
|
|
271
|
-
|
|
306
|
+
|
|
272
307
|
@classmethod
|
|
273
308
|
def create(
|
|
274
309
|
cls,
|
|
@@ -277,8 +312,8 @@ class Message:
|
|
|
277
312
|
attachments: Optional[List[Attachment]] = None,
|
|
278
313
|
tool_call_id: Optional[str] = None,
|
|
279
314
|
tool_calls: Optional[List[ToolCall]] = None,
|
|
280
|
-
message_id: Optional[MessageId] = None
|
|
281
|
-
) ->
|
|
315
|
+
message_id: Optional[MessageId] = None,
|
|
316
|
+
) -> "Message":
|
|
282
317
|
"""Create a message with string content and optional attachments."""
|
|
283
318
|
return cls(
|
|
284
319
|
role=role,
|
|
@@ -286,32 +321,43 @@ class Message:
|
|
|
286
321
|
attachments=attachments,
|
|
287
322
|
tool_call_id=tool_call_id,
|
|
288
323
|
tool_calls=tool_calls,
|
|
289
|
-
message_id=message_id
|
|
324
|
+
message_id=message_id,
|
|
290
325
|
)
|
|
291
326
|
|
|
327
|
+
|
|
292
328
|
def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
|
|
293
329
|
"""Extract text content from message content."""
|
|
294
330
|
if isinstance(content, str):
|
|
295
331
|
return content
|
|
296
|
-
|
|
297
|
-
text_parts = [part.text for part in content if part.type ==
|
|
298
|
-
return
|
|
332
|
+
|
|
333
|
+
text_parts = [part.text for part in content if part.type == "text" and part.text]
|
|
334
|
+
return " ".join(text_parts)
|
|
335
|
+
|
|
299
336
|
|
|
300
337
|
@dataclass(frozen=True)
|
|
301
338
|
class ModelConfig:
|
|
302
339
|
"""Configuration for model behavior."""
|
|
340
|
+
|
|
303
341
|
name: Optional[str] = None
|
|
304
342
|
temperature: Optional[float] = None
|
|
305
343
|
max_tokens: Optional[int] = None
|
|
344
|
+
inline_tool_schemas: Optional[bool] = (
|
|
345
|
+
None # If True, resolve $refs and inline $defs in tool schemas
|
|
346
|
+
)
|
|
347
|
+
|
|
306
348
|
|
|
307
349
|
@dataclass(frozen=True)
|
|
308
350
|
class ToolSchema(Generic[Args]):
|
|
309
351
|
"""Schema definition for a tool."""
|
|
352
|
+
|
|
310
353
|
name: str
|
|
311
354
|
description: str
|
|
312
|
-
parameters:
|
|
355
|
+
parameters: (
|
|
356
|
+
Any # Should be a type that can validate Args (like Pydantic model or Zod equivalent)
|
|
357
|
+
)
|
|
313
358
|
timeout: Optional[float] = None # Optional timeout in seconds for tool execution
|
|
314
359
|
|
|
360
|
+
|
|
315
361
|
@runtime_checkable
|
|
316
362
|
class Tool(Protocol[Args, Ctx]):
|
|
317
363
|
"""Protocol for tool implementations."""
|
|
@@ -321,7 +367,7 @@ class Tool(Protocol[Args, Ctx]):
|
|
|
321
367
|
"""Tool schema including name, description, and parameter validation."""
|
|
322
368
|
...
|
|
323
369
|
|
|
324
|
-
async def execute(self, args: Args, context: Ctx) -> Union[str,
|
|
370
|
+
async def execute(self, args: Args, context: Ctx) -> Union[str, "ToolResult[Any]"]:
|
|
325
371
|
"""Execute the tool with given arguments and context."""
|
|
326
372
|
...
|
|
327
373
|
|
|
@@ -334,9 +380,12 @@ class Tool(Protocol[Args, Ctx]):
|
|
|
334
380
|
# Function tool configuration for improved DX
|
|
335
381
|
class FunctionToolConfig(TypedDict):
|
|
336
382
|
"""Configuration for creating function-based tools with object-based API."""
|
|
383
|
+
|
|
337
384
|
name: str
|
|
338
385
|
description: str
|
|
339
|
-
execute: Callable[
|
|
386
|
+
execute: Callable[
|
|
387
|
+
[Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
|
|
388
|
+
]
|
|
340
389
|
parameters: Any # Pydantic model or similar for parameter validation
|
|
341
390
|
metadata: Optional[Dict[str, Any]] # Optional metadata
|
|
342
391
|
source: Optional[ToolSource] # Optional source tracking
|
|
@@ -344,34 +393,45 @@ class FunctionToolConfig(TypedDict):
|
|
|
344
393
|
|
|
345
394
|
|
|
346
395
|
# Type alias for tool execution functions
|
|
347
|
-
ToolExecuteFunction = Callable[
|
|
396
|
+
ToolExecuteFunction = Callable[
|
|
397
|
+
[Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
|
|
398
|
+
]
|
|
348
399
|
|
|
349
400
|
|
|
350
401
|
@dataclass(frozen=True)
|
|
351
402
|
class Agent(Generic[Ctx, Out]):
|
|
352
403
|
"""An agent definition with instructions, tools, and configuration."""
|
|
404
|
+
|
|
353
405
|
name: str
|
|
354
|
-
instructions: Callable[[
|
|
406
|
+
instructions: Callable[["RunState[Ctx]"], str]
|
|
355
407
|
tools: Optional[List[Tool[Any, Ctx]]] = None
|
|
356
|
-
output_codec: Optional[Any] =
|
|
408
|
+
output_codec: Optional[Any] = (
|
|
409
|
+
None # Type that can validate Out (like Pydantic model or Zod equivalent)
|
|
410
|
+
)
|
|
357
411
|
handoffs: Optional[List[str]] = None
|
|
358
412
|
model_config: Optional[ModelConfig] = None
|
|
359
|
-
advanced_config: Optional[
|
|
413
|
+
advanced_config: Optional["AdvancedConfig"] = None
|
|
360
414
|
|
|
361
415
|
def as_tool(
|
|
362
416
|
self,
|
|
363
417
|
tool_name: Optional[str] = None,
|
|
364
418
|
tool_description: Optional[str] = None,
|
|
365
419
|
max_turns: Optional[int] = None,
|
|
366
|
-
custom_output_extractor: Optional[
|
|
367
|
-
|
|
420
|
+
custom_output_extractor: Optional[
|
|
421
|
+
Callable[["RunResult[Out]"], Union[str, Awaitable[str]]]
|
|
422
|
+
] = None,
|
|
423
|
+
is_enabled: Union[
|
|
424
|
+
bool,
|
|
425
|
+
Callable[[Any, "Agent[Ctx, Out]"], bool],
|
|
426
|
+
Callable[[Any, "Agent[Ctx, Out]"], Awaitable[bool]],
|
|
427
|
+
] = True,
|
|
368
428
|
metadata: Optional[Dict[str, Any]] = None,
|
|
369
429
|
timeout: Optional[float] = None,
|
|
370
|
-
preserve_session: bool = False
|
|
430
|
+
preserve_session: bool = False,
|
|
371
431
|
) -> Tool[Any, Ctx]:
|
|
372
432
|
"""
|
|
373
433
|
Convert this agent into a tool that can be used by other agents.
|
|
374
|
-
|
|
434
|
+
|
|
375
435
|
Args:
|
|
376
436
|
tool_name: Optional custom name for the tool (defaults to agent name)
|
|
377
437
|
tool_description: Optional custom description (defaults to generic description)
|
|
@@ -380,11 +440,12 @@ class Agent(Generic[Ctx, Out]):
|
|
|
380
440
|
is_enabled: Whether the tool is enabled (bool, sync function, or async function)
|
|
381
441
|
metadata: Optional metadata for the tool
|
|
382
442
|
timeout: Optional timeout for the tool execution
|
|
383
|
-
|
|
443
|
+
|
|
384
444
|
Returns:
|
|
385
445
|
A Tool that wraps this agent's execution
|
|
386
446
|
"""
|
|
387
447
|
from .agent_tool import create_agent_tool
|
|
448
|
+
|
|
388
449
|
return create_agent_tool(
|
|
389
450
|
agent=self,
|
|
390
451
|
tool_name=tool_name,
|
|
@@ -394,90 +455,109 @@ class Agent(Generic[Ctx, Out]):
|
|
|
394
455
|
is_enabled=is_enabled,
|
|
395
456
|
metadata=metadata,
|
|
396
457
|
timeout=timeout,
|
|
397
|
-
preserve_session=preserve_session
|
|
458
|
+
preserve_session=preserve_session,
|
|
398
459
|
)
|
|
399
460
|
|
|
461
|
+
|
|
400
462
|
# Guardrail type
|
|
401
463
|
Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
|
|
402
464
|
|
|
465
|
+
|
|
403
466
|
@dataclass(frozen=True)
|
|
404
467
|
class AdvancedGuardrailsConfig:
|
|
405
468
|
"""Configuration for advanced guardrails with LLM-based validation."""
|
|
469
|
+
|
|
406
470
|
input_prompt: Optional[str] = None
|
|
407
471
|
output_prompt: Optional[str] = None
|
|
408
472
|
require_citations: bool = False
|
|
409
473
|
fast_model: Optional[str] = None
|
|
410
|
-
fail_safe: Literal[
|
|
411
|
-
execution_mode: Literal[
|
|
474
|
+
fail_safe: Literal["allow", "block"] = "allow"
|
|
475
|
+
execution_mode: Literal["parallel", "sequential"] = "parallel"
|
|
412
476
|
timeout_ms: int = 30000
|
|
413
477
|
|
|
414
478
|
def __post_init__(self):
|
|
415
479
|
"""Validate configuration."""
|
|
416
480
|
if self.timeout_ms < 1000:
|
|
417
|
-
object.__setattr__(self,
|
|
481
|
+
object.__setattr__(self, "timeout_ms", 1000)
|
|
482
|
+
|
|
418
483
|
|
|
419
484
|
@dataclass(frozen=True)
|
|
420
485
|
class AdvancedConfig:
|
|
421
486
|
"""Advanced agent configuration including guardrails."""
|
|
487
|
+
|
|
422
488
|
guardrails: Optional[AdvancedGuardrailsConfig] = None
|
|
423
489
|
|
|
424
|
-
|
|
490
|
+
|
|
491
|
+
def validate_guardrails_config(
|
|
492
|
+
config: Optional[AdvancedGuardrailsConfig],
|
|
493
|
+
) -> AdvancedGuardrailsConfig:
|
|
425
494
|
"""Validate and provide defaults for guardrails configuration."""
|
|
426
495
|
if config is None:
|
|
427
496
|
return AdvancedGuardrailsConfig()
|
|
428
|
-
|
|
497
|
+
|
|
429
498
|
return AdvancedGuardrailsConfig(
|
|
430
|
-
input_prompt=config.input_prompt.strip()
|
|
431
|
-
|
|
499
|
+
input_prompt=config.input_prompt.strip()
|
|
500
|
+
if isinstance(config.input_prompt, str) and config.input_prompt
|
|
501
|
+
else None,
|
|
502
|
+
output_prompt=config.output_prompt.strip()
|
|
503
|
+
if isinstance(config.output_prompt, str) and config.output_prompt
|
|
504
|
+
else None,
|
|
432
505
|
require_citations=config.require_citations,
|
|
433
|
-
fast_model=config.fast_model.strip()
|
|
506
|
+
fast_model=config.fast_model.strip()
|
|
507
|
+
if isinstance(config.fast_model, str) and config.fast_model
|
|
508
|
+
else None,
|
|
434
509
|
fail_safe=config.fail_safe,
|
|
435
510
|
execution_mode=config.execution_mode,
|
|
436
|
-
timeout_ms=max(1000, config.timeout_ms)
|
|
511
|
+
timeout_ms=max(1000, config.timeout_ms),
|
|
437
512
|
)
|
|
438
513
|
|
|
514
|
+
|
|
439
515
|
def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
|
|
440
516
|
"""Parse JSON from LLM output, handling common formatting issues."""
|
|
441
517
|
import json
|
|
442
518
|
import re
|
|
443
|
-
|
|
519
|
+
|
|
444
520
|
if not text:
|
|
445
521
|
return None
|
|
446
|
-
|
|
522
|
+
|
|
447
523
|
# Try direct parsing first
|
|
448
524
|
try:
|
|
449
525
|
return json.loads(text)
|
|
450
526
|
except json.JSONDecodeError:
|
|
451
527
|
pass
|
|
452
|
-
|
|
528
|
+
|
|
453
529
|
# Try to extract JSON from markdown code blocks
|
|
454
|
-
json_match = re.search(r
|
|
530
|
+
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
|
455
531
|
if json_match:
|
|
456
532
|
try:
|
|
457
533
|
return json.loads(json_match.group(1))
|
|
458
534
|
except json.JSONDecodeError:
|
|
459
535
|
pass
|
|
460
|
-
|
|
536
|
+
|
|
461
537
|
# Try to find the first JSON object in the text
|
|
462
|
-
json_match = re.search(r
|
|
538
|
+
json_match = re.search(r"\{.*?\}", text, re.DOTALL)
|
|
463
539
|
if json_match:
|
|
464
540
|
try:
|
|
465
541
|
return json.loads(json_match.group(0))
|
|
466
542
|
except json.JSONDecodeError:
|
|
467
543
|
pass
|
|
468
|
-
|
|
544
|
+
|
|
469
545
|
return None
|
|
470
546
|
|
|
547
|
+
|
|
471
548
|
@dataclass(frozen=True)
|
|
472
549
|
class ApprovalValue:
|
|
473
550
|
"""Represents an approval decision with context."""
|
|
551
|
+
|
|
474
552
|
status: str # 'pending', 'approved', 'rejected'
|
|
475
553
|
approved: bool
|
|
476
554
|
additional_context: Optional[Dict[str, Any]] = None
|
|
477
555
|
|
|
556
|
+
|
|
478
557
|
@dataclass(frozen=True)
|
|
479
558
|
class RunState(Generic[Ctx]):
|
|
480
559
|
"""Immutable state of a run."""
|
|
560
|
+
|
|
481
561
|
run_id: RunId
|
|
482
562
|
trace_id: TraceId
|
|
483
563
|
messages: List[Message]
|
|
@@ -486,51 +566,61 @@ class RunState(Generic[Ctx]):
|
|
|
486
566
|
turn_count: int
|
|
487
567
|
approvals: Dict[str, ApprovalValue] = field(default_factory=dict)
|
|
488
568
|
|
|
569
|
+
|
|
489
570
|
# Error types using dataclasses for immutability
|
|
490
571
|
@dataclass(frozen=True)
|
|
491
572
|
class MaxTurnsExceeded:
|
|
492
573
|
_tag: Literal["MaxTurnsExceeded"] = "MaxTurnsExceeded"
|
|
493
574
|
turns: int = 0
|
|
494
575
|
|
|
576
|
+
|
|
495
577
|
@dataclass(frozen=True)
|
|
496
578
|
class ModelBehaviorError:
|
|
497
579
|
_tag: Literal["ModelBehaviorError"] = "ModelBehaviorError"
|
|
498
580
|
detail: str = ""
|
|
499
581
|
|
|
582
|
+
|
|
500
583
|
@dataclass(frozen=True)
|
|
501
584
|
class DecodeError:
|
|
502
585
|
_tag: Literal["DecodeError"] = "DecodeError"
|
|
503
586
|
errors: List[Dict[str, Any]] = field(default_factory=list)
|
|
504
587
|
|
|
588
|
+
|
|
505
589
|
@dataclass(frozen=True)
|
|
506
590
|
class InputGuardrailTripwire:
|
|
507
591
|
_tag: Literal["InputGuardrailTripwire"] = "InputGuardrailTripwire"
|
|
508
592
|
reason: str = ""
|
|
509
593
|
|
|
594
|
+
|
|
510
595
|
@dataclass(frozen=True)
|
|
511
596
|
class OutputGuardrailTripwire:
|
|
512
597
|
_tag: Literal["OutputGuardrailTripwire"] = "OutputGuardrailTripwire"
|
|
513
598
|
reason: str = ""
|
|
514
599
|
|
|
600
|
+
|
|
515
601
|
@dataclass(frozen=True)
|
|
516
602
|
class ToolCallError:
|
|
517
603
|
_tag: Literal["ToolCallError"] = "ToolCallError"
|
|
518
604
|
tool: str = ""
|
|
519
605
|
detail: str = ""
|
|
520
606
|
|
|
607
|
+
|
|
521
608
|
@dataclass(frozen=True)
|
|
522
609
|
class HandoffError:
|
|
523
610
|
_tag: Literal["HandoffError"] = "HandoffError"
|
|
524
611
|
detail: str = ""
|
|
525
612
|
|
|
613
|
+
|
|
526
614
|
@dataclass(frozen=True)
|
|
527
615
|
class AgentNotFound:
|
|
528
616
|
_tag: Literal["AgentNotFound"] = "AgentNotFound"
|
|
529
617
|
agent_name: str = ""
|
|
530
618
|
|
|
619
|
+
|
|
531
620
|
@dataclass(frozen=True)
|
|
532
621
|
class RecoverableError:
|
|
533
622
|
"""Error that can be recovered from with retry strategies."""
|
|
623
|
+
|
|
534
624
|
_tag: Literal["RecoverableError"] = "RecoverableError"
|
|
535
625
|
detail: str = ""
|
|
536
626
|
retry_strategy: Optional[str] = None
|
|
@@ -538,40 +628,51 @@ class RecoverableError:
|
|
|
538
628
|
current_attempt: int = 1
|
|
539
629
|
backoff_seconds: float = 1.0
|
|
540
630
|
|
|
631
|
+
|
|
541
632
|
@dataclass(frozen=True)
|
|
542
633
|
class RateLimitError:
|
|
543
634
|
"""Error indicating rate limit has been exceeded."""
|
|
635
|
+
|
|
544
636
|
_tag: Literal["RateLimitError"] = "RateLimitError"
|
|
545
637
|
detail: str = ""
|
|
546
638
|
retry_after_seconds: Optional[float] = None
|
|
547
639
|
limit_type: str = "requests" # "requests", "tokens", "concurrent"
|
|
548
640
|
|
|
641
|
+
|
|
549
642
|
@dataclass(frozen=True)
|
|
550
643
|
class ValidationError:
|
|
551
644
|
"""Enhanced validation error with detailed context."""
|
|
645
|
+
|
|
552
646
|
_tag: Literal["ValidationError"] = "ValidationError"
|
|
553
647
|
detail: str = ""
|
|
554
648
|
field_errors: List[Dict[str, Any]] = field(default_factory=list)
|
|
555
649
|
error_code: str = "validation_failed"
|
|
556
650
|
|
|
651
|
+
|
|
557
652
|
@dataclass(frozen=True)
|
|
558
653
|
class NetworkError:
|
|
559
654
|
"""Network-related errors with retry information."""
|
|
655
|
+
|
|
560
656
|
_tag: Literal["NetworkError"] = "NetworkError"
|
|
561
657
|
detail: str = ""
|
|
562
658
|
status_code: Optional[int] = None
|
|
563
659
|
is_retryable: bool = True
|
|
564
660
|
endpoint: Optional[str] = None
|
|
565
661
|
|
|
662
|
+
|
|
566
663
|
# Interruption types for HITL
|
|
567
664
|
@dataclass(frozen=True)
|
|
568
665
|
class ToolApprovalInterruption(Generic[Ctx]):
|
|
569
666
|
"""Interruption for tool approval."""
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
667
|
+
|
|
668
|
+
type: Literal["tool_approval"] = "tool_approval"
|
|
669
|
+
tool_call: ToolCall = field(
|
|
670
|
+
default_factory=lambda: ToolCall("", "function", ToolCallFunction("", ""))
|
|
671
|
+
)
|
|
672
|
+
agent: "Agent[Ctx, Any]" = None
|
|
573
673
|
session_id: Optional[str] = None
|
|
574
674
|
|
|
675
|
+
|
|
575
676
|
# Union type for all interruptions
|
|
576
677
|
Interruption = Union[ToolApprovalInterruption[Any]]
|
|
577
678
|
|
|
@@ -588,40 +689,51 @@ JAFError = Union[
|
|
|
588
689
|
RecoverableError,
|
|
589
690
|
RateLimitError,
|
|
590
691
|
ValidationError,
|
|
591
|
-
NetworkError
|
|
692
|
+
NetworkError,
|
|
592
693
|
]
|
|
593
694
|
|
|
695
|
+
|
|
594
696
|
@dataclass(frozen=True)
|
|
595
697
|
class CompletedOutcome(Generic[Out]):
|
|
596
698
|
"""Successful completion outcome."""
|
|
597
|
-
|
|
699
|
+
|
|
700
|
+
status: Literal["completed"] = "completed"
|
|
598
701
|
output: Out = field(default=None)
|
|
599
702
|
|
|
703
|
+
|
|
600
704
|
@dataclass(frozen=True)
|
|
601
705
|
class ErrorOutcome:
|
|
602
706
|
"""Error outcome."""
|
|
603
|
-
|
|
707
|
+
|
|
708
|
+
status: Literal["error"] = "error"
|
|
604
709
|
error: JAFError = field(default=None)
|
|
605
710
|
|
|
711
|
+
|
|
606
712
|
@dataclass(frozen=True)
|
|
607
713
|
class InterruptedOutcome:
|
|
608
714
|
"""Interrupted outcome for HITL."""
|
|
609
|
-
|
|
715
|
+
|
|
716
|
+
status: Literal["interrupted"] = "interrupted"
|
|
610
717
|
interruptions: List[Interruption] = field(default_factory=list)
|
|
611
718
|
|
|
719
|
+
|
|
612
720
|
# Union type for outcomes
|
|
613
721
|
RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome, InterruptedOutcome]
|
|
614
722
|
|
|
723
|
+
|
|
615
724
|
@dataclass(frozen=True)
|
|
616
725
|
class RunResult(Generic[Out]):
|
|
617
726
|
"""Result of a run execution."""
|
|
727
|
+
|
|
618
728
|
final_state: RunState[Any]
|
|
619
729
|
outcome: RunOutcome[Out]
|
|
620
730
|
|
|
731
|
+
|
|
621
732
|
# Trace event types with specific data structures to match TypeScript
|
|
622
733
|
@dataclass(frozen=True)
|
|
623
734
|
class RunStartEventData:
|
|
624
735
|
"""Data for run start events."""
|
|
736
|
+
|
|
625
737
|
run_id: RunId
|
|
626
738
|
trace_id: TraceId
|
|
627
739
|
session_id: Optional[str] = None
|
|
@@ -629,14 +741,19 @@ class RunStartEventData:
|
|
|
629
741
|
messages: Optional[List[Message]] = None
|
|
630
742
|
agent_name: Optional[str] = None
|
|
631
743
|
|
|
744
|
+
|
|
632
745
|
@dataclass(frozen=True)
|
|
633
746
|
class RunStartEvent:
|
|
634
|
-
type: Literal[
|
|
635
|
-
data: RunStartEventData = field(
|
|
747
|
+
type: Literal["run_start"] = "run_start"
|
|
748
|
+
data: RunStartEventData = field(
|
|
749
|
+
default_factory=lambda: RunStartEventData(RunId(""), TraceId(""))
|
|
750
|
+
)
|
|
751
|
+
|
|
636
752
|
|
|
637
753
|
@dataclass(frozen=True)
|
|
638
754
|
class LLMCallStartEventData:
|
|
639
755
|
"""Data for LLM call start events."""
|
|
756
|
+
|
|
640
757
|
agent_name: str
|
|
641
758
|
model: str
|
|
642
759
|
trace_id: TraceId
|
|
@@ -644,46 +761,68 @@ class LLMCallStartEventData:
|
|
|
644
761
|
context: Optional[Any] = None
|
|
645
762
|
messages: Optional[List[Message]] = None
|
|
646
763
|
|
|
764
|
+
|
|
647
765
|
@dataclass(frozen=True)
|
|
648
766
|
class LLMCallStartEvent:
|
|
649
|
-
type: Literal[
|
|
650
|
-
data: LLMCallStartEventData = field(
|
|
767
|
+
type: Literal["llm_call_start"] = "llm_call_start"
|
|
768
|
+
data: LLMCallStartEventData = field(
|
|
769
|
+
default_factory=lambda: LLMCallStartEventData("", "", TraceId(""), RunId(""))
|
|
770
|
+
)
|
|
771
|
+
|
|
651
772
|
|
|
652
773
|
@dataclass(frozen=True)
|
|
653
774
|
class LLMCallEndEventData:
|
|
654
775
|
"""Data for LLM call end events."""
|
|
776
|
+
|
|
655
777
|
choice: Any
|
|
656
778
|
trace_id: TraceId
|
|
657
779
|
run_id: RunId
|
|
658
780
|
usage: Optional[Dict[str, int]] = None
|
|
659
781
|
|
|
782
|
+
|
|
660
783
|
@dataclass(frozen=True)
|
|
661
784
|
class LLMCallEndEvent:
|
|
662
|
-
type: Literal[
|
|
663
|
-
data: LLMCallEndEventData = field(
|
|
785
|
+
type: Literal["llm_call_end"] = "llm_call_end"
|
|
786
|
+
data: LLMCallEndEventData = field(
|
|
787
|
+
default_factory=lambda: LLMCallEndEventData(None, TraceId(""), RunId(""))
|
|
788
|
+
)
|
|
789
|
+
|
|
664
790
|
|
|
665
791
|
@dataclass(frozen=True)
|
|
666
792
|
class AssistantMessageEventData:
|
|
667
793
|
"""Data for assistant message events (partial or complete)."""
|
|
794
|
+
|
|
668
795
|
message: Message
|
|
669
796
|
|
|
797
|
+
|
|
670
798
|
@dataclass(frozen=True)
|
|
671
799
|
class AssistantMessageEvent:
|
|
672
|
-
type: Literal[
|
|
673
|
-
data: AssistantMessageEventData = field(
|
|
800
|
+
type: Literal["assistant_message"] = "assistant_message"
|
|
801
|
+
data: AssistantMessageEventData = field(
|
|
802
|
+
default_factory=lambda: AssistantMessageEventData(
|
|
803
|
+
Message(role=ContentRole.ASSISTANT, content="")
|
|
804
|
+
)
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
|
|
674
808
|
@dataclass(frozen=True)
|
|
675
809
|
class ToolCallStartEventData:
|
|
676
810
|
"""Data for tool call start events."""
|
|
811
|
+
|
|
677
812
|
tool_name: str
|
|
678
813
|
args: Any
|
|
679
814
|
trace_id: TraceId
|
|
680
815
|
run_id: RunId
|
|
681
816
|
call_id: Optional[str] = None
|
|
682
817
|
|
|
818
|
+
|
|
683
819
|
@dataclass(frozen=True)
|
|
684
820
|
class ToolCallStartEvent:
|
|
685
|
-
type: Literal[
|
|
686
|
-
data: ToolCallStartEventData = field(
|
|
821
|
+
type: Literal["tool_call_start"] = "tool_call_start"
|
|
822
|
+
data: ToolCallStartEventData = field(
|
|
823
|
+
default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None)
|
|
824
|
+
)
|
|
825
|
+
|
|
687
826
|
|
|
688
827
|
@dataclass(frozen=True)
|
|
689
828
|
class ToolCallEndEventData:
|
|
@@ -703,18 +842,27 @@ class ToolCallEndEventData:
|
|
|
703
842
|
- 'rejected': Tool was rejected by user
|
|
704
843
|
- 'execution_error', 'validation_error', etc.: Various error states
|
|
705
844
|
"""
|
|
845
|
+
|
|
706
846
|
tool_name: str
|
|
707
847
|
result: str
|
|
708
848
|
trace_id: TraceId
|
|
709
849
|
run_id: RunId
|
|
710
850
|
tool_result: Optional[Any] = None
|
|
711
|
-
execution_status: Optional[str] =
|
|
712
|
-
|
|
851
|
+
execution_status: Optional[str] = (
|
|
852
|
+
None # success/error/timeout - indicates if tool executed successfully
|
|
853
|
+
)
|
|
854
|
+
status: Optional[str] = (
|
|
855
|
+
None # DEPRECATED: maintained for backward-compatible initialization/serialization
|
|
856
|
+
)
|
|
713
857
|
call_id: Optional[str] = None
|
|
714
858
|
|
|
715
859
|
def __post_init__(self) -> None:
|
|
716
860
|
# Handle backward compatibility with explicit conflict detection
|
|
717
|
-
if
|
|
861
|
+
if (
|
|
862
|
+
self.execution_status is not None
|
|
863
|
+
and self.status is not None
|
|
864
|
+
and self.execution_status != self.status
|
|
865
|
+
):
|
|
718
866
|
raise ValueError(
|
|
719
867
|
f"Conflicting values for execution_status ('{self.execution_status}') and status ('{self.status}'). "
|
|
720
868
|
f"Please use only execution_status for new code."
|
|
@@ -722,88 +870,145 @@ class ToolCallEndEventData:
|
|
|
722
870
|
|
|
723
871
|
# Prefer execution_status (new field) over status (deprecated field)
|
|
724
872
|
canonical = self.execution_status if self.execution_status is not None else self.status
|
|
725
|
-
object.__setattr__(self,
|
|
726
|
-
object.__setattr__(self,
|
|
873
|
+
object.__setattr__(self, "execution_status", canonical)
|
|
874
|
+
object.__setattr__(self, "status", canonical)
|
|
875
|
+
|
|
876
|
+
|
|
727
877
|
@dataclass(frozen=True)
|
|
728
878
|
class ToolCallEndEvent:
|
|
729
|
-
type: Literal[
|
|
730
|
-
data: ToolCallEndEventData = field(
|
|
879
|
+
type: Literal["tool_call_end"] = "tool_call_end"
|
|
880
|
+
data: ToolCallEndEventData = field(
|
|
881
|
+
default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None)
|
|
882
|
+
)
|
|
883
|
+
|
|
731
884
|
|
|
732
885
|
@dataclass(frozen=True)
|
|
733
886
|
class HandoffEventData:
|
|
734
887
|
"""Data for handoff events."""
|
|
735
|
-
|
|
888
|
+
|
|
889
|
+
from_: str = field(metadata={"alias": "from"}) # Using from_ since 'from' is a Python keyword
|
|
736
890
|
to: str
|
|
737
891
|
|
|
892
|
+
|
|
738
893
|
@dataclass(frozen=True)
|
|
739
894
|
class HandoffEvent:
|
|
740
|
-
type: Literal[
|
|
895
|
+
type: Literal["handoff"] = "handoff"
|
|
741
896
|
data: HandoffEventData = field(default_factory=lambda: HandoffEventData("", ""))
|
|
742
897
|
|
|
898
|
+
|
|
743
899
|
@dataclass(frozen=True)
|
|
744
900
|
class RunEndEventData:
|
|
745
901
|
"""Data for run end events."""
|
|
746
|
-
|
|
902
|
+
|
|
903
|
+
outcome: "RunOutcome[Any]"
|
|
747
904
|
trace_id: TraceId
|
|
748
905
|
run_id: RunId
|
|
749
906
|
|
|
907
|
+
|
|
750
908
|
@dataclass(frozen=True)
|
|
751
909
|
class RunEndEvent:
|
|
752
|
-
type: Literal[
|
|
753
|
-
data: RunEndEventData = field(
|
|
910
|
+
type: Literal["run_end"] = "run_end"
|
|
911
|
+
data: RunEndEventData = field(
|
|
912
|
+
default_factory=lambda: RunEndEventData(None, TraceId(""), RunId(""))
|
|
913
|
+
)
|
|
914
|
+
|
|
754
915
|
|
|
755
916
|
@dataclass(frozen=True)
|
|
756
917
|
class GuardrailEventData:
|
|
757
918
|
"""Data for guardrail check events."""
|
|
919
|
+
|
|
758
920
|
guardrail_name: str
|
|
759
921
|
content: Any
|
|
760
922
|
is_valid: Optional[bool] = None
|
|
761
923
|
error_message: Optional[str] = None
|
|
762
924
|
|
|
925
|
+
|
|
763
926
|
@dataclass(frozen=True)
|
|
764
927
|
class GuardrailEvent:
|
|
765
|
-
type: Literal[
|
|
928
|
+
type: Literal["guardrail_check"] = "guardrail_check"
|
|
766
929
|
data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
|
|
767
930
|
|
|
931
|
+
|
|
768
932
|
@dataclass(frozen=True)
|
|
769
933
|
class GuardrailViolationEventData:
|
|
770
934
|
"""Data for guardrail violation events."""
|
|
771
|
-
|
|
935
|
+
|
|
936
|
+
stage: Literal["input", "output"]
|
|
772
937
|
reason: str
|
|
773
938
|
|
|
939
|
+
|
|
774
940
|
@dataclass(frozen=True)
|
|
775
941
|
class GuardrailViolationEvent:
|
|
776
|
-
type: Literal[
|
|
777
|
-
data: GuardrailViolationEventData = field(
|
|
942
|
+
type: Literal["guardrail_violation"] = "guardrail_violation"
|
|
943
|
+
data: GuardrailViolationEventData = field(
|
|
944
|
+
default_factory=lambda: GuardrailViolationEventData("input", "")
|
|
945
|
+
)
|
|
946
|
+
|
|
778
947
|
|
|
779
948
|
@dataclass(frozen=True)
|
|
780
949
|
class MemoryEventData:
|
|
781
950
|
"""Data for memory operation events."""
|
|
782
|
-
|
|
951
|
+
|
|
952
|
+
operation: Literal["load", "store"]
|
|
783
953
|
conversation_id: str
|
|
784
|
-
status: Literal[
|
|
954
|
+
status: Literal["start", "end", "fail"]
|
|
785
955
|
error: Optional[str] = None
|
|
786
956
|
message_count: Optional[int] = None
|
|
787
957
|
|
|
958
|
+
|
|
788
959
|
@dataclass(frozen=True)
|
|
789
960
|
class MemoryEvent:
|
|
790
|
-
type: Literal[
|
|
961
|
+
type: Literal["memory_operation"] = "memory_operation"
|
|
791
962
|
data: MemoryEventData = field(default_factory=lambda: MemoryEventData("load", "", "start"))
|
|
792
963
|
|
|
964
|
+
|
|
793
965
|
@dataclass(frozen=True)
|
|
794
966
|
class OutputParseEventData:
|
|
795
967
|
"""Data for output parsing events."""
|
|
968
|
+
|
|
796
969
|
content: str
|
|
797
|
-
status: Literal[
|
|
970
|
+
status: Literal["start", "end", "fail"]
|
|
798
971
|
parsed_output: Optional[Any] = None
|
|
799
972
|
error: Optional[str] = None
|
|
800
973
|
|
|
974
|
+
|
|
801
975
|
@dataclass(frozen=True)
|
|
802
976
|
class OutputParseEvent:
|
|
803
|
-
type: Literal[
|
|
977
|
+
type: Literal["output_parse"] = "output_parse"
|
|
804
978
|
data: OutputParseEventData = field(default_factory=lambda: OutputParseEventData("", "start"))
|
|
805
979
|
|
|
806
980
|
|
|
981
|
+
@dataclass(frozen=True)
|
|
982
|
+
class RetryEventData:
|
|
983
|
+
"""Data for retry events."""
|
|
984
|
+
|
|
985
|
+
attempt: int # Current retry attempt (1-indexed)
|
|
986
|
+
max_retries: int # Maximum number of retries configured
|
|
987
|
+
reason: str # Reason for retry (e.g., "HTTP 429 - Rate Limit", "HTTP 500 - Server Error")
|
|
988
|
+
operation: Literal["llm_call", "tool_call", "workflow_step"] # What operation is being retried
|
|
989
|
+
trace_id: TraceId
|
|
990
|
+
run_id: RunId
|
|
991
|
+
delay: Optional[float] = None # Backoff delay in seconds before next retry
|
|
992
|
+
error_details: Optional[Dict[str, Any]] = None # Additional error context
|
|
993
|
+
|
|
994
|
+
|
|
995
|
+
@dataclass(frozen=True)
|
|
996
|
+
class RetryEvent:
|
|
997
|
+
"""Event emitted when a retry occurs."""
|
|
998
|
+
|
|
999
|
+
type: Literal["retry"] = "retry"
|
|
1000
|
+
data: RetryEventData = field(
|
|
1001
|
+
default_factory=lambda: RetryEventData(
|
|
1002
|
+
attempt=1,
|
|
1003
|
+
max_retries=3,
|
|
1004
|
+
reason="",
|
|
1005
|
+
operation="llm_call",
|
|
1006
|
+
trace_id=TraceId(""),
|
|
1007
|
+
run_id=RunId(""),
|
|
1008
|
+
)
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
|
|
807
1012
|
# Union type for all trace events
|
|
808
1013
|
TraceEvent = Union[
|
|
809
1014
|
RunStartEvent,
|
|
@@ -817,72 +1022,81 @@ TraceEvent = Union[
|
|
|
817
1022
|
ToolCallStartEvent,
|
|
818
1023
|
ToolCallEndEvent,
|
|
819
1024
|
HandoffEvent,
|
|
820
|
-
RunEndEvent
|
|
1025
|
+
RunEndEvent,
|
|
1026
|
+
RetryEvent,
|
|
821
1027
|
]
|
|
822
1028
|
|
|
1029
|
+
|
|
823
1030
|
@dataclass(frozen=True)
|
|
824
1031
|
class ModelCompletionMessage:
|
|
825
1032
|
"""Message structure returned by model completion."""
|
|
1033
|
+
|
|
826
1034
|
content: Optional[str] = None
|
|
827
1035
|
tool_calls: Optional[List[ToolCall]] = None
|
|
828
1036
|
|
|
1037
|
+
|
|
829
1038
|
@dataclass(frozen=True)
|
|
830
1039
|
class ModelCompletionResponse:
|
|
831
1040
|
"""Response structure from model completion."""
|
|
1041
|
+
|
|
832
1042
|
message: Optional[ModelCompletionMessage] = None
|
|
833
1043
|
|
|
1044
|
+
|
|
834
1045
|
# Streaming chunk structures for provider-level streaming support
|
|
835
1046
|
@dataclass(frozen=True)
|
|
836
1047
|
class ToolCallFunctionDelta:
|
|
837
1048
|
"""Function fields that may stream as deltas."""
|
|
1049
|
+
|
|
838
1050
|
name: Optional[str] = None
|
|
839
1051
|
arguments_delta: Optional[str] = None
|
|
840
1052
|
|
|
1053
|
+
|
|
841
1054
|
@dataclass(frozen=True)
|
|
842
1055
|
class ToolCallDelta:
|
|
843
1056
|
"""Represents a partial tool call delta in a streamed response."""
|
|
1057
|
+
|
|
844
1058
|
index: int
|
|
845
1059
|
id: Optional[str] = None
|
|
846
|
-
type: Literal[
|
|
1060
|
+
type: Literal["function"] = "function"
|
|
847
1061
|
function: Optional[ToolCallFunctionDelta] = None
|
|
848
1062
|
|
|
1063
|
+
|
|
849
1064
|
@dataclass(frozen=True)
|
|
850
1065
|
class CompletionStreamChunk:
|
|
851
1066
|
"""A streamed chunk from the model provider."""
|
|
1067
|
+
|
|
852
1068
|
delta: Optional[str] = None
|
|
853
1069
|
tool_call_delta: Optional[ToolCallDelta] = None
|
|
854
1070
|
is_done: Optional[bool] = False
|
|
855
1071
|
finish_reason: Optional[str] = None
|
|
856
1072
|
raw: Optional[Any] = None
|
|
857
1073
|
|
|
1074
|
+
|
|
858
1075
|
@runtime_checkable
|
|
859
1076
|
class ModelProvider(Protocol[Ctx]):
|
|
860
1077
|
"""Protocol for model providers."""
|
|
861
1078
|
|
|
862
1079
|
async def get_completion(
|
|
863
|
-
self,
|
|
864
|
-
state: RunState[Ctx],
|
|
865
|
-
agent: Agent[Ctx, Any],
|
|
866
|
-
config: 'RunConfig[Ctx]'
|
|
1080
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
|
|
867
1081
|
) -> ModelCompletionResponse:
|
|
868
1082
|
"""Get completion from the model."""
|
|
869
1083
|
...
|
|
870
1084
|
|
|
871
1085
|
async def get_completion_stream(
|
|
872
|
-
self,
|
|
873
|
-
state: RunState[Ctx],
|
|
874
|
-
agent: Agent[Ctx, Any],
|
|
875
|
-
config: 'RunConfig[Ctx]'
|
|
1086
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
|
|
876
1087
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
877
1088
|
"""Optional streaming API: yields incremental deltas while generating."""
|
|
878
1089
|
...
|
|
879
1090
|
|
|
1091
|
+
|
|
880
1092
|
@dataclass(frozen=True)
|
|
881
1093
|
class RunConfig(Generic[Ctx]):
|
|
882
1094
|
"""Configuration for running agents."""
|
|
1095
|
+
|
|
883
1096
|
agent_registry: Dict[str, Agent[Ctx, Any]]
|
|
884
1097
|
model_provider: ModelProvider[Ctx]
|
|
885
1098
|
max_turns: Optional[int] = 50
|
|
1099
|
+
max_tokens: Optional[int] = None # Default max_tokens for all agents (can be overridden per agent)
|
|
886
1100
|
model_override: Optional[str] = None
|
|
887
1101
|
initial_input_guardrails: Optional[List[Guardrail]] = None
|
|
888
1102
|
final_output_guardrails: Optional[List[Guardrail]] = None
|
|
@@ -891,32 +1105,68 @@ class RunConfig(Generic[Ctx]):
|
|
|
891
1105
|
conversation_id: Optional[str] = None
|
|
892
1106
|
default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
|
|
893
1107
|
default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
|
|
894
|
-
approval_storage: Optional[
|
|
895
|
-
before_llm_call: Optional[
|
|
896
|
-
|
|
1108
|
+
approval_storage: Optional["ApprovalStorage"] = None # Storage for approval decisions
|
|
1109
|
+
before_llm_call: Optional[
|
|
1110
|
+
Callable[[RunState[Ctx], Agent[Ctx, Any]], Union[RunState[Ctx], Awaitable[RunState[Ctx]]]]
|
|
1111
|
+
] = None # Callback before LLM call - can modify context/messages
|
|
1112
|
+
after_llm_call: Optional[
|
|
1113
|
+
Callable[
|
|
1114
|
+
[RunState[Ctx], ModelCompletionResponse],
|
|
1115
|
+
Union[ModelCompletionResponse, Awaitable[ModelCompletionResponse]],
|
|
1116
|
+
]
|
|
1117
|
+
] = None # Callback after LLM call - can process response
|
|
897
1118
|
max_empty_response_retries: int = 3 # Maximum retries when LLM returns empty response
|
|
898
|
-
empty_response_retry_delay: float =
|
|
1119
|
+
empty_response_retry_delay: float = (
|
|
1120
|
+
1.0 # Initial delay in seconds before retrying empty response (uses exponential backoff)
|
|
1121
|
+
)
|
|
899
1122
|
log_empty_responses: bool = True # Whether to log diagnostic info for empty responses
|
|
900
|
-
prefer_streaming: Optional[bool] =
|
|
1123
|
+
prefer_streaming: Optional[bool] = (
|
|
1124
|
+
None # Whether to prefer streaming responses. None (default) = use streaming if available, True = prefer streaming, False = disable streaming
|
|
1125
|
+
)
|
|
901
1126
|
|
|
902
1127
|
|
|
903
1128
|
# Regeneration types for conversation management
|
|
904
1129
|
@dataclass(frozen=True)
|
|
905
1130
|
class RegenerationRequest:
|
|
906
1131
|
"""Request to regenerate a conversation from a specific message."""
|
|
1132
|
+
|
|
907
1133
|
conversation_id: str
|
|
908
1134
|
message_id: MessageId # ID of the message to regenerate from
|
|
909
1135
|
context: Optional[Dict[str, Any]] = None # Optional context override
|
|
910
1136
|
|
|
1137
|
+
|
|
911
1138
|
@dataclass(frozen=True)
|
|
912
1139
|
class RegenerationContext:
|
|
913
1140
|
"""Context information for a regeneration operation."""
|
|
1141
|
+
|
|
914
1142
|
original_message_count: int
|
|
915
1143
|
truncated_at_index: int
|
|
916
1144
|
regenerated_message_id: MessageId
|
|
917
1145
|
regeneration_id: str # Unique ID for this regeneration operation
|
|
918
1146
|
timestamp: int # Unix timestamp in milliseconds
|
|
919
1147
|
|
|
1148
|
+
|
|
1149
|
+
# Checkpoint types for conversation management
|
|
1150
|
+
@dataclass(frozen=True)
|
|
1151
|
+
class CheckpointRequest:
|
|
1152
|
+
"""Request to checkpoint a conversation after a specific message."""
|
|
1153
|
+
|
|
1154
|
+
conversation_id: str
|
|
1155
|
+
message_id: MessageId # ID of the message to checkpoint after (this message is kept)
|
|
1156
|
+
context: Optional[Dict[str, Any]] = None # Optional context for the checkpoint
|
|
1157
|
+
|
|
1158
|
+
|
|
1159
|
+
@dataclass(frozen=True)
|
|
1160
|
+
class CheckpointContext:
|
|
1161
|
+
"""Context information for a checkpoint operation."""
|
|
1162
|
+
|
|
1163
|
+
original_message_count: int
|
|
1164
|
+
checkpointed_at_index: int
|
|
1165
|
+
checkpointed_message_id: MessageId
|
|
1166
|
+
checkpoint_id: str # Unique ID for this checkpoint operation
|
|
1167
|
+
timestamp: int # Unix timestamp in milliseconds
|
|
1168
|
+
|
|
1169
|
+
|
|
920
1170
|
# Message utility functions
|
|
921
1171
|
def find_message_index(messages: List[Message], message_id: MessageId) -> Optional[int]:
|
|
922
1172
|
"""Find the index of a message by its ID."""
|
|
@@ -925,6 +1175,7 @@ def find_message_index(messages: List[Message], message_id: MessageId) -> Option
|
|
|
925
1175
|
return i
|
|
926
1176
|
return None
|
|
927
1177
|
|
|
1178
|
+
|
|
928
1179
|
def truncate_messages_after(messages: List[Message], message_id: MessageId) -> List[Message]:
|
|
929
1180
|
"""Truncate messages after (and including) the specified message ID."""
|
|
930
1181
|
index = find_message_index(messages, message_id)
|
|
@@ -932,6 +1183,7 @@ def truncate_messages_after(messages: List[Message], message_id: MessageId) -> L
|
|
|
932
1183
|
return messages # Message not found, return unchanged
|
|
933
1184
|
return messages[:index]
|
|
934
1185
|
|
|
1186
|
+
|
|
935
1187
|
def get_message_by_id(messages: List[Message], message_id: MessageId) -> Optional[Message]:
|
|
936
1188
|
"""Get a message by its ID."""
|
|
937
1189
|
for msg in messages:
|