jaf-py 2.5.9__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +361 -280
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.9.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/core/types.py CHANGED
@@ -36,239 +36,274 @@ if TYPE_CHECKING:
36
36
  # Comprehensive enums for type safety and improved developer experience
37
37
  class Model(str, Enum):
38
38
  """Supported model identifiers."""
39
- GEMINI_2_0_FLASH = 'gemini-2.0-flash'
40
- GEMINI_2_5_PRO = 'gemini-2.5-pro'
41
- GEMINI_PRO = 'gemini-pro'
42
- GPT_4 = 'gpt-4'
43
- GPT_4_TURBO = 'gpt-4-turbo'
44
- GPT_3_5_TURBO = 'gpt-3.5-turbo'
45
- CLAUDE_3_SONNET = 'claude-3-sonnet'
46
- CLAUDE_3_HAIKU = 'claude-3-haiku'
47
- CLAUDE_3_OPUS = 'claude-3-opus'
39
+
40
+ GEMINI_2_0_FLASH = "gemini-2.0-flash"
41
+ GEMINI_2_5_PRO = "gemini-2.5-pro"
42
+ GEMINI_PRO = "gemini-pro"
43
+ GPT_4 = "gpt-4"
44
+ GPT_4_TURBO = "gpt-4-turbo"
45
+ GPT_3_5_TURBO = "gpt-3.5-turbo"
46
+ CLAUDE_3_SONNET = "claude-3-sonnet"
47
+ CLAUDE_3_HAIKU = "claude-3-haiku"
48
+ CLAUDE_3_OPUS = "claude-3-opus"
48
49
 
49
50
 
50
51
  class ToolParameterType(str, Enum):
51
52
  """Tool parameter types."""
52
- STRING = 'string'
53
- NUMBER = 'number'
54
- INTEGER = 'integer'
55
- BOOLEAN = 'boolean'
56
- ARRAY = 'array'
57
- OBJECT = 'object'
58
- NULL = 'null'
53
+
54
+ STRING = "string"
55
+ NUMBER = "number"
56
+ INTEGER = "integer"
57
+ BOOLEAN = "boolean"
58
+ ARRAY = "array"
59
+ OBJECT = "object"
60
+ NULL = "null"
59
61
 
60
62
 
61
63
  class ToolSource(str, Enum):
62
64
  """Source of tool definitions."""
63
- NATIVE = 'native'
64
- MCP = 'mcp'
65
- PLUGIN = 'plugin'
66
- EXTERNAL = 'external'
65
+
66
+ NATIVE = "native"
67
+ MCP = "mcp"
68
+ PLUGIN = "plugin"
69
+ EXTERNAL = "external"
67
70
 
68
71
 
69
72
  class ContentRole(str, Enum):
70
73
  """Message content roles."""
71
- USER = 'user'
72
- ASSISTANT = 'assistant'
73
- TOOL = 'tool'
74
- SYSTEM = 'system'
74
+
75
+ USER = "user"
76
+ ASSISTANT = "assistant"
77
+ TOOL = "tool"
78
+ SYSTEM = "system"
75
79
 
76
80
 
77
81
  class PartType(str, Enum):
78
82
  """Message part types."""
79
- TEXT = 'text'
80
- IMAGE = 'image'
81
- AUDIO = 'audio'
82
- VIDEO = 'video'
83
- FILE = 'file'
83
+
84
+ TEXT = "text"
85
+ IMAGE = "image"
86
+ AUDIO = "audio"
87
+ VIDEO = "video"
88
+ FILE = "file"
84
89
 
85
90
 
86
91
  # Branded types for type safety - using class-based approach for better type safety
87
92
  class TraceId(str):
88
93
  """Branded string type for trace IDs."""
89
- def __new__(cls, value: str) -> 'TraceId':
94
+
95
+ def __new__(cls, value: str) -> "TraceId":
90
96
  return str.__new__(cls, value)
91
97
 
98
+
92
99
  class RunId(str):
93
100
  """Branded string type for run IDs."""
94
- def __new__(cls, value: str) -> 'RunId':
101
+
102
+ def __new__(cls, value: str) -> "RunId":
95
103
  return str.__new__(cls, value)
96
104
 
105
+
97
106
  class MessageId(str):
98
107
  """Branded string type for message IDs."""
99
- def __new__(cls, value: str) -> 'MessageId':
108
+
109
+ def __new__(cls, value: str) -> "MessageId":
100
110
  return str.__new__(cls, value)
101
111
 
112
+
102
113
  def create_trace_id(id_str: str) -> TraceId:
103
114
  """Create a TraceId from a string."""
104
115
  return TraceId(id_str)
105
116
 
117
+
106
118
  def create_run_id(id_str: str) -> RunId:
107
119
  """Create a RunId from a string."""
108
120
  return RunId(id_str)
109
121
 
122
+
110
123
  def create_message_id(id_str: Union[str, MessageId]) -> MessageId:
111
124
  """
112
125
  Create a MessageId from a string or return existing MessageId.
113
-
126
+
114
127
  Args:
115
128
  id_str: Either a string to convert to MessageId or an existing MessageId
116
-
129
+
117
130
  Returns:
118
131
  MessageId: A validated MessageId instance
119
-
132
+
120
133
  Raises:
121
134
  ValueError: If the input is invalid or empty
122
135
  """
123
136
  # Handle None input
124
137
  if id_str is None:
125
138
  raise ValueError("Message ID cannot be None")
126
-
139
+
127
140
  # If already a MessageId, return as-is
128
141
  if isinstance(id_str, MessageId):
129
142
  return id_str
130
-
143
+
131
144
  # Convert string to MessageId with validation
132
145
  if isinstance(id_str, str):
133
146
  if not id_str.strip():
134
147
  raise ValueError("Message ID cannot be empty or whitespace")
135
148
  return MessageId(id_str.strip())
136
-
149
+
137
150
  # Handle any other type
138
151
  raise ValueError(f"Message ID must be a string or MessageId, got {type(id_str)}")
139
152
 
153
+
140
154
  def generate_run_id() -> RunId:
141
155
  """Generate a new unique run ID."""
142
156
  import time
143
157
  import uuid
158
+
144
159
  return RunId(f"run_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
145
160
 
161
+
146
162
  def generate_trace_id() -> TraceId:
147
163
  """Generate a new unique trace ID."""
148
164
  import time
149
165
  import uuid
166
+
150
167
  return TraceId(f"trace_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
151
168
 
169
+
152
170
  def generate_message_id() -> MessageId:
153
171
  """Generate a new unique message ID."""
154
172
  import time
155
173
  import uuid
174
+
156
175
  return MessageId(f"msg_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
157
176
 
177
+
158
178
  # Type variables for generic contexts and outputs
159
- Ctx = TypeVar('Ctx')
160
- Out = TypeVar('Out')
161
- Args = TypeVar('Args')
179
+ Ctx = TypeVar("Ctx")
180
+ Out = TypeVar("Out")
181
+ Args = TypeVar("Args")
182
+
162
183
 
163
184
  # Discriminated union for ValidationResult to match TypeScript version
164
185
  @dataclass(frozen=True)
165
186
  class ValidValidationResult:
166
187
  """Valid validation result."""
188
+
167
189
  is_valid: Literal[True] = True
168
190
 
191
+
169
192
  @dataclass(frozen=True)
170
193
  class InvalidValidationResult:
171
194
  """Invalid validation result with error message."""
195
+
172
196
  is_valid: Literal[False] = False
173
197
  error_message: str = ""
174
198
 
199
+
175
200
  ValidationResult = Union[ValidValidationResult, InvalidValidationResult]
176
201
 
202
+
177
203
  @dataclass(frozen=True)
178
204
  class ToolCall:
179
205
  """Represents a tool call from the model."""
206
+
180
207
  id: str
181
- type: Literal['function']
182
- function: 'ToolCallFunction'
208
+ type: Literal["function"]
209
+ function: "ToolCallFunction"
210
+
183
211
 
184
212
  @dataclass(frozen=True)
185
213
  class ToolCallFunction:
186
214
  """Function information within a tool call."""
215
+
187
216
  name: str
188
217
  arguments: str
189
218
 
219
+
190
220
  @dataclass(frozen=True)
191
221
  class Attachment:
192
222
  """Represents an attachment with various content types."""
193
- kind: Literal['image', 'document', 'file']
223
+
224
+ kind: Literal["image", "document", "file"]
194
225
  mime_type: Optional[str] = None # e.g. image/png, application/pdf
195
- name: Optional[str] = None # Optional filename
196
- url: Optional[str] = None # Remote URL or data URL
197
- data: Optional[str] = None # Base64 without data: prefix
198
- format: Optional[str] = None # Optional short format like 'pdf', 'txt'
226
+ name: Optional[str] = None # Optional filename
227
+ url: Optional[str] = None # Remote URL or data URL
228
+ data: Optional[str] = None # Base64 without data: prefix
229
+ format: Optional[str] = None # Optional short format like 'pdf', 'txt'
199
230
  use_litellm_format: Optional[bool] = None # Use LiteLLM native file format
200
-
231
+
201
232
  def __post_init__(self):
202
233
  """Validate that at least one of url or data is provided."""
203
234
  if self.url is None and self.data is None:
204
235
  raise ValueError("At least one of 'url' or 'data' must be provided for an Attachment.")
205
236
 
237
+
206
238
  @dataclass(frozen=True)
207
239
  class MessageContentPart:
208
240
  """Part of multi-part message content."""
209
- type: Literal['text', 'image_url', 'file']
241
+
242
+ type: Literal["text", "image_url", "file"]
210
243
  text: Optional[str] = None
211
244
  image_url: Optional[Dict[str, Any]] = None # Contains url and optional detail
212
- file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
245
+ file: Optional[Dict[str, Any]] = None # Contains file_id and optional format
246
+
213
247
 
214
248
  @dataclass(frozen=True)
215
249
  class Message:
216
250
  """
217
251
  A message in the conversation.
218
-
252
+
219
253
  BACKWARDS COMPATIBILITY:
220
254
  - Messages created with string content remain fully backwards compatible
221
255
  - Direct access to .content returns the original string when created with string
222
256
  - Use .text_content property for guaranteed string access in all cases
223
257
  - Use get_text_content() function to extract text from any content type
224
258
  - message_id is optional for backward compatibility
225
-
259
+
226
260
  Examples:
227
261
  # Original usage - still works exactly the same
228
262
  msg = Message(role='user', content='Hello')
229
263
  text = msg.content # Returns 'Hello' as string
230
-
264
+
231
265
  # New usage with message ID
232
266
  msg = Message(role='user', content='Hello', message_id='msg_123')
233
-
267
+
234
268
  # Guaranteed string access (recommended for new code)
235
269
  text = msg.text_content # Always returns string
236
-
270
+
237
271
  # Universal text extraction
238
272
  text = get_text_content(msg.content) # Works with any content type
239
273
  """
274
+
240
275
  role: ContentRole
241
276
  content: Union[str, List[MessageContentPart]]
242
277
  attachments: Optional[List[Attachment]] = None
243
278
  tool_call_id: Optional[str] = None
244
279
  tool_calls: Optional[List[ToolCall]] = None
245
280
  message_id: Optional[MessageId] = None # Optional for backward compatibility
246
-
281
+
247
282
  def __post_init__(self):
248
283
  """
249
284
  Auto-generate message ID if not provided.
250
-
285
+
251
286
  This implementation uses object.__setattr__ to bypass frozen dataclass restrictions,
252
287
  which is a recommended pattern for one-time initialization of computed fields in
253
288
  frozen dataclasses. This ensures:
254
-
289
+
255
290
  1. Backward compatibility - existing code with message_id=None continues to work
256
291
  2. Immutability - the dataclass remains frozen after initialization
257
292
  3. Guaranteed unique IDs - every message gets a unique identifier
258
293
  4. Clean API - users don't need to manually generate IDs in most cases
259
-
294
+
260
295
  This pattern is preferred over using field(default_factory=...) because it
261
296
  maintains the Optional[MessageId] type hint for backward compatibility while
262
297
  ensuring the field is never actually None after object creation.
263
298
  """
264
299
  if self.message_id is None:
265
- object.__setattr__(self, 'message_id', generate_message_id())
266
-
300
+ object.__setattr__(self, "message_id", generate_message_id())
301
+
267
302
  @property
268
303
  def text_content(self) -> str:
269
304
  """Get text content as string for backwards compatibility."""
270
305
  return get_text_content(self.content)
271
-
306
+
272
307
  @classmethod
273
308
  def create(
274
309
  cls,
@@ -277,8 +312,8 @@ class Message:
277
312
  attachments: Optional[List[Attachment]] = None,
278
313
  tool_call_id: Optional[str] = None,
279
314
  tool_calls: Optional[List[ToolCall]] = None,
280
- message_id: Optional[MessageId] = None
281
- ) -> 'Message':
315
+ message_id: Optional[MessageId] = None,
316
+ ) -> "Message":
282
317
  """Create a message with string content and optional attachments."""
283
318
  return cls(
284
319
  role=role,
@@ -286,32 +321,43 @@ class Message:
286
321
  attachments=attachments,
287
322
  tool_call_id=tool_call_id,
288
323
  tool_calls=tool_calls,
289
- message_id=message_id
324
+ message_id=message_id,
290
325
  )
291
326
 
327
+
292
328
  def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
293
329
  """Extract text content from message content."""
294
330
  if isinstance(content, str):
295
331
  return content
296
-
297
- text_parts = [part.text for part in content if part.type == 'text' and part.text]
298
- return ' '.join(text_parts)
332
+
333
+ text_parts = [part.text for part in content if part.type == "text" and part.text]
334
+ return " ".join(text_parts)
335
+
299
336
 
300
337
  @dataclass(frozen=True)
301
338
  class ModelConfig:
302
339
  """Configuration for model behavior."""
340
+
303
341
  name: Optional[str] = None
304
342
  temperature: Optional[float] = None
305
343
  max_tokens: Optional[int] = None
344
+ inline_tool_schemas: Optional[bool] = (
345
+ None # If True, resolve $refs and inline $defs in tool schemas
346
+ )
347
+
306
348
 
307
349
  @dataclass(frozen=True)
308
350
  class ToolSchema(Generic[Args]):
309
351
  """Schema definition for a tool."""
352
+
310
353
  name: str
311
354
  description: str
312
- parameters: Any # Should be a type that can validate Args (like Pydantic model or Zod equivalent)
355
+ parameters: (
356
+ Any # Should be a type that can validate Args (like Pydantic model or Zod equivalent)
357
+ )
313
358
  timeout: Optional[float] = None # Optional timeout in seconds for tool execution
314
359
 
360
+
315
361
  @runtime_checkable
316
362
  class Tool(Protocol[Args, Ctx]):
317
363
  """Protocol for tool implementations."""
@@ -321,7 +367,7 @@ class Tool(Protocol[Args, Ctx]):
321
367
  """Tool schema including name, description, and parameter validation."""
322
368
  ...
323
369
 
324
- async def execute(self, args: Args, context: Ctx) -> Union[str, 'ToolResult[Any]']:
370
+ async def execute(self, args: Args, context: Ctx) -> Union[str, "ToolResult[Any]"]:
325
371
  """Execute the tool with given arguments and context."""
326
372
  ...
327
373
 
@@ -334,9 +380,12 @@ class Tool(Protocol[Args, Ctx]):
334
380
  # Function tool configuration for improved DX
335
381
  class FunctionToolConfig(TypedDict):
336
382
  """Configuration for creating function-based tools with object-based API."""
383
+
337
384
  name: str
338
385
  description: str
339
- execute: Callable[[Any, Any], Union[str, 'ToolResult[Any]', Awaitable[Union[str, 'ToolResult[Any]']]]]
386
+ execute: Callable[
387
+ [Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
388
+ ]
340
389
  parameters: Any # Pydantic model or similar for parameter validation
341
390
  metadata: Optional[Dict[str, Any]] # Optional metadata
342
391
  source: Optional[ToolSource] # Optional source tracking
@@ -344,34 +393,45 @@ class FunctionToolConfig(TypedDict):
344
393
 
345
394
 
346
395
  # Type alias for tool execution functions
347
- ToolExecuteFunction = Callable[[Any, Any], Union[str, 'ToolResult[Any]', Awaitable[Union[str, 'ToolResult[Any]']]]]
396
+ ToolExecuteFunction = Callable[
397
+ [Any, Any], Union[str, "ToolResult[Any]", Awaitable[Union[str, "ToolResult[Any]"]]]
398
+ ]
348
399
 
349
400
 
350
401
  @dataclass(frozen=True)
351
402
  class Agent(Generic[Ctx, Out]):
352
403
  """An agent definition with instructions, tools, and configuration."""
404
+
353
405
  name: str
354
- instructions: Callable[['RunState[Ctx]'], str]
406
+ instructions: Callable[["RunState[Ctx]"], str]
355
407
  tools: Optional[List[Tool[Any, Ctx]]] = None
356
- output_codec: Optional[Any] = None # Type that can validate Out (like Pydantic model or Zod equivalent)
408
+ output_codec: Optional[Any] = (
409
+ None # Type that can validate Out (like Pydantic model or Zod equivalent)
410
+ )
357
411
  handoffs: Optional[List[str]] = None
358
412
  model_config: Optional[ModelConfig] = None
359
- advanced_config: Optional['AdvancedConfig'] = None
413
+ advanced_config: Optional["AdvancedConfig"] = None
360
414
 
361
415
  def as_tool(
362
416
  self,
363
417
  tool_name: Optional[str] = None,
364
418
  tool_description: Optional[str] = None,
365
419
  max_turns: Optional[int] = None,
366
- custom_output_extractor: Optional[Callable[['RunResult[Out]'], Union[str, Awaitable[str]]]] = None,
367
- is_enabled: Union[bool, Callable[[Any, 'Agent[Ctx, Out]'], bool], Callable[[Any, 'Agent[Ctx, Out]'], Awaitable[bool]]] = True,
420
+ custom_output_extractor: Optional[
421
+ Callable[["RunResult[Out]"], Union[str, Awaitable[str]]]
422
+ ] = None,
423
+ is_enabled: Union[
424
+ bool,
425
+ Callable[[Any, "Agent[Ctx, Out]"], bool],
426
+ Callable[[Any, "Agent[Ctx, Out]"], Awaitable[bool]],
427
+ ] = True,
368
428
  metadata: Optional[Dict[str, Any]] = None,
369
429
  timeout: Optional[float] = None,
370
- preserve_session: bool = False
430
+ preserve_session: bool = False,
371
431
  ) -> Tool[Any, Ctx]:
372
432
  """
373
433
  Convert this agent into a tool that can be used by other agents.
374
-
434
+
375
435
  Args:
376
436
  tool_name: Optional custom name for the tool (defaults to agent name)
377
437
  tool_description: Optional custom description (defaults to generic description)
@@ -380,11 +440,12 @@ class Agent(Generic[Ctx, Out]):
380
440
  is_enabled: Whether the tool is enabled (bool, sync function, or async function)
381
441
  metadata: Optional metadata for the tool
382
442
  timeout: Optional timeout for the tool execution
383
-
443
+
384
444
  Returns:
385
445
  A Tool that wraps this agent's execution
386
446
  """
387
447
  from .agent_tool import create_agent_tool
448
+
388
449
  return create_agent_tool(
389
450
  agent=self,
390
451
  tool_name=tool_name,
@@ -394,90 +455,109 @@ class Agent(Generic[Ctx, Out]):
394
455
  is_enabled=is_enabled,
395
456
  metadata=metadata,
396
457
  timeout=timeout,
397
- preserve_session=preserve_session
458
+ preserve_session=preserve_session,
398
459
  )
399
460
 
461
+
400
462
  # Guardrail type
401
463
  Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
402
464
 
465
+
403
466
  @dataclass(frozen=True)
404
467
  class AdvancedGuardrailsConfig:
405
468
  """Configuration for advanced guardrails with LLM-based validation."""
469
+
406
470
  input_prompt: Optional[str] = None
407
471
  output_prompt: Optional[str] = None
408
472
  require_citations: bool = False
409
473
  fast_model: Optional[str] = None
410
- fail_safe: Literal['allow', 'block'] = 'allow'
411
- execution_mode: Literal['parallel', 'sequential'] = 'parallel'
474
+ fail_safe: Literal["allow", "block"] = "allow"
475
+ execution_mode: Literal["parallel", "sequential"] = "parallel"
412
476
  timeout_ms: int = 30000
413
477
 
414
478
  def __post_init__(self):
415
479
  """Validate configuration."""
416
480
  if self.timeout_ms < 1000:
417
- object.__setattr__(self, 'timeout_ms', 1000)
481
+ object.__setattr__(self, "timeout_ms", 1000)
482
+
418
483
 
419
484
  @dataclass(frozen=True)
420
485
  class AdvancedConfig:
421
486
  """Advanced agent configuration including guardrails."""
487
+
422
488
  guardrails: Optional[AdvancedGuardrailsConfig] = None
423
489
 
424
- def validate_guardrails_config(config: Optional[AdvancedGuardrailsConfig]) -> AdvancedGuardrailsConfig:
490
+
491
+ def validate_guardrails_config(
492
+ config: Optional[AdvancedGuardrailsConfig],
493
+ ) -> AdvancedGuardrailsConfig:
425
494
  """Validate and provide defaults for guardrails configuration."""
426
495
  if config is None:
427
496
  return AdvancedGuardrailsConfig()
428
-
497
+
429
498
  return AdvancedGuardrailsConfig(
430
- input_prompt=config.input_prompt.strip() if isinstance(config.input_prompt, str) and config.input_prompt else None,
431
- output_prompt=config.output_prompt.strip() if isinstance(config.output_prompt, str) and config.output_prompt else None,
499
+ input_prompt=config.input_prompt.strip()
500
+ if isinstance(config.input_prompt, str) and config.input_prompt
501
+ else None,
502
+ output_prompt=config.output_prompt.strip()
503
+ if isinstance(config.output_prompt, str) and config.output_prompt
504
+ else None,
432
505
  require_citations=config.require_citations,
433
- fast_model=config.fast_model.strip() if isinstance(config.fast_model, str) and config.fast_model else None,
506
+ fast_model=config.fast_model.strip()
507
+ if isinstance(config.fast_model, str) and config.fast_model
508
+ else None,
434
509
  fail_safe=config.fail_safe,
435
510
  execution_mode=config.execution_mode,
436
- timeout_ms=max(1000, config.timeout_ms)
511
+ timeout_ms=max(1000, config.timeout_ms),
437
512
  )
438
513
 
514
+
439
515
  def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
440
516
  """Parse JSON from LLM output, handling common formatting issues."""
441
517
  import json
442
518
  import re
443
-
519
+
444
520
  if not text:
445
521
  return None
446
-
522
+
447
523
  # Try direct parsing first
448
524
  try:
449
525
  return json.loads(text)
450
526
  except json.JSONDecodeError:
451
527
  pass
452
-
528
+
453
529
  # Try to extract JSON from markdown code blocks
454
- json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
530
+ json_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
455
531
  if json_match:
456
532
  try:
457
533
  return json.loads(json_match.group(1))
458
534
  except json.JSONDecodeError:
459
535
  pass
460
-
536
+
461
537
  # Try to find the first JSON object in the text
462
- json_match = re.search(r'\{.*?\}', text, re.DOTALL)
538
+ json_match = re.search(r"\{.*?\}", text, re.DOTALL)
463
539
  if json_match:
464
540
  try:
465
541
  return json.loads(json_match.group(0))
466
542
  except json.JSONDecodeError:
467
543
  pass
468
-
544
+
469
545
  return None
470
546
 
547
+
471
548
  @dataclass(frozen=True)
472
549
  class ApprovalValue:
473
550
  """Represents an approval decision with context."""
551
+
474
552
  status: str # 'pending', 'approved', 'rejected'
475
553
  approved: bool
476
554
  additional_context: Optional[Dict[str, Any]] = None
477
555
 
556
+
478
557
  @dataclass(frozen=True)
479
558
  class RunState(Generic[Ctx]):
480
559
  """Immutable state of a run."""
560
+
481
561
  run_id: RunId
482
562
  trace_id: TraceId
483
563
  messages: List[Message]
@@ -486,51 +566,61 @@ class RunState(Generic[Ctx]):
486
566
  turn_count: int
487
567
  approvals: Dict[str, ApprovalValue] = field(default_factory=dict)
488
568
 
569
+
489
570
  # Error types using dataclasses for immutability
490
571
  @dataclass(frozen=True)
491
572
  class MaxTurnsExceeded:
492
573
  _tag: Literal["MaxTurnsExceeded"] = "MaxTurnsExceeded"
493
574
  turns: int = 0
494
575
 
576
+
495
577
  @dataclass(frozen=True)
496
578
  class ModelBehaviorError:
497
579
  _tag: Literal["ModelBehaviorError"] = "ModelBehaviorError"
498
580
  detail: str = ""
499
581
 
582
+
500
583
  @dataclass(frozen=True)
501
584
  class DecodeError:
502
585
  _tag: Literal["DecodeError"] = "DecodeError"
503
586
  errors: List[Dict[str, Any]] = field(default_factory=list)
504
587
 
588
+
505
589
  @dataclass(frozen=True)
506
590
  class InputGuardrailTripwire:
507
591
  _tag: Literal["InputGuardrailTripwire"] = "InputGuardrailTripwire"
508
592
  reason: str = ""
509
593
 
594
+
510
595
  @dataclass(frozen=True)
511
596
  class OutputGuardrailTripwire:
512
597
  _tag: Literal["OutputGuardrailTripwire"] = "OutputGuardrailTripwire"
513
598
  reason: str = ""
514
599
 
600
+
515
601
  @dataclass(frozen=True)
516
602
  class ToolCallError:
517
603
  _tag: Literal["ToolCallError"] = "ToolCallError"
518
604
  tool: str = ""
519
605
  detail: str = ""
520
606
 
607
+
521
608
  @dataclass(frozen=True)
522
609
  class HandoffError:
523
610
  _tag: Literal["HandoffError"] = "HandoffError"
524
611
  detail: str = ""
525
612
 
613
+
526
614
  @dataclass(frozen=True)
527
615
  class AgentNotFound:
528
616
  _tag: Literal["AgentNotFound"] = "AgentNotFound"
529
617
  agent_name: str = ""
530
618
 
619
+
531
620
  @dataclass(frozen=True)
532
621
  class RecoverableError:
533
622
  """Error that can be recovered from with retry strategies."""
623
+
534
624
  _tag: Literal["RecoverableError"] = "RecoverableError"
535
625
  detail: str = ""
536
626
  retry_strategy: Optional[str] = None
@@ -538,40 +628,51 @@ class RecoverableError:
538
628
  current_attempt: int = 1
539
629
  backoff_seconds: float = 1.0
540
630
 
631
+
541
632
  @dataclass(frozen=True)
542
633
  class RateLimitError:
543
634
  """Error indicating rate limit has been exceeded."""
635
+
544
636
  _tag: Literal["RateLimitError"] = "RateLimitError"
545
637
  detail: str = ""
546
638
  retry_after_seconds: Optional[float] = None
547
639
  limit_type: str = "requests" # "requests", "tokens", "concurrent"
548
640
 
641
+
549
642
  @dataclass(frozen=True)
550
643
  class ValidationError:
551
644
  """Enhanced validation error with detailed context."""
645
+
552
646
  _tag: Literal["ValidationError"] = "ValidationError"
553
647
  detail: str = ""
554
648
  field_errors: List[Dict[str, Any]] = field(default_factory=list)
555
649
  error_code: str = "validation_failed"
556
650
 
651
+
557
652
  @dataclass(frozen=True)
558
653
  class NetworkError:
559
654
  """Network-related errors with retry information."""
655
+
560
656
  _tag: Literal["NetworkError"] = "NetworkError"
561
657
  detail: str = ""
562
658
  status_code: Optional[int] = None
563
659
  is_retryable: bool = True
564
660
  endpoint: Optional[str] = None
565
661
 
662
+
566
663
  # Interruption types for HITL
567
664
  @dataclass(frozen=True)
568
665
  class ToolApprovalInterruption(Generic[Ctx]):
569
666
  """Interruption for tool approval."""
570
- type: Literal['tool_approval'] = 'tool_approval'
571
- tool_call: ToolCall = field(default_factory=lambda: ToolCall("", "function", ToolCallFunction("", "")))
572
- agent: 'Agent[Ctx, Any]' = None
667
+
668
+ type: Literal["tool_approval"] = "tool_approval"
669
+ tool_call: ToolCall = field(
670
+ default_factory=lambda: ToolCall("", "function", ToolCallFunction("", ""))
671
+ )
672
+ agent: "Agent[Ctx, Any]" = None
573
673
  session_id: Optional[str] = None
574
674
 
675
+
575
676
  # Union type for all interruptions
576
677
  Interruption = Union[ToolApprovalInterruption[Any]]
577
678
 
@@ -588,40 +689,51 @@ JAFError = Union[
588
689
  RecoverableError,
589
690
  RateLimitError,
590
691
  ValidationError,
591
- NetworkError
692
+ NetworkError,
592
693
  ]
593
694
 
695
+
594
696
  @dataclass(frozen=True)
595
697
  class CompletedOutcome(Generic[Out]):
596
698
  """Successful completion outcome."""
597
- status: Literal['completed'] = 'completed'
699
+
700
+ status: Literal["completed"] = "completed"
598
701
  output: Out = field(default=None)
599
702
 
703
+
600
704
  @dataclass(frozen=True)
601
705
  class ErrorOutcome:
602
706
  """Error outcome."""
603
- status: Literal['error'] = 'error'
707
+
708
+ status: Literal["error"] = "error"
604
709
  error: JAFError = field(default=None)
605
710
 
711
+
606
712
  @dataclass(frozen=True)
607
713
  class InterruptedOutcome:
608
714
  """Interrupted outcome for HITL."""
609
- status: Literal['interrupted'] = 'interrupted'
715
+
716
+ status: Literal["interrupted"] = "interrupted"
610
717
  interruptions: List[Interruption] = field(default_factory=list)
611
718
 
719
+
612
720
  # Union type for outcomes
613
721
  RunOutcome = Union[CompletedOutcome[Out], ErrorOutcome, InterruptedOutcome]
614
722
 
723
+
615
724
  @dataclass(frozen=True)
616
725
  class RunResult(Generic[Out]):
617
726
  """Result of a run execution."""
727
+
618
728
  final_state: RunState[Any]
619
729
  outcome: RunOutcome[Out]
620
730
 
731
+
621
732
  # Trace event types with specific data structures to match TypeScript
622
733
  @dataclass(frozen=True)
623
734
  class RunStartEventData:
624
735
  """Data for run start events."""
736
+
625
737
  run_id: RunId
626
738
  trace_id: TraceId
627
739
  session_id: Optional[str] = None
@@ -629,14 +741,19 @@ class RunStartEventData:
629
741
  messages: Optional[List[Message]] = None
630
742
  agent_name: Optional[str] = None
631
743
 
744
+
632
745
  @dataclass(frozen=True)
633
746
  class RunStartEvent:
634
- type: Literal['run_start'] = 'run_start'
635
- data: RunStartEventData = field(default_factory=lambda: RunStartEventData(RunId(""), TraceId("")))
747
+ type: Literal["run_start"] = "run_start"
748
+ data: RunStartEventData = field(
749
+ default_factory=lambda: RunStartEventData(RunId(""), TraceId(""))
750
+ )
751
+
636
752
 
637
753
  @dataclass(frozen=True)
638
754
  class LLMCallStartEventData:
639
755
  """Data for LLM call start events."""
756
+
640
757
  agent_name: str
641
758
  model: str
642
759
  trace_id: TraceId
@@ -644,46 +761,68 @@ class LLMCallStartEventData:
644
761
  context: Optional[Any] = None
645
762
  messages: Optional[List[Message]] = None
646
763
 
764
+
647
765
  @dataclass(frozen=True)
648
766
  class LLMCallStartEvent:
649
- type: Literal['llm_call_start'] = 'llm_call_start'
650
- data: LLMCallStartEventData = field(default_factory=lambda: LLMCallStartEventData("", "", TraceId(""), RunId("")))
767
+ type: Literal["llm_call_start"] = "llm_call_start"
768
+ data: LLMCallStartEventData = field(
769
+ default_factory=lambda: LLMCallStartEventData("", "", TraceId(""), RunId(""))
770
+ )
771
+
651
772
 
652
773
  @dataclass(frozen=True)
653
774
  class LLMCallEndEventData:
654
775
  """Data for LLM call end events."""
776
+
655
777
  choice: Any
656
778
  trace_id: TraceId
657
779
  run_id: RunId
658
780
  usage: Optional[Dict[str, int]] = None
659
781
 
782
+
660
783
  @dataclass(frozen=True)
661
784
  class LLMCallEndEvent:
662
- type: Literal['llm_call_end'] = 'llm_call_end'
663
- data: LLMCallEndEventData = field(default_factory=lambda: LLMCallEndEventData(None, TraceId(""), RunId("")))
785
+ type: Literal["llm_call_end"] = "llm_call_end"
786
+ data: LLMCallEndEventData = field(
787
+ default_factory=lambda: LLMCallEndEventData(None, TraceId(""), RunId(""))
788
+ )
789
+
664
790
 
665
791
  @dataclass(frozen=True)
666
792
  class AssistantMessageEventData:
667
793
  """Data for assistant message events (partial or complete)."""
794
+
668
795
  message: Message
669
796
 
797
+
670
798
  @dataclass(frozen=True)
671
799
  class AssistantMessageEvent:
672
- type: Literal['assistant_message'] = 'assistant_message'
673
- data: AssistantMessageEventData = field(default_factory=lambda: AssistantMessageEventData(Message(role=ContentRole.ASSISTANT, content="")))
800
+ type: Literal["assistant_message"] = "assistant_message"
801
+ data: AssistantMessageEventData = field(
802
+ default_factory=lambda: AssistantMessageEventData(
803
+ Message(role=ContentRole.ASSISTANT, content="")
804
+ )
805
+ )
806
+
807
+
674
808
  @dataclass(frozen=True)
675
809
  class ToolCallStartEventData:
676
810
  """Data for tool call start events."""
811
+
677
812
  tool_name: str
678
813
  args: Any
679
814
  trace_id: TraceId
680
815
  run_id: RunId
681
816
  call_id: Optional[str] = None
682
817
 
818
+
683
819
  @dataclass(frozen=True)
684
820
  class ToolCallStartEvent:
685
- type: Literal['tool_call_start'] = 'tool_call_start'
686
- data: ToolCallStartEventData = field(default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None))
821
+ type: Literal["tool_call_start"] = "tool_call_start"
822
+ data: ToolCallStartEventData = field(
823
+ default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None)
824
+ )
825
+
687
826
 
688
827
  @dataclass(frozen=True)
689
828
  class ToolCallEndEventData:
@@ -703,18 +842,27 @@ class ToolCallEndEventData:
703
842
  - 'rejected': Tool was rejected by user
704
843
  - 'execution_error', 'validation_error', etc.: Various error states
705
844
  """
845
+
706
846
  tool_name: str
707
847
  result: str
708
848
  trace_id: TraceId
709
849
  run_id: RunId
710
850
  tool_result: Optional[Any] = None
711
- execution_status: Optional[str] = None # success/error/timeout - indicates if tool executed successfully
712
- status: Optional[str] = None # DEPRECATED: maintained for backward-compatible initialization/serialization
851
+ execution_status: Optional[str] = (
852
+ None # success/error/timeout - indicates if tool executed successfully
853
+ )
854
+ status: Optional[str] = (
855
+ None # DEPRECATED: maintained for backward-compatible initialization/serialization
856
+ )
713
857
  call_id: Optional[str] = None
714
858
 
715
859
  def __post_init__(self) -> None:
716
860
  # Handle backward compatibility with explicit conflict detection
717
- if self.execution_status is not None and self.status is not None and self.execution_status != self.status:
861
+ if (
862
+ self.execution_status is not None
863
+ and self.status is not None
864
+ and self.execution_status != self.status
865
+ ):
718
866
  raise ValueError(
719
867
  f"Conflicting values for execution_status ('{self.execution_status}') and status ('{self.status}'). "
720
868
  f"Please use only execution_status for new code."
@@ -722,85 +870,111 @@ class ToolCallEndEventData:
722
870
 
723
871
  # Prefer execution_status (new field) over status (deprecated field)
724
872
  canonical = self.execution_status if self.execution_status is not None else self.status
725
- object.__setattr__(self, 'execution_status', canonical)
726
- object.__setattr__(self, 'status', canonical)
873
+ object.__setattr__(self, "execution_status", canonical)
874
+ object.__setattr__(self, "status", canonical)
875
+
876
+
727
877
  @dataclass(frozen=True)
728
878
  class ToolCallEndEvent:
729
- type: Literal['tool_call_end'] = 'tool_call_end'
730
- data: ToolCallEndEventData = field(default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None))
879
+ type: Literal["tool_call_end"] = "tool_call_end"
880
+ data: ToolCallEndEventData = field(
881
+ default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None)
882
+ )
883
+
731
884
 
732
885
  @dataclass(frozen=True)
733
886
  class HandoffEventData:
734
887
  """Data for handoff events."""
735
- from_: str = field(metadata={'alias': 'from'}) # Using from_ since 'from' is a Python keyword
888
+
889
+ from_: str = field(metadata={"alias": "from"}) # Using from_ since 'from' is a Python keyword
736
890
  to: str
737
891
 
892
+
738
893
  @dataclass(frozen=True)
739
894
  class HandoffEvent:
740
- type: Literal['handoff'] = 'handoff'
895
+ type: Literal["handoff"] = "handoff"
741
896
  data: HandoffEventData = field(default_factory=lambda: HandoffEventData("", ""))
742
897
 
898
+
743
899
  @dataclass(frozen=True)
744
900
  class RunEndEventData:
745
901
  """Data for run end events."""
746
- outcome: 'RunOutcome[Any]'
902
+
903
+ outcome: "RunOutcome[Any]"
747
904
  trace_id: TraceId
748
905
  run_id: RunId
749
906
 
907
+
750
908
  @dataclass(frozen=True)
751
909
  class RunEndEvent:
752
- type: Literal['run_end'] = 'run_end'
753
- data: RunEndEventData = field(default_factory=lambda: RunEndEventData(None, TraceId(""), RunId("")))
910
+ type: Literal["run_end"] = "run_end"
911
+ data: RunEndEventData = field(
912
+ default_factory=lambda: RunEndEventData(None, TraceId(""), RunId(""))
913
+ )
914
+
754
915
 
755
916
  @dataclass(frozen=True)
756
917
  class GuardrailEventData:
757
918
  """Data for guardrail check events."""
919
+
758
920
  guardrail_name: str
759
921
  content: Any
760
922
  is_valid: Optional[bool] = None
761
923
  error_message: Optional[str] = None
762
924
 
925
+
763
926
  @dataclass(frozen=True)
764
927
  class GuardrailEvent:
765
- type: Literal['guardrail_check'] = 'guardrail_check'
928
+ type: Literal["guardrail_check"] = "guardrail_check"
766
929
  data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
767
930
 
931
+
768
932
  @dataclass(frozen=True)
769
933
  class GuardrailViolationEventData:
770
934
  """Data for guardrail violation events."""
771
- stage: Literal['input', 'output']
935
+
936
+ stage: Literal["input", "output"]
772
937
  reason: str
773
938
 
939
+
774
940
  @dataclass(frozen=True)
775
941
  class GuardrailViolationEvent:
776
- type: Literal['guardrail_violation'] = 'guardrail_violation'
777
- data: GuardrailViolationEventData = field(default_factory=lambda: GuardrailViolationEventData("input", ""))
942
+ type: Literal["guardrail_violation"] = "guardrail_violation"
943
+ data: GuardrailViolationEventData = field(
944
+ default_factory=lambda: GuardrailViolationEventData("input", "")
945
+ )
946
+
778
947
 
779
948
  @dataclass(frozen=True)
780
949
  class MemoryEventData:
781
950
  """Data for memory operation events."""
782
- operation: Literal['load', 'store']
951
+
952
+ operation: Literal["load", "store"]
783
953
  conversation_id: str
784
- status: Literal['start', 'end', 'fail']
954
+ status: Literal["start", "end", "fail"]
785
955
  error: Optional[str] = None
786
956
  message_count: Optional[int] = None
787
957
 
958
+
788
959
  @dataclass(frozen=True)
789
960
  class MemoryEvent:
790
- type: Literal['memory_operation'] = 'memory_operation'
961
+ type: Literal["memory_operation"] = "memory_operation"
791
962
  data: MemoryEventData = field(default_factory=lambda: MemoryEventData("load", "", "start"))
792
963
 
964
+
793
965
  @dataclass(frozen=True)
794
966
  class OutputParseEventData:
795
967
  """Data for output parsing events."""
968
+
796
969
  content: str
797
- status: Literal['start', 'end', 'fail']
970
+ status: Literal["start", "end", "fail"]
798
971
  parsed_output: Optional[Any] = None
799
972
  error: Optional[str] = None
800
973
 
974
+
801
975
  @dataclass(frozen=True)
802
976
  class OutputParseEvent:
803
- type: Literal['output_parse'] = 'output_parse'
977
+ type: Literal["output_parse"] = "output_parse"
804
978
  data: OutputParseEventData = field(default_factory=lambda: OutputParseEventData("", "start"))
805
979
 
806
980
 
@@ -817,72 +991,80 @@ TraceEvent = Union[
817
991
  ToolCallStartEvent,
818
992
  ToolCallEndEvent,
819
993
  HandoffEvent,
820
- RunEndEvent
994
+ RunEndEvent,
821
995
  ]
822
996
 
997
+
823
998
  @dataclass(frozen=True)
824
999
  class ModelCompletionMessage:
825
1000
  """Message structure returned by model completion."""
1001
+
826
1002
  content: Optional[str] = None
827
1003
  tool_calls: Optional[List[ToolCall]] = None
828
1004
 
1005
+
829
1006
  @dataclass(frozen=True)
830
1007
  class ModelCompletionResponse:
831
1008
  """Response structure from model completion."""
1009
+
832
1010
  message: Optional[ModelCompletionMessage] = None
833
1011
 
1012
+
834
1013
  # Streaming chunk structures for provider-level streaming support
835
1014
  @dataclass(frozen=True)
836
1015
  class ToolCallFunctionDelta:
837
1016
  """Function fields that may stream as deltas."""
1017
+
838
1018
  name: Optional[str] = None
839
1019
  arguments_delta: Optional[str] = None
840
1020
 
1021
+
841
1022
  @dataclass(frozen=True)
842
1023
  class ToolCallDelta:
843
1024
  """Represents a partial tool call delta in a streamed response."""
1025
+
844
1026
  index: int
845
1027
  id: Optional[str] = None
846
- type: Literal['function'] = 'function'
1028
+ type: Literal["function"] = "function"
847
1029
  function: Optional[ToolCallFunctionDelta] = None
848
1030
 
1031
+
849
1032
  @dataclass(frozen=True)
850
1033
  class CompletionStreamChunk:
851
1034
  """A streamed chunk from the model provider."""
1035
+
852
1036
  delta: Optional[str] = None
853
1037
  tool_call_delta: Optional[ToolCallDelta] = None
854
1038
  is_done: Optional[bool] = False
855
1039
  finish_reason: Optional[str] = None
856
1040
  raw: Optional[Any] = None
857
1041
 
1042
+
858
1043
  @runtime_checkable
859
1044
  class ModelProvider(Protocol[Ctx]):
860
1045
  """Protocol for model providers."""
861
1046
 
862
1047
  async def get_completion(
863
- self,
864
- state: RunState[Ctx],
865
- agent: Agent[Ctx, Any],
866
- config: 'RunConfig[Ctx]'
1048
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
867
1049
  ) -> ModelCompletionResponse:
868
1050
  """Get completion from the model."""
869
1051
  ...
870
1052
 
871
1053
  async def get_completion_stream(
872
- self,
873
- state: RunState[Ctx],
874
- agent: Agent[Ctx, Any],
875
- config: 'RunConfig[Ctx]'
1054
+ self, state: RunState[Ctx], agent: Agent[Ctx, Any], config: "RunConfig[Ctx]"
876
1055
  ) -> AsyncIterator[CompletionStreamChunk]:
877
1056
  """Optional streaming API: yields incremental deltas while generating."""
878
1057
  ...
879
1058
 
1059
+
880
1060
  @dataclass(frozen=True)
881
1061
  class RunConfig(Generic[Ctx]):
882
1062
  """Configuration for running agents."""
1063
+
883
1064
  agent_registry: Dict[str, Agent[Ctx, Any]]
884
1065
  model_provider: ModelProvider[Ctx]
885
1066
  max_turns: Optional[int] = 50
1067
+ max_tokens: Optional[int] = None # Default max_tokens for all agents (can be overridden per agent)
886
1068
  model_override: Optional[str] = None
887
1069
  initial_input_guardrails: Optional[List[Guardrail]] = None
888
1070
  final_output_guardrails: Optional[List[Guardrail]] = None
@@ -891,32 +1073,68 @@ class RunConfig(Generic[Ctx]):
891
1073
  conversation_id: Optional[str] = None
892
1074
  default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
893
1075
  default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
894
- approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
895
- before_llm_call: Optional[Callable[[RunState[Ctx], Agent[Ctx, Any]], Union[RunState[Ctx], Awaitable[RunState[Ctx]]]]] = None # Callback before LLM call - can modify context/messages
896
- after_llm_call: Optional[Callable[[RunState[Ctx], ModelCompletionResponse], Union[ModelCompletionResponse, Awaitable[ModelCompletionResponse]]]] = None # Callback after LLM call - can process response
1076
+ approval_storage: Optional["ApprovalStorage"] = None # Storage for approval decisions
1077
+ before_llm_call: Optional[
1078
+ Callable[[RunState[Ctx], Agent[Ctx, Any]], Union[RunState[Ctx], Awaitable[RunState[Ctx]]]]
1079
+ ] = None # Callback before LLM call - can modify context/messages
1080
+ after_llm_call: Optional[
1081
+ Callable[
1082
+ [RunState[Ctx], ModelCompletionResponse],
1083
+ Union[ModelCompletionResponse, Awaitable[ModelCompletionResponse]],
1084
+ ]
1085
+ ] = None # Callback after LLM call - can process response
897
1086
  max_empty_response_retries: int = 3 # Maximum retries when LLM returns empty response
898
- empty_response_retry_delay: float = 1.0 # Initial delay in seconds before retrying empty response (uses exponential backoff)
1087
+ empty_response_retry_delay: float = (
1088
+ 1.0 # Initial delay in seconds before retrying empty response (uses exponential backoff)
1089
+ )
899
1090
  log_empty_responses: bool = True # Whether to log diagnostic info for empty responses
900
- prefer_streaming: Optional[bool] = None # Whether to prefer streaming responses. None (default) = use streaming if available, True = prefer streaming, False = disable streaming
1091
+ prefer_streaming: Optional[bool] = (
1092
+ None # Whether to prefer streaming responses. None (default) = use streaming if available, True = prefer streaming, False = disable streaming
1093
+ )
901
1094
 
902
1095
 
903
1096
  # Regeneration types for conversation management
904
1097
  @dataclass(frozen=True)
905
1098
  class RegenerationRequest:
906
1099
  """Request to regenerate a conversation from a specific message."""
1100
+
907
1101
  conversation_id: str
908
1102
  message_id: MessageId # ID of the message to regenerate from
909
1103
  context: Optional[Dict[str, Any]] = None # Optional context override
910
1104
 
1105
+
911
1106
  @dataclass(frozen=True)
912
1107
  class RegenerationContext:
913
1108
  """Context information for a regeneration operation."""
1109
+
914
1110
  original_message_count: int
915
1111
  truncated_at_index: int
916
1112
  regenerated_message_id: MessageId
917
1113
  regeneration_id: str # Unique ID for this regeneration operation
918
1114
  timestamp: int # Unix timestamp in milliseconds
919
1115
 
1116
+
1117
+ # Checkpoint types for conversation management
1118
+ @dataclass(frozen=True)
1119
+ class CheckpointRequest:
1120
+ """Request to checkpoint a conversation after a specific message."""
1121
+
1122
+ conversation_id: str
1123
+ message_id: MessageId # ID of the message to checkpoint after (this message is kept)
1124
+ context: Optional[Dict[str, Any]] = None # Optional context for the checkpoint
1125
+
1126
+
1127
+ @dataclass(frozen=True)
1128
+ class CheckpointContext:
1129
+ """Context information for a checkpoint operation."""
1130
+
1131
+ original_message_count: int
1132
+ checkpointed_at_index: int
1133
+ checkpointed_message_id: MessageId
1134
+ checkpoint_id: str # Unique ID for this checkpoint operation
1135
+ timestamp: int # Unix timestamp in milliseconds
1136
+
1137
+
920
1138
  # Message utility functions
921
1139
  def find_message_index(messages: List[Message], message_id: MessageId) -> Optional[int]:
922
1140
  """Find the index of a message by its ID."""
@@ -925,6 +1143,7 @@ def find_message_index(messages: List[Message], message_id: MessageId) -> Option
925
1143
  return i
926
1144
  return None
927
1145
 
1146
+
928
1147
  def truncate_messages_after(messages: List[Message], message_id: MessageId) -> List[Message]:
929
1148
  """Truncate messages after (and including) the specified message ID."""
930
1149
  index = find_message_index(messages, message_id)
@@ -932,6 +1151,7 @@ def truncate_messages_after(messages: List[Message], message_id: MessageId) -> L
932
1151
  return messages # Message not found, return unchanged
933
1152
  return messages[:index]
934
1153
 
1154
+
935
1155
  def get_message_by_id(messages: List[Message], message_id: MessageId) -> Optional[Message]:
936
1156
  """Get a message by its ID."""
937
1157
  for msg in messages: