jaf-py 2.5.9__py3-none-any.whl → 2.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +269 -210
- jaf/core/types.py +371 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +361 -280
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
- jaf_py-2.5.11.dist-info/RECORD +97 -0
- jaf_py-2.5.9.dist-info/RECORD +0 -96
- {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/core/types.py
CHANGED
|
@@ -36,239 +36,274 @@ if TYPE_CHECKING:
|
|
|
36
36
|
# Comprehensive enums for type safety and improved developer experience
|
|
37
37
|
class Model(str, Enum):
|
|
38
38
|
"""Supported model identifiers."""
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
39
|
+
|
|
40
|
+
GEMINI_2_0_FLASH = "gemini-2.0-flash"
|
|
41
|
+
GEMINI_2_5_PRO = "gemini-2.5-pro"
|
|
42
|
+
GEMINI_PRO = "gemini-pro"
|
|
43
|
+
GPT_4 = "gpt-4"
|
|
44
|
+
GPT_4_TURBO = "gpt-4-turbo"
|
|
45
|
+
GPT_3_5_TURBO = "gpt-3.5-turbo"
|
|
46
|
+
CLAUDE_3_SONNET = "claude-3-sonnet"
|
|
47
|
+
CLAUDE_3_HAIKU = "claude-3-haiku"
|
|
48
|
+
CLAUDE_3_OPUS = "claude-3-opus"
|
|
48
49
|
|
|
49
50
|
|
|
50
51
|
class ToolParameterType(str, Enum):
|
|
51
52
|
"""Tool parameter types."""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
53
|
+
|
|
54
|
+
STRING = "string"
|
|
55
|
+
NUMBER = "number"
|
|
56
|
+
INTEGER = "integer"
|
|
57
|
+
BOOLEAN = "boolean"
|
|
58
|
+
ARRAY = "array"
|
|
59
|
+
OBJECT = "object"
|
|
60
|
+
NULL = "null"
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
class ToolSource(str, Enum):
|
|
62
64
|
"""Source of tool definitions."""
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
|
|
66
|
+
NATIVE = "native"
|
|
67
|
+
MCP = "mcp"
|
|
68
|
+
PLUGIN = "plugin"
|
|
69
|
+
EXTERNAL = "external"
|
|
67
70
|
|
|
68
71
|
|
|
69
72
|
class ContentRole(str, Enum):
|
|
70
73
|
"""Message content roles."""
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
74
|
+
|
|
75
|
+
USER = "user"
|
|
76
|
+
ASSISTANT = "assistant"
|
|
77
|
+
TOOL = "tool"
|
|
78
|
+
SYSTEM = "system"
|
|
75
79
|
|
|
76
80
|
|
|
77
81
|
class PartType(str, Enum):
|
|
78
82
|
"""Message part types."""
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
83
|
+
|
|
84
|
+
TEXT = "text"
|
|
85
|
+
IMAGE = "image"
|
|
86
|
+
AUDIO = "audio"
|
|
87
|
+
VIDEO = "video"
|
|
88
|
+
FILE = "file"
|
|
84
89
|
|
|
85
90
|
|
|
86
91
|
# Branded types for type safety - using class-based approach for better type safety
|
|
87
92
|
class TraceId(str):
|
|
88
93
|
"""Branded string type for trace IDs."""
|
|
89
|
-
|
|
94
|
+
|
|
95
|
+
def __new__(cls, value: str) -> "TraceId":
|
|
90
96
|
return str.__new__(cls, value)
|
|
91
97
|
|
|
98
|
+
|
|
92
99
|
class RunId(str):
|
|
93
100
|
"""Branded string type for run IDs."""
|
|
94
|
-
|
|
101
|
+
|
|
102
|
+
def __new__(cls, value: str) -> "RunId":
|
|
95
103
|
return str.__new__(cls, value)
|
|
96
104
|
|
|
105
|
+
|
|
97
106
|
class MessageId(str):
|
|
98
107
|
"""Branded string type for message IDs."""
|
|
99
|
-
|
|
108
|
+
|
|
109
|
+
def __new__(cls, value: str) -> "MessageId":
|
|
100
110
|
return str.__new__(cls, value)
|
|
101
111
|
|
|
112
|
+
|
|
102
113
|
def create_trace_id(id_str: str) -> TraceId:
|
|
103
114
|
"""Create a TraceId from a string."""
|
|
104
115
|
return TraceId(id_str)
|
|
105
116
|
|
|
117
|
+
|
|
106
118
|
def create_run_id(id_str: str) -> RunId:
|
|
107
119
|
"""Create a RunId from a string."""
|
|
108
120
|
return RunId(id_str)
|
|
109
121
|
|
|
122
|
+
|
|
110
123
|
def create_message_id(id_str: Union[str, MessageId]) -> MessageId:
|
|
111
124
|
"""
|
|
112
125
|
Create a MessageId from a string or return existing MessageId.
|
|
113
|
-
|
|
126
|
+
|
|
114
127
|
Args:
|
|
115
128
|
id_str: Either a string to convert to MessageId or an existing MessageId
|
|
116
|
-
|
|
129
|
+
|
|
117
130
|
Returns:
|
|
118
131
|
MessageId: A validated MessageId instance
|
|
119
|
-
|
|
132
|
+
|
|
120
133
|
Raises:
|
|
121
134
|
ValueError: If the input is invalid or empty
|
|
122
135
|
"""
|
|
123
136
|
# Handle None input
|
|
124
137
|
if id_str is None:
|
|
125
138
|
raise ValueError("Message ID cannot be None")
|
|
126
|
-
|
|
139
|
+
|
|
127
140
|
# If already a MessageId, return as-is
|
|
128
141
|
if isinstance(id_str, MessageId):
|
|
129
142
|
return id_str
|
|
130
|
-
|
|
143
|
+
|
|
131
144
|
# Convert string to MessageId with validation
|
|
132
145
|
if isinstance(id_str, str):
|
|
133
146
|
if not id_str.strip():
|
|
134
147
|
raise ValueError("Message ID cannot be empty or whitespace")
|
|
135
148
|
return MessageId(id_str.strip())
|
|
136
|
-
|
|
149
|
+
|
|
137
150
|
# Handle any other type
|
|
138
151
|
raise ValueError(f"Message ID must be a string or MessageId, got {type(id_str)}")
|
|
139
152
|
|
|
153
|
+
|
|
140
154
|
def generate_run_id() -> RunId:
|
|
141
155
|
"""Generate a new unique run ID."""
|
|
142
156
|
import time
|
|
143
157
|
import uuid
|
|
158
|
+
|
|
144
159
|
return RunId(f"run_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
145
160
|
|
|
161
|
+
|
|
146
162
|
def generate_trace_id() -> TraceId:
|
|
147
163
|
"""Generate a new unique trace ID."""
|
|
148
164
|
import time
|
|
149
165
|
import uuid
|
|
166
|
+
|
|
150
167
|
return TraceId(f"trace_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
151
168
|
|
|
169
|
+
|
|
152
170
|
def generate_message_id() -> MessageId:
|
|
153
171
|
"""Generate a new unique message ID."""
|
|
154
172
|
import time
|
|
155
173
|
import uuid
|
|
174
|
+
|
|
156
175
|
return MessageId(f"msg_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
157
176
|
|
|
177
|
+
|
|
158
178
|
# Type variables for generic contexts and outputs
|
|
159
|
-
Ctx = TypeVar(
|
|
160
|
-
Out = TypeVar(
|
|
161
|
-
Args = TypeVar(
|
|
179
|
+
Ctx = TypeVar("Ctx")
|
|
180
|
+
Out = TypeVar("Out")
|
|
181
|
+
Args = TypeVar("Args")
|
|
182
|
+
|
|
162
183
|
|
|
163
184
|
# Discriminated union for ValidationResult to match TypeScript version
|
|
164
185
|
@dataclass(frozen=True)
|
|
165
186
|
class ValidValidationResult:
|
|
166
187
|
"""Valid validation result."""
|
|
188
|
+
|
|
167
189
|
is_valid: Literal[True] = True
|
|
168
190
|
|
|
191
|
+
|
|
169
192
|
@dataclass(frozen=True)
|
|
170
193
|
class InvalidValidationResult:
|
|
171
194
|
"""Invalid validation result with error message."""
|
|
195
|
+
|
|
172
196
|
is_valid: Literal[False] = False
|
|
173
197
|
error_message: str = ""
|
|
174
198
|
|
|
199
|
+
|
|
175
200
|
ValidationResult = Union[ValidValidationResult, InvalidValidationResult]
|
|
176
201
|
|
|
202
|
+
|
|
177
203
|
@dataclass(frozen=True)
|
|
178
204
|
class ToolCall:
|
|
179
205
|
"""Represents a tool call from the model."""
|
|
206
|
+
|
|
180
207
|
id: str
|
|
181
|
-
type: Literal[
|
|
182
|
-
function:
|
|
208
|
+
type: Literal["function"]
|
|
209
|
+
function: "ToolCallFunction"
|
|
210
|
+
|
|
183
211
|
|
|
184
212
|
@dataclass(frozen=True)
|
|
185
213
|
class ToolCallFunction:
|
|
186
214
|
"""Function information within a tool call."""
|
|
215
|
+
|
|
187
216
|
name: str
|
|
188
217
|
arguments: str
|
|
189
218
|
|
|
219
|
+
|
|
190
220
|
@dataclass(frozen=True)
|
|
191
221
|
class Attachment:
|
|
192
222
|
"""Represents an attachment with various content types."""
|
|
193
|
-
|
|
223
|
+
|
|
224
|
+
kind: Literal["image", "document", "file"]
|
|
194
225
|
mime_type: Optional[str] = None # e.g. image/png, application/pdf
|
|
195
|
-
name: Optional[str] = None
|
|
196
|
-
url: Optional[str] = None
|
|
197
|
-
data: Optional[str] = None
|
|
198
|
-
format: Optional[str] = None
|
|
226
|
+
name: Optional[str] = None # Optional filename
|
|
227
|
+
url: Optional[str] = None # Remote URL or data URL
|
|
228
|
+
data: Optional[str] = None # Base64 without data: prefix
|
|
229
|
+
format: Optional[str] = None # Optional short format like 'pdf', 'txt'
|
|
199
230
|
use_litellm_format: Optional[bool] = None # Use LiteLLM native file format
|
|
200
|
-
|
|
231
|
+
|
|
201
232
|
def __post_init__(self):
|
|
202
233
|
"""Validate that at least one of url or data is provided."""
|
|
203
234
|
if self.url is None and self.data is None:
|
|
204
235
|
raise ValueError("At least one of 'url' or 'data' must be provided for an Attachment.")
|
|
205
236
|
|
|
237
|
+
|
|
206
238
|
@dataclass(frozen=True)
|
|
207
239
|
class MessageContentPart:
|
|
208
240
|
"""Part of multi-part message content."""
|
|
209
|
-
|
|
241
|
+
|
|
242
|
+
type: Literal["text", "image_url", "file"]
|
|
210
243
|
text: Optional[str] = None
|
|
211
244
|
image_url: Optional[Dict[str, Any]] = None # Contains url and optional detail
|
|
212
|
-
file: Optional[Dict[str, Any]] = None
|
|
245
|
+
file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
|
|
246
|
+
|
|
213
247
|
|
|
214
248
|
@dataclass(frozen=True)
|
|
215
249
|
class Message:
|
|
216
250
|
"""
|
|
217
251
|
A message in the conversation.
|
|
218
|
-
|
|
252
|
+
|
|
219
253
|
BACKWARDS COMPATIBILITY:
|
|
220
254
|
- Messages created with string content remain fully backwards compatible
|
|
221
255
|
- Direct access to .content returns the original string when created with string
|
|
222
256
|
- Use .text_content property for guaranteed string access in all cases
|
|
223
257
|
- Use get_text_content() function to extract text from any content type
|
|
224
258
|
- message_id is optional for backward compatibility
|
|
225
|
-
|
|
259
|
+
|
|
226
260
|
Examples:
|
|
227
261
|
# Original usage - still works exactly the same
|
|
228
262
|
msg = Message(role='user', content='Hello')
|
|
229
263
|
text = msg.content # Returns 'Hello' as string
|
|
230
|
-
|
|
264
|
+
|
|
231
265
|
# New usage with message ID
|
|
232
266
|
msg = Message(role='user', content='Hello', message_id='msg_123')
|
|
233
|
-
|
|
267
|
+
|
|
234
268
|
# Guaranteed string access (recommended for new code)
|
|
235
269
|
text = msg.text_content # Always returns string
|
|
236
|
-
|
|
270
|
+
|
|
237
271
|
# Universal text extraction
|
|
238
272
|
text = get_text_content(msg.content) # Works with any content type
|
|
239
273
|
"""
|
|
274
|
+
|
|
240
275
|
role: ContentRole
|
|
241
276
|
content: Union[str, List[MessageContentPart]]
|
|
242
277
|
attachments: Optional[List[Attachment]] = None
|
|
243
278
|
tool_call_id: Optional[str] = None
|
|
244
279
|
tool_calls: Optional[List[ToolCall]] = None
|
|
245
280
|
message_id: Optional[MessageId] = None # Optional for backward compatibility
|
|
246
|
-
|
|
281
|
+
|
|
247
282
|
def __post_init__(self):
|
|
248
283
|
"""
|
|
249
284
|
Auto-generate message ID if not provided.
|
|
250
|
-
|
|
285
|
+
|
|
251
286
|
This implementation uses object.__setattr__ to bypass frozen dataclass restrictions,
|
|
252
287
|
which is a recommended pattern for one-time initialization of computed fields in
|
|
253
288
|
frozen dataclasses. This ensures:
|
|
254
|
-
|
|
289
|
+
|
|
255
290
|
1. Backward compatibility - existing code with message_id=None continues to work
|
|
256
291
|
2. Immutability - the dataclass remains frozen after initialization
|
|
257
292
|
3. Guaranteed unique IDs - every message gets a unique identifier
|
|
258
293
|
4. Clean API - users don't need to manually generate IDs in most cases
|
|
259
|
-
|
|
294
|
+
|
|
260
295
|
This pattern is preferred over using field(default_factory=...) because it
|
|
261
296
|
maintains the Optional[MessageId] type hint for backward compatibility while
|
|
262
297
|
ensuring the field is never actually None after object creation.
|
|
263
298
|
"""
|
|
264
299
|
if self.message_id is None:
|
|
265
|
-
object.__setattr__(self,
|
|
266
|
-
|
|
300
|
+
object.__setattr__(self, "message_id", generate_message_id())
|
|
301
|
+
|
|
267
302
|
@property
|
|
268
303
|
def text_content(self) -> str:
|
|
269
304
|
"""Get text content as string for backwards compatibility."""
|
|
270
305
|
return get_text_content(self.content)
|
|
271
|
-
|
|
306
|
+
|
|
272
307
|
@classmethod
|
|
273
308
|
def create(
|
|
274
309
|
cls,
|
|
@@ -277,8 +312,8 @@ class Message:
|
|
|
277
312
|
attachments: Optional[List[Attachment]] = None,
|
|
278
313
|
tool_call_id: Optional[str] = None,
|
|
279
314
|
tool_calls: Optional[List[ToolCall]] = None,
|
|
280
|
-
message_id: Optional[MessageId] = None
|
|
281
|
-
) ->
|
|
315
|
+
message_id: Optional[MessageId] = None,
|
|
316
|
+
) -> "Message":
|
|
282
317
|
"""Create a message with string content and optional attachments."""
|
|
283
318
|
return cls(
|
|
284
319
|
role=role,
|
|
@@ -286,32 +321,43 @@ class Message:
|
|
|
286
321
|
attachments=attachments,
|
|
287
322
|
tool_call_id=tool_call_id,
|
|
288
323
|
tool_calls=tool_calls,
|
|
289
|
-
message_id=message_id
|
|
324
|
+
message_id=message_id,
|
|
290
325
|
)
|
|
291
326
|
|
|
327
|
+
|
|
292
328
|
def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
|
|
293
329
|
"""Extract text content from message content."""
|
|
294
330
|
if isinstance(content, str):
|
|
295
331
|
return content
|
|
296
|
-
|
|
297
|
-
text_parts = [part.text for part in content if part.type ==
|
|
298
|
-
return
|
|
332
|
+
|
|
333
|
+
text_parts = [part.text for part in content if part.type == "text" and part.text]
|
|
334
|
+
return " ".join(text_parts)
|
|
335
|
+
|
|
299
336
|
|
|
300
337
|
@dataclass(frozen=True)
|
|
301
338
|
class ModelConfig:
|
|
302
339
|
"""Configuration for model behavior."""
|
|
340
|
+
|
|
303
341
|
name: Optional[str] = None
|
|
304
342
|
temperature: Optional[float] = None
|
|
305
343
|
max_tokens: Optional[int] = None
|
|
344
|
+
inline_tool_schemas: Optional[bool] = (
|
|
345
|
+
None # If True, resolve $refs and inline $defs in tool schemas
|
|
346
|
+
)
|
|
347
|
+
|
|
306
348
|
|
|
307
349
|
@dataclass(frozen=True)
|
|
308
350
|
class ToolSchema(Generic[Args]):
|
|
309
351
|
"""Schema definition for a tool."""
|
|
352
|
+
|
|
310
353
|
name: str
|
|
311
354
|
description: str
|
|
312
|
-
parameters:
|
|
355
|
+
parameters: (
|
|
356
|
+
Any # Should be a type that can validate Args (like Pydantic model or Zod equivalent)
|
|
357
|
+
)
|
|
313
358
|
timeout: Optional[float] = None # Optional timeout in seconds for tool execution
|
|
314
359
|
|
|
360
|
+
|
|
315
361
|
@runtime_checkable
|
|
316
362
|
class Tool(Protocol[Args, Ctx]):
|
|
317
363
|
"""Protocol for tool implementations."""
|
|
@@ -321,7 +367,7 @@ class Tool(Protocol[Args, Ctx]):
|
|
|
321
367
|
"""Tool schema including name, description, and parameter validation."""
|
|
322
368
|
...
|
|
323
369
|
|
|
324
|
-
async def execute(self, args: Args, context: Ctx) -> Union[str,
|
|
370
|
+
async def execute(self, args: Args, context: Ctx) -> Union[str, "ToolResult[Any]"]:
|
|
325
371
|
"""Execute the tool with given arguments and context."""
|
|
326
372
|
...
|
|
327
373
|
|
|
@@ -334,9 +380,12 @@ class Tool(Protocol[Args, Ctx]):
|
|
|
334
380
|
# Function tool configuration for improved DX
|
|
335
381
|
class FunctionToolConfig(TypedDict):
|
|
336
382
|
"""Configuration for creating function-based tools with object-based API."""
|
|
383
|
+
|
|
337
384
|
name: str
|
|
338
385
|
description: str
|
|
339
|
-
execute: Callable[
|
|
386
|
+
execute: Callable[
|
|
387
|
+
[Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
|
|
388
|
+
]
|
|
340
389
|
parameters: Any # Pydantic model or similar for parameter validation
|
|
341
390
|
metadata: Optional[Dict[str, Any]] # Optional metadata
|
|
342
391
|
source: Optional[ToolSource] # Optional source tracking
|
|
@@ -344,34 +393,45 @@ class FunctionToolConfig(TypedDict):
|
|
|
344
393
|
|
|
345
394
|
|
|
346
395
|
# Type alias for tool execution functions
|
|
347
|
-
ToolExecuteFunction = Callable[
|
|
396
|
+
ToolExecuteFunction = Callable[
|
|
397
|
+
[Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
|
|
398
|
+
]
|
|
348
399
|
|
|
349
400
|
|
|
350
401
|
@dataclass(frozen=True)
|
|
351
402
|
class Agent(Generic[Ctx, Out]):
|
|
352
403
|
"""An agent definition with instructions, tools, and configuration."""
|
|
404
|
+
|
|
353
405
|
name: str
|
|
354
|
-
instructions: Callable[[
|
|
406
|
+
instructions: Callable[["RunState[Ctx]"], str]
|
|
355
407
|
tools: Optional[List[Tool[Any, Ctx]]] = None
|
|
356
|
-
output_codec: Optional[Any] =
|
|
408
|
+
output_codec: Optional[Any] = (
|
|
409
|
+
None # Type that can validate Out (like Pydantic model or Zod equivalent)
|
|
410
|
+
)
|
|
357
411
|
handoffs: Optional[List[str]] = None
|
|
358
412
|
model_config: Optional[ModelConfig] = None
|
|
359
|
-
advanced_config: Optional[
|
|
413
|
+
advanced_config: Optional["AdvancedConfig"] = None
|
|
360
414
|
|
|
361
415
|
def as_tool(
|
|
362
416
|
self,
|
|
363
417
|
tool_name: Optional[str] = None,
|
|
364
418
|
tool_description: Optional[str] = None,
|
|
365
419
|
max_turns: Optional[int] = None,
|
|
366
|
-
custom_output_extractor: Optional[
|
|
367
|
-
|
|
420
|
+
custom_output_extractor: Optional[
|
|
421
|
+
Callable[["RunResult[Out]"], Union[str, Awaitable[str]]]
|
|
422
|
+
] = None,
|
|
423
|
+
is_enabled: Union[
|
|
424
|
+
bool,
|
|
425
|
+
Callable[[Any, "Agent[Ctx, Out]"], bool],
|
|
426
|
+
Callable[[Any, "Agent[Ctx, Out]"], Awaitable[bool]],
|
|
427
|
+
] = True,
|
|
368
428
|
metadata: Optional[Dict[str, Any]] = None,
|
|
369
429
|
timeout: Optional[float] = None,
|
|
370
|
-
preserve_session: bool = False
|
|
430
|
+
preserve_session: bool = False,
|
|
371
431
|
) -> Tool[Any, Ctx]:
|
|
372
432
|
"""
|
|
373
433
|
Convert this agent into a tool that can be used by other agents.
|
|
374
|
-
|
|
434
|
+
|
|
375
435
|
Args:
|
|
376
436
|
tool_name: Optional custom name for the tool (defaults to agent name)
|
|
377
437
|
tool_description: Optional custom description (defaults to generic description)
|
|
@@ -380,11 +440,12 @@ class Agent(Generic[Ctx, Out]):
|
|
|
380
440
|
is_enabled: Whether the tool is enabled (bool, sync function, or async function)
|
|
381
441
|
metadata: Optional metadata for the tool
|
|
382
442
|
timeout: Optional timeout for the tool execution
|
|
383
|
-
|
|
443
|
+
|
|
384
444
|
Returns:
|
|
385
445
|
A Tool that wraps this agent's execution
|
|
386
446
|
"""
|
|
387
447
|
from .agent_tool import create_agent_tool
|
|
448
|
+
|
|
388
449
|
return create_agent_tool(
|
|
389
450
|
agent=self,
|
|
390
451
|
tool_name=tool_name,
|
|
@@ -394,90 +455,109 @@ class Agent(Generic[Ctx, Out]):
|
|
|
394
455
|
is_enabled=is_enabled,
|
|
395
456
|
metadata=metadata,
|
|
396
457
|
timeout=timeout,
|
|
397
|
-
preserve_session=preserve_session
|
|
458
|
+
preserve_session=preserve_session,
|
|
398
459
|
)
|
|
399
460
|
|
|
461
|
+
|
|
400
462
|
# Guardrail type
|
|
401
463
|
Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
|
|
402
464
|
|
|
465
|
+
|
|
403
466
|
@dataclass(frozen=True)
|
|
404
467
|
class AdvancedGuardrailsConfig:
|
|
405
468
|
"""Configuration for advanced guardrails with LLM-based validation."""
|
|
469
|
+
|
|
406
470
|
input_prompt: Optional[str] = None
|
|
407
471
|
output_prompt: Optional[str] = None
|
|
408
472
|
require_citations: bool = False
|
|
409
473
|
fast_model: Optional[str] = None
|
|
410
|
-
fail_safe: Literal[
|
|
411
|
-
execution_mode: Literal[
|
|
474
|
+
fail_safe: Literal["allow", "block"] = "allow"
|
|
475
|
+
execution_mode: Literal["parallel", "sequential"] = "parallel"
|
|
412
476
|
timeout_ms: int = 30000
|
|
413
477
|
|
|
414
478
|
def __post_init__(self):
|
|
415
479
|
"""Validate configuration."""
|
|
416
480
|
if self.timeout_ms < 1000:
|
|
417
|
-
object.__setattr__(self,
|
|
481
|
+
object.__setattr__(self, "timeout_ms", 1000)
|
|
482
|
+
|
|
418
483
|
|
|
419
484
|
@dataclass(frozen=True)
|
|
420
485
|
class AdvancedConfig:
|
|
421
486
|
"""Advanced agent configuration including guardrails."""
|
|
487
|
+
|
|
422
488
|
guardrails: Optional[AdvancedGuardrailsConfig] = None
|
|
423
489
|
|
|
424
|
-
|
|
490
|
+
|
|
491
|
+
def validate_guardrails_config(
|
|
492
|
+
config: Optional[AdvancedGuardrailsConfig],
|
|
493
|
+
) -> AdvancedGuardrailsConfig:
|
|
425
494
|
"""Validate and provide defaults for guardrails configuration."""
|
|
426
495
|
if config is None:
|
|
427
496
|
return AdvancedGuardrailsConfig()
|
|
428
|
-
|
|
497
|
+
|
|
429
498
|
return AdvancedGuardrailsConfig(
|
|
430
|
-
input_prompt=config.input_prompt.strip()
|
|
431
|
-
|
|
499
|
+
input_prompt=config.input_prompt.strip()
|
|
500
|
+
if isinstance(config.input_prompt, str) and config.input_prompt
|
|
501
|
+
else None,
|
|
502
|
+
output_prompt=config.output_prompt.strip()
|
|
503
|
+
if isinstance(config.output_prompt, str) and config.output_prompt
|
|
504
|
+
else None,
|
|
432
505
|
require_citations=config.require_citations,
|
|
433
|
-
fast_model=config.fast_model.strip()
|
|
506
|
+
fast_model=config.fast_model.strip()
|
|
507
|
+
if isinstance(config.fast_model, str) and config.fast_model
|
|
508
|
+
else None,
|
|
434
509
|
fail_safe=config.fail_safe,
|
|
435
510
|
execution_mode=config.execution_mode,
|
|
436
|
-
timeout_ms=max(1000, config.timeout_ms)
|
|
511
|
+
timeout_ms=max(1000, config.timeout_ms),
|
|
437
512
|
)
|
|
438
513
|
|
|
514
|
+
|
|
439
515
|
def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
|
|
440
516
|
"""Parse JSON from LLM output, handling common formatting issues."""
|
|
441
517
|
import json
|
|
442
518
|
import re
|
|
443
|
-
|
|
519
|
+
|
|
444
520
|
if not text:
|
|
445
521
|
return None
|
|
446
|
-
|
|
522
|
+
|
|
447
523
|
# Try direct parsing first
|
|
448
524
|
try:
|
|
449
525
|
return json.loads(text)
|
|
450
526
|
except json.JSONDecodeError:
|
|
451
527
|
pass
|
|
452
|
-
|
|
528
|
+
|
|
453
529
|
# Try to extract JSON from markdown code blocks
|
|
454
|
-
json_match = re.search(r
|
|
530
|
+
json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
|
455
531
|
if json_match:
|
|
456
532
|
try:
|
|
457
533
|
return json.loads(json_match.group(1))
|
|
458
534
|
except json.JSONDecodeError:
|
|
459
535
|
pass
|
|
460
|
-
|
|
536
|
+
|
|
461
537
|
# Try to find the first JSON object in the text
|
|
462
|
-
json_match = re.search(r
|
|
538
|
+
json_match = re.search(r"\{.*?\}", text, re.DOTALL)
|
|
463
539
|
if json_match:
|
|
464
540
|
try:
|
|
465
541
|
return json.loads(json_match.group(0))
|
|
466
542
|
except json.JSONDecodeError:
|
|
467
543
|
pass
|
|
468
|
-
|
|
544
|
+
|
|
469
545
|
return None
|
|
470
546
|
|
|
547
|
+
|
|
471
548
|
@dataclass(frozen=True)
|
|
472
549
|
class ApprovalValue:
|
|
473
550
|
"""Represents an approval decision with context."""
|
|
551
|
+
|
|
474
552
|
status: str # 'pending', 'approved', 'rejected'
|
|
475
553
|
approved: bool
|
|
476
554
|
additional_context: Optional[Dict[str, Any]] = None
|
|
477
555
|
|
|
556
|
+
|
|
478
557
|
@dataclass(frozen=True)
|
|
479
558
|
class RunState(Generic[Ctx]):
|
|
480
559
|
"""Immutable state of a run."""
|
|
560
|
+
|
|
481
561
|
run_id: RunId
|
|
482
562
|
trace_id: TraceId
|
|
483
563
|
messages: List[Message]
|
|
@@ -486,51 +566,61 @@ class RunState(Generic[Ctx]):
|
|
|
486
566
|
turn_count: int
|
|
487
567
|
approvals: Dict[str, ApprovalValue] = field(default_factory=dict)
|
|
488
568
|
|
|
569
|
+
|
|
489
570
|
# Error types using dataclasses for immutability
|
|
490
571
|
@dataclass(frozen=True)
|
|
491
572
|
class MaxTurnsExceeded:
|
|
492
573
|
_tag: Literal["MaxTurnsExceeded"] = "MaxTurnsExceeded"
|
|
493
574
|
turns: int = 0
|
|
494
575
|
|
|
576
|
+
|
|
495
577
|
@dataclass(frozen=True)
|
|
496
578
|
class ModelBehaviorError:
|
|
497
579
|
_tag: Literal["ModelBehaviorError"] = "ModelBehaviorError"
|
|
498
580
|
detail: str = ""
|
|
499
581
|
|
|
582
|
+
|
|
500
583
|
@dataclass(frozen=True)
|
|
501
584
|
class DecodeError:
|
|
502
585
|
_tag: Literal["DecodeError"] = "DecodeError"
|
|
503
586
|
errors: List[Dict[str, Any]] = field(default_factory=list)
|
|
504
587
|
|
|
588
|
+
|
|
505
589
|
@dataclass(frozen=True)
|
|
506
590
|
class InputGuardrailTripwire:
|
|
507
591
|
_tag: Literal["InputGuardrailTripwire"] = "InputGuardrailTripwire"
|
|
508
592
|
reason: str = ""
|
|
509
593
|
|
|
594
|
+
|
|
510
595
|
@dataclass(frozen=True)
|
|
511
596
|
class OutputGuardrailTripwire:
|
|
512
597
|
_tag: Literal["OutputGuardrailTripwire"] = "OutputGuardrailTripwire"
|
|
513
598
|
reason: str = ""
|
|
514
599
|
|
|
600
|
+
|
|
515
601
|
@dataclass(frozen=True)
|
|
516
602
|
class ToolCallError:
|
|
517
603
|
_tag: Literal["ToolCallError"] = "ToolCallError"
|
|
518
604
|
tool: str = ""
|
|
519
605
|
detail: str = ""
|
|
520
606
|
|
|
607
|
+
|
|
521
608
|
@dataclass(frozen=True)
|
|
522
609
|
class HandoffError:
|
|
523
610
|
_tag: Literal["HandoffError"] = "HandoffError"
|
|
524
611
|
detail: str = ""
|
|
525
612
|
|
|
613
|
+
|
|
526
614
|
@dataclass(frozen=True)
|
|
527
615
|
class AgentNotFound:
|
|
528
616
|
_tag: Literal["AgentNotFound"] = "AgentNotFound"
|
|
529
617
|
agent_name: str = ""
|
|
530
618
|
|
|
619
|
+
|
|
531
620
|
@dataclass(frozen=True)
|
|
532
621
|
class RecoverableError:
|
|
533
622
|
"""Error that can be recovered from with retry strategies."""
|
|
623
|
+
|
|
534
624
|
_tag: Literal["RecoverableError"] = "RecoverableError"
|
|
535
625
|
detail: str = ""
|
|
536
626
|
retry_strategy: Optional[str] = None
|
|
@@ -538,40 +628,51 @@ class RecoverableError:
|
|
|
538
628
|
current_attempt: int = 1
|
|
539
629
|
backoff_seconds: float = 1.0
|
|
540
630
|
|
|
631
|
+
|
|
541
632
|
@dataclass(frozen=True)
|
|
542
633
|
class RateLimitError:
|
|
543
634
|
"""Error indicating rate limit has been exceeded."""
|
|
635
|
+
|
|
544
636
|
_tag: Literal["RateLimitError"] = "RateLimitError"
|
|
545
637
|
detail: str = ""
|
|
546
638
|
retry_after_seconds: Optional[float] = None
|
|
547
639
|
limit_type: str = "requests" # "requests", "tokens", "concurrent"
|
|
548
640
|
|
|
641
|
+
|
|
549
642
|
@dataclass(frozen=True)
|
|
550
643
|
class ValidationError:
|
|
551
644
|
"""Enhanced validation error with detailed context."""
|
|
645
|
+
|
|
552
646
|
_tag: Literal["ValidationError"] = "ValidationError"
|
|
553
647
|
detail: str = ""
|
|
554
648
|
field_errors: List[Dict[str, Any]] = field(default_factory=list)
|
|
555
649
|
error_code: str = "validation_failed"
|
|
556
650
|
|
|
651
|
+
|
|
557
652
|
@dataclass(frozen=True)
|
|
558
653
|
class NetworkError:
|
|
559
654
|
"""Network-related errors with retry information."""
|
|
655
|
+
|
|
560
656
|
_tag: Literal["NetworkError"] = "NetworkError"
|
|
561
657
|
detail: str = ""
|
|
562
658
|
status_code: Optional[int] = None
|
|
563
659
|
is_retryable: bool = True
|
|
564
660
|
endpoint: Optional[str] = None
|
|
565
661
|
|
|
662
|
+
|
|
566
663
|
# Interruption types for HITL
|
|
567
664
|
@dataclass(frozen=True)
|
|
568
665
|
class ToolApprovalInterruption(Generic[Ctx]):
|
|
569
666
|
"""Interruption for tool approval."""
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
667
|
+
|
|
668
|
+
type: Literal["tool_approval"] = "tool_approval"
|
|
669
|
+
tool_call: ToolCall = field(
|
|
670
|
+
default_factory=lambda: ToolCall("", "function", ToolCallFunction("", ""))
|
|
671
|
+
)
|
|
672
|
+
agent: "Agent[Ctx, Any]" = None
|
|
573
673
|
session_id: Optional[str] = None
|
|
574
674
|
|
|
675
|
+
|
|
575
676
|
# Union type for all interruptions
|
|
576
677
|
Interruption = Union[ToolApprovalInterruption[Any]]
|
|
577
678
|
|
|
@@ -588,40 +689,51 @@ JAFError = Union[
|
|
|
588
689
|
RecoverableError,
|
|
589
690
|
RateLimitError,
|
|
590
691
|
ValidationError,
|
|
591
|
-
NetworkError
|
|
692
|
+
NetworkError,
|
|
592
693
|
]
|
|
593
694
|
|
|
695
|
+
|
|
594
696
|
@dataclass(frozen=True)
|
|
595
697
|
class CompletedOutcome(Generic[Out]):
|
|
596
698
|
"""Successful completion outcome."""
|
|
597
|
-
|
|
699
|
+
|
|
700
|
+
status: Literal["completed"] = "completed"
|
|
598
701
|
output: Out = field(default=None)
|
|
599
702
|
|
|
703
|
+
|
|
600
704
|
@dataclass(frozen=True)
|
|
601
705
|
class ErrorOutcome:
|
|
602
706
|
"""Error outcome."""
|
|
603
|
-
|
|
707
|
+
|
|
708
|
+
status: Literal["error"] = "error"
|
|
604
709
|
error: JAFError = field(default=None)
|
|
605
710
|
|
|
711
|
+
|
|
606
712
|
@dataclass(frozen=True)
|
|
607
713
|
class InterruptedOutcome:
|
|
608
714
|
"""Interrupted outcome for HITL."""
|
|
609
|
-
|
|
715
|
+
|
|
716
|
+
status: Literal["interrupted"] = "interrupted"
|
|
610
717
|
interruptions: List[Interruption] = field(default_factory=list)
|
|
611
718
|
|
|
719
|
+
|
|
612
720
|
# Union type for outcomes
|
|
613
721
|
RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome, InterruptedOutcome]
|
|
614
722
|
|
|
723
|
+
|
|
615
724
|
@dataclass(frozen=True)
|
|
616
725
|
class RunResult(Generic[Out]):
|
|
617
726
|
"""Result of a run execution."""
|
|
727
|
+
|
|
618
728
|
final_state: RunState[Any]
|
|
619
729
|
outcome: RunOutcome[Out]
|
|
620
730
|
|
|
731
|
+
|
|
621
732
|
# Trace event types with specific data structures to match TypeScript
|
|
622
733
|
@dataclass(frozen=True)
|
|
623
734
|
class RunStartEventData:
|
|
624
735
|
"""Data for run start events."""
|
|
736
|
+
|
|
625
737
|
run_id: RunId
|
|
626
738
|
trace_id: TraceId
|
|
627
739
|
session_id: Optional[str] = None
|
|
@@ -629,14 +741,19 @@ class RunStartEventData:
|
|
|
629
741
|
messages: Optional[List[Message]] = None
|
|
630
742
|
agent_name: Optional[str] = None
|
|
631
743
|
|
|
744
|
+
|
|
632
745
|
@dataclass(frozen=True)
|
|
633
746
|
class RunStartEvent:
|
|
634
|
-
type: Literal[
|
|
635
|
-
data: RunStartEventData = field(
|
|
747
|
+
type: Literal["run_start"] = "run_start"
|
|
748
|
+
data: RunStartEventData = field(
|
|
749
|
+
default_factory=lambda: RunStartEventData(RunId(""), TraceId(""))
|
|
750
|
+
)
|
|
751
|
+
|
|
636
752
|
|
|
637
753
|
@dataclass(frozen=True)
|
|
638
754
|
class LLMCallStartEventData:
|
|
639
755
|
"""Data for LLM call start events."""
|
|
756
|
+
|
|
640
757
|
agent_name: str
|
|
641
758
|
model: str
|
|
642
759
|
trace_id: TraceId
|
|
@@ -644,46 +761,68 @@ class LLMCallStartEventData:
|
|
|
644
761
|
context: Optional[Any] = None
|
|
645
762
|
messages: Optional[List[Message]] = None
|
|
646
763
|
|
|
764
|
+
|
|
647
765
|
@dataclass(frozen=True)
|
|
648
766
|
class LLMCallStartEvent:
|
|
649
|
-
type: Literal[
|
|
650
|
-
data: LLMCallStartEventData = field(
|
|
767
|
+
type: Literal["llm_call_start"] = "llm_call_start"
|
|
768
|
+
data: LLMCallStartEventData = field(
|
|
769
|
+
default_factory=lambda: LLMCallStartEventData("", "", TraceId(""), RunId(""))
|
|
770
|
+
)
|
|
771
|
+
|
|
651
772
|
|
|
652
773
|
@dataclass(frozen=True)
|
|
653
774
|
class LLMCallEndEventData:
|
|
654
775
|
"""Data for LLM call end events."""
|
|
776
|
+
|
|
655
777
|
choice: Any
|
|
656
778
|
trace_id: TraceId
|
|
657
779
|
run_id: RunId
|
|
658
780
|
usage: Optional[Dict[str, int]] = None
|
|
659
781
|
|
|
782
|
+
|
|
660
783
|
@dataclass(frozen=True)
|
|
661
784
|
class LLMCallEndEvent:
|
|
662
|
-
type: Literal[
|
|
663
|
-
data: LLMCallEndEventData = field(
|
|
785
|
+
type: Literal["llm_call_end"] = "llm_call_end"
|
|
786
|
+
data: LLMCallEndEventData = field(
|
|
787
|
+
default_factory=lambda: LLMCallEndEventData(None, TraceId(""), RunId(""))
|
|
788
|
+
)
|
|
789
|
+
|
|
664
790
|
|
|
665
791
|
@dataclass(frozen=True)
|
|
666
792
|
class AssistantMessageEventData:
|
|
667
793
|
"""Data for assistant message events (partial or complete)."""
|
|
794
|
+
|
|
668
795
|
message: Message
|
|
669
796
|
|
|
797
|
+
|
|
670
798
|
@dataclass(frozen=True)
|
|
671
799
|
class AssistantMessageEvent:
|
|
672
|
-
type: Literal[
|
|
673
|
-
data: AssistantMessageEventData = field(
|
|
800
|
+
type: Literal["assistant_message"] = "assistant_message"
|
|
801
|
+
data: AssistantMessageEventData = field(
|
|
802
|
+
default_factory=lambda: AssistantMessageEventData(
|
|
803
|
+
Message(role=ContentRole.ASSISTANT, content="")
|
|
804
|
+
)
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
|
|
674
808
|
@dataclass(frozen=True)
|
|
675
809
|
class ToolCallStartEventData:
|
|
676
810
|
"""Data for tool call start events."""
|
|
811
|
+
|
|
677
812
|
tool_name: str
|
|
678
813
|
args: Any
|
|
679
814
|
trace_id: TraceId
|
|
680
815
|
run_id: RunId
|
|
681
816
|
call_id: Optional[str] = None
|
|
682
817
|
|
|
818
|
+
|
|
683
819
|
@dataclass(frozen=True)
|
|
684
820
|
class ToolCallStartEvent:
|
|
685
|
-
type: Literal[
|
|
686
|
-
data: ToolCallStartEventData = field(
|
|
821
|
+
type: Literal["tool_call_start"] = "tool_call_start"
|
|
822
|
+
data: ToolCallStartEventData = field(
|
|
823
|
+
default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None)
|
|
824
|
+
)
|
|
825
|
+
|
|
687
826
|
|
|
688
827
|
@dataclass(frozen=True)
|
|
689
828
|
class ToolCallEndEventData:
|
|
@@ -703,18 +842,27 @@ class ToolCallEndEventData:
|
|
|
703
842
|
- 'rejected': Tool was rejected by user
|
|
704
843
|
- 'execution_error', 'validation_error', etc.: Various error states
|
|
705
844
|
"""
|
|
845
|
+
|
|
706
846
|
tool_name: str
|
|
707
847
|
result: str
|
|
708
848
|
trace_id: TraceId
|
|
709
849
|
run_id: RunId
|
|
710
850
|
tool_result: Optional[Any] = None
|
|
711
|
-
execution_status: Optional[str] =
|
|
712
|
-
|
|
851
|
+
execution_status: Optional[str] = (
|
|
852
|
+
None # success/error/timeout - indicates if tool executed successfully
|
|
853
|
+
)
|
|
854
|
+
status: Optional[str] = (
|
|
855
|
+
None # DEPRECATED: maintained for backward-compatible initialization/serialization
|
|
856
|
+
)
|
|
713
857
|
call_id: Optional[str] = None
|
|
714
858
|
|
|
715
859
|
def __post_init__(self) -> None:
|
|
716
860
|
# Handle backward compatibility with explicit conflict detection
|
|
717
|
-
if
|
|
861
|
+
if (
|
|
862
|
+
self.execution_status is not None
|
|
863
|
+
and self.status is not None
|
|
864
|
+
and self.execution_status != self.status
|
|
865
|
+
):
|
|
718
866
|
raise ValueError(
|
|
719
867
|
f"Conflicting values for execution_status ('{self.execution_status}') and status ('{self.status}'). "
|
|
720
868
|
f"Please use only execution_status for new code."
|
|
@@ -722,85 +870,111 @@ class ToolCallEndEventData:
|
|
|
722
870
|
|
|
723
871
|
# Prefer execution_status (new field) over status (deprecated field)
|
|
724
872
|
canonical = self.execution_status if self.execution_status is not None else self.status
|
|
725
|
-
object.__setattr__(self,
|
|
726
|
-
object.__setattr__(self,
|
|
873
|
+
object.__setattr__(self, "execution_status", canonical)
|
|
874
|
+
object.__setattr__(self, "status", canonical)
|
|
875
|
+
|
|
876
|
+
|
|
727
877
|
@dataclass(frozen=True)
|
|
728
878
|
class ToolCallEndEvent:
|
|
729
|
-
type: Literal[
|
|
730
|
-
data: ToolCallEndEventData = field(
|
|
879
|
+
type: Literal["tool_call_end"] = "tool_call_end"
|
|
880
|
+
data: ToolCallEndEventData = field(
|
|
881
|
+
default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None)
|
|
882
|
+
)
|
|
883
|
+
|
|
731
884
|
|
|
732
885
|
@dataclass(frozen=True)
|
|
733
886
|
class HandoffEventData:
|
|
734
887
|
"""Data for handoff events."""
|
|
735
|
-
|
|
888
|
+
|
|
889
|
+
from_: str = field(metadata={"alias": "from"}) # Using from_ since 'from' is a Python keyword
|
|
736
890
|
to: str
|
|
737
891
|
|
|
892
|
+
|
|
738
893
|
@dataclass(frozen=True)
|
|
739
894
|
class HandoffEvent:
|
|
740
|
-
type: Literal[
|
|
895
|
+
type: Literal["handoff"] = "handoff"
|
|
741
896
|
data: HandoffEventData = field(default_factory=lambda: HandoffEventData("", ""))
|
|
742
897
|
|
|
898
|
+
|
|
743
899
|
@dataclass(frozen=True)
|
|
744
900
|
class RunEndEventData:
|
|
745
901
|
"""Data for run end events."""
|
|
746
|
-
|
|
902
|
+
|
|
903
|
+
outcome: "RunOutcome[Any]"
|
|
747
904
|
trace_id: TraceId
|
|
748
905
|
run_id: RunId
|
|
749
906
|
|
|
907
|
+
|
|
750
908
|
@dataclass(frozen=True)
|
|
751
909
|
class RunEndEvent:
|
|
752
|
-
type: Literal[
|
|
753
|
-
data: RunEndEventData = field(
|
|
910
|
+
type: Literal["run_end"] = "run_end"
|
|
911
|
+
data: RunEndEventData = field(
|
|
912
|
+
default_factory=lambda: RunEndEventData(None, TraceId(""), RunId(""))
|
|
913
|
+
)
|
|
914
|
+
|
|
754
915
|
|
|
755
916
|
@dataclass(frozen=True)
|
|
756
917
|
class GuardrailEventData:
|
|
757
918
|
"""Data for guardrail check events."""
|
|
919
|
+
|
|
758
920
|
guardrail_name: str
|
|
759
921
|
content: Any
|
|
760
922
|
is_valid: Optional[bool] = None
|
|
761
923
|
error_message: Optional[str] = None
|
|
762
924
|
|
|
925
|
+
|
|
763
926
|
@dataclass(frozen=True)
|
|
764
927
|
class GuardrailEvent:
|
|
765
|
-
type: Literal[
|
|
928
|
+
type: Literal["guardrail_check"] = "guardrail_check"
|
|
766
929
|
data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
|
|
767
930
|
|
|
931
|
+
|
|
768
932
|
@dataclass(frozen=True)
|
|
769
933
|
class GuardrailViolationEventData:
|
|
770
934
|
"""Data for guardrail violation events."""
|
|
771
|
-
|
|
935
|
+
|
|
936
|
+
stage: Literal["input", "output"]
|
|
772
937
|
reason: str
|
|
773
938
|
|
|
939
|
+
|
|
774
940
|
@dataclass(frozen=True)
|
|
775
941
|
class GuardrailViolationEvent:
|
|
776
|
-
type: Literal[
|
|
777
|
-
data: GuardrailViolationEventData = field(
|
|
942
|
+
type: Literal["guardrail_violation"] = "guardrail_violation"
|
|
943
|
+
data: GuardrailViolationEventData = field(
|
|
944
|
+
default_factory=lambda: GuardrailViolationEventData("input", "")
|
|
945
|
+
)
|
|
946
|
+
|
|
778
947
|
|
|
779
948
|
@dataclass(frozen=True)
|
|
780
949
|
class MemoryEventData:
|
|
781
950
|
"""Data for memory operation events."""
|
|
782
|
-
|
|
951
|
+
|
|
952
|
+
operation: Literal["load", "store"]
|
|
783
953
|
conversation_id: str
|
|
784
|
-
status: Literal[
|
|
954
|
+
status: Literal["start", "end", "fail"]
|
|
785
955
|
error: Optional[str] = None
|
|
786
956
|
message_count: Optional[int] = None
|
|
787
957
|
|
|
958
|
+
|
|
788
959
|
@dataclass(frozen=True)
|
|
789
960
|
class MemoryEvent:
|
|
790
|
-
type: Literal[
|
|
961
|
+
type: Literal["memory_operation"] = "memory_operation"
|
|
791
962
|
data: MemoryEventData = field(default_factory=lambda: MemoryEventData("load", "", "start"))
|
|
792
963
|
|
|
964
|
+
|
|
793
965
|
@dataclass(frozen=True)
|
|
794
966
|
class OutputParseEventData:
|
|
795
967
|
"""Data for output parsing events."""
|
|
968
|
+
|
|
796
969
|
content: str
|
|
797
|
-
status: Literal[
|
|
970
|
+
status: Literal["start", "end", "fail"]
|
|
798
971
|
parsed_output: Optional[Any] = None
|
|
799
972
|
error: Optional[str] = None
|
|
800
973
|
|
|
974
|
+
|
|
801
975
|
@dataclass(frozen=True)
|
|
802
976
|
class OutputParseEvent:
|
|
803
|
-
type: Literal[
|
|
977
|
+
type: Literal["output_parse"] = "output_parse"
|
|
804
978
|
data: OutputParseEventData = field(default_factory=lambda: OutputParseEventData("", "start"))
|
|
805
979
|
|
|
806
980
|
|
|
@@ -817,72 +991,80 @@ TraceEvent = Union[
|
|
|
817
991
|
ToolCallStartEvent,
|
|
818
992
|
ToolCallEndEvent,
|
|
819
993
|
HandoffEvent,
|
|
820
|
-
RunEndEvent
|
|
994
|
+
RunEndEvent,
|
|
821
995
|
]
|
|
822
996
|
|
|
997
|
+
|
|
823
998
|
@dataclass(frozen=True)
|
|
824
999
|
class ModelCompletionMessage:
|
|
825
1000
|
"""Message structure returned by model completion."""
|
|
1001
|
+
|
|
826
1002
|
content: Optional[str] = None
|
|
827
1003
|
tool_calls: Optional[List[ToolCall]] = None
|
|
828
1004
|
|
|
1005
|
+
|
|
829
1006
|
@dataclass(frozen=True)
|
|
830
1007
|
class ModelCompletionResponse:
|
|
831
1008
|
"""Response structure from model completion."""
|
|
1009
|
+
|
|
832
1010
|
message: Optional[ModelCompletionMessage] = None
|
|
833
1011
|
|
|
1012
|
+
|
|
834
1013
|
# Streaming chunk structures for provider-level streaming support
|
|
835
1014
|
@dataclass(frozen=True)
|
|
836
1015
|
class ToolCallFunctionDelta:
|
|
837
1016
|
"""Function fields that may stream as deltas."""
|
|
1017
|
+
|
|
838
1018
|
name: Optional[str] = None
|
|
839
1019
|
arguments_delta: Optional[str] = None
|
|
840
1020
|
|
|
1021
|
+
|
|
841
1022
|
@dataclass(frozen=True)
|
|
842
1023
|
class ToolCallDelta:
|
|
843
1024
|
"""Represents a partial tool call delta in a streamed response."""
|
|
1025
|
+
|
|
844
1026
|
index: int
|
|
845
1027
|
id: Optional[str] = None
|
|
846
|
-
type: Literal[
|
|
1028
|
+
type: Literal["function"] = "function"
|
|
847
1029
|
function: Optional[ToolCallFunctionDelta] = None
|
|
848
1030
|
|
|
1031
|
+
|
|
849
1032
|
@dataclass(frozen=True)
|
|
850
1033
|
class CompletionStreamChunk:
|
|
851
1034
|
"""A streamed chunk from the model provider."""
|
|
1035
|
+
|
|
852
1036
|
delta: Optional[str] = None
|
|
853
1037
|
tool_call_delta: Optional[ToolCallDelta] = None
|
|
854
1038
|
is_done: Optional[bool] = False
|
|
855
1039
|
finish_reason: Optional[str] = None
|
|
856
1040
|
raw: Optional[Any] = None
|
|
857
1041
|
|
|
1042
|
+
|
|
858
1043
|
@runtime_checkable
|
|
859
1044
|
class ModelProvider(Protocol[Ctx]):
|
|
860
1045
|
"""Protocol for model providers."""
|
|
861
1046
|
|
|
862
1047
|
async def get_completion(
|
|
863
|
-
self,
|
|
864
|
-
state: RunState[Ctx],
|
|
865
|
-
agent: Agent[Ctx, Any],
|
|
866
|
-
config: 'RunConfig[Ctx]'
|
|
1048
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
|
|
867
1049
|
) -> ModelCompletionResponse:
|
|
868
1050
|
"""Get completion from the model."""
|
|
869
1051
|
...
|
|
870
1052
|
|
|
871
1053
|
async def get_completion_stream(
|
|
872
|
-
self,
|
|
873
|
-
state: RunState[Ctx],
|
|
874
|
-
agent: Agent[Ctx, Any],
|
|
875
|
-
config: 'RunConfig[Ctx]'
|
|
1054
|
+
self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
|
|
876
1055
|
) -> AsyncIterator[CompletionStreamChunk]:
|
|
877
1056
|
"""Optional streaming API: yields incremental deltas while generating."""
|
|
878
1057
|
...
|
|
879
1058
|
|
|
1059
|
+
|
|
880
1060
|
@dataclass(frozen=True)
|
|
881
1061
|
class RunConfig(Generic[Ctx]):
|
|
882
1062
|
"""Configuration for running agents."""
|
|
1063
|
+
|
|
883
1064
|
agent_registry: Dict[str, Agent[Ctx, Any]]
|
|
884
1065
|
model_provider: ModelProvider[Ctx]
|
|
885
1066
|
max_turns: Optional[int] = 50
|
|
1067
|
+
max_tokens: Optional[int] = None # Default max_tokens for all agents (can be overridden per agent)
|
|
886
1068
|
model_override: Optional[str] = None
|
|
887
1069
|
initial_input_guardrails: Optional[List[Guardrail]] = None
|
|
888
1070
|
final_output_guardrails: Optional[List[Guardrail]] = None
|
|
@@ -891,32 +1073,68 @@ class RunConfig(Generic[Ctx]):
|
|
|
891
1073
|
conversation_id: Optional[str] = None
|
|
892
1074
|
default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
|
|
893
1075
|
default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
|
|
894
|
-
approval_storage: Optional[
|
|
895
|
-
before_llm_call: Optional[
|
|
896
|
-
|
|
1076
|
+
approval_storage: Optional["ApprovalStorage"] = None # Storage for approval decisions
|
|
1077
|
+
before_llm_call: Optional[
|
|
1078
|
+
Callable[[RunState[Ctx], Agent[Ctx, Any]], Union[RunState[Ctx], Awaitable[RunState[Ctx]]]]
|
|
1079
|
+
] = None # Callback before LLM call - can modify context/messages
|
|
1080
|
+
after_llm_call: Optional[
|
|
1081
|
+
Callable[
|
|
1082
|
+
[RunState[Ctx], ModelCompletionResponse],
|
|
1083
|
+
Union[ModelCompletionResponse, Awaitable[ModelCompletionResponse]],
|
|
1084
|
+
]
|
|
1085
|
+
] = None # Callback after LLM call - can process response
|
|
897
1086
|
max_empty_response_retries: int = 3 # Maximum retries when LLM returns empty response
|
|
898
|
-
empty_response_retry_delay: float =
|
|
1087
|
+
empty_response_retry_delay: float = (
|
|
1088
|
+
1.0 # Initial delay in seconds before retrying empty response (uses exponential backoff)
|
|
1089
|
+
)
|
|
899
1090
|
log_empty_responses: bool = True # Whether to log diagnostic info for empty responses
|
|
900
|
-
prefer_streaming: Optional[bool] =
|
|
1091
|
+
prefer_streaming: Optional[bool] = (
|
|
1092
|
+
None # Whether to prefer streaming responses. None (default) = use streaming if available, True = prefer streaming, False = disable streaming
|
|
1093
|
+
)
|
|
901
1094
|
|
|
902
1095
|
|
|
903
1096
|
# Regeneration types for conversation management
|
|
904
1097
|
@dataclass(frozen=True)
|
|
905
1098
|
class RegenerationRequest:
|
|
906
1099
|
"""Request to regenerate a conversation from a specific message."""
|
|
1100
|
+
|
|
907
1101
|
conversation_id: str
|
|
908
1102
|
message_id: MessageId # ID of the message to regenerate from
|
|
909
1103
|
context: Optional[Dict[str, Any]] = None # Optional context override
|
|
910
1104
|
|
|
1105
|
+
|
|
911
1106
|
@dataclass(frozen=True)
|
|
912
1107
|
class RegenerationContext:
|
|
913
1108
|
"""Context information for a regeneration operation."""
|
|
1109
|
+
|
|
914
1110
|
original_message_count: int
|
|
915
1111
|
truncated_at_index: int
|
|
916
1112
|
regenerated_message_id: MessageId
|
|
917
1113
|
regeneration_id: str # Unique ID for this regeneration operation
|
|
918
1114
|
timestamp: int # Unix timestamp in milliseconds
|
|
919
1115
|
|
|
1116
|
+
|
|
1117
|
+
# Checkpoint types for conversation management
|
|
1118
|
+
@dataclass(frozen=True)
|
|
1119
|
+
class CheckpointRequest:
|
|
1120
|
+
"""Request to checkpoint a conversation after a specific message."""
|
|
1121
|
+
|
|
1122
|
+
conversation_id: str
|
|
1123
|
+
message_id: MessageId # ID of the message to checkpoint after (this message is kept)
|
|
1124
|
+
context: Optional[Dict[str, Any]] = None # Optional context for the checkpoint
|
|
1125
|
+
|
|
1126
|
+
|
|
1127
|
+
@dataclass(frozen=True)
|
|
1128
|
+
class CheckpointContext:
|
|
1129
|
+
"""Context information for a checkpoint operation."""
|
|
1130
|
+
|
|
1131
|
+
original_message_count: int
|
|
1132
|
+
checkpointed_at_index: int
|
|
1133
|
+
checkpointed_message_id: MessageId
|
|
1134
|
+
checkpoint_id: str # Unique ID for this checkpoint operation
|
|
1135
|
+
timestamp: int # Unix timestamp in milliseconds
|
|
1136
|
+
|
|
1137
|
+
|
|
920
1138
|
# Message utility functions
|
|
921
1139
|
def find_message_index(messages: List[Message], message_id: MessageId) -> Optional[int]:
|
|
922
1140
|
"""Find the index of a message by its ID."""
|
|
@@ -925,6 +1143,7 @@ def find_message_index(messages: List[Message], message_id: MessageId) -> Option
|
|
|
925
1143
|
return i
|
|
926
1144
|
return None
|
|
927
1145
|
|
|
1146
|
+
|
|
928
1147
|
def truncate_messages_after(messages: List[Message], message_id: MessageId) -> List[Message]:
|
|
929
1148
|
"""Truncate messages after (and including) the specified message ID."""
|
|
930
1149
|
index = find_message_index(messages, message_id)
|
|
@@ -932,6 +1151,7 @@ def truncate_messages_after(messages: List[Message], message_id: MessageId) -> L
|
|
|
932
1151
|
return messages # Message not found, return unchanged
|
|
933
1152
|
return messages[:index]
|
|
934
1153
|
|
|
1154
|
+
|
|
935
1155
|
def get_message_by_id(messages: List[Message], message_id: MessageId) -> Optional[Message]:
|
|
936
1156
|
"""Get a message by its ID."""
|
|
937
1157
|
for msg in messages:
|