dispatch_agents 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentservice/__init__.py +0 -0
- agentservice/py.typed +0 -0
- agentservice/v1/__init__.py +0 -0
- agentservice/v1/message_pb2.py +41 -0
- agentservice/v1/message_pb2.pyi +22 -0
- agentservice/v1/message_pb2_grpc.py +4 -0
- agentservice/v1/request_response_pb2.py +46 -0
- agentservice/v1/request_response_pb2.pyi +54 -0
- agentservice/v1/request_response_pb2_grpc.py +4 -0
- agentservice/v1/service_pb2.py +43 -0
- agentservice/v1/service_pb2.pyi +6 -0
- agentservice/v1/service_pb2_grpc.py +129 -0
- dispatch_agents/__init__.py +281 -0
- dispatch_agents/agent_service.py +135 -0
- dispatch_agents/config.py +490 -0
- dispatch_agents/contrib/__init__.py +1 -0
- dispatch_agents/contrib/claude/__init__.py +246 -0
- dispatch_agents/contrib/openai/__init__.py +167 -0
- dispatch_agents/events.py +986 -0
- dispatch_agents/grpc_server.py +565 -0
- dispatch_agents/instrument.py +217 -0
- dispatch_agents/integrations/__init__.py +1 -0
- dispatch_agents/integrations/github/README.md +9 -0
- dispatch_agents/integrations/github/__init__.py +4268 -0
- dispatch_agents/invocation.py +25 -0
- dispatch_agents/llm.py +1017 -0
- dispatch_agents/llm_langchain.py +394 -0
- dispatch_agents/logging_config.py +133 -0
- dispatch_agents/mcp.py +266 -0
- dispatch_agents/memory.py +264 -0
- dispatch_agents/models.py +748 -0
- dispatch_agents/proxy/__init__.py +6 -0
- dispatch_agents/proxy/server.py +1137 -0
- dispatch_agents/proxy/sse_utils.py +76 -0
- dispatch_agents/py.typed +0 -0
- dispatch_agents/resources.py +68 -0
- dispatch_agents/version.py +19 -0
- dispatch_agents-0.9.0.dist-info/METADATA +20 -0
- dispatch_agents-0.9.0.dist-info/RECORD +43 -0
- dispatch_agents-0.9.0.dist-info/WHEEL +4 -0
- dispatch_agents-0.9.0.dist-info/licenses/LICENSE +191 -0
- dispatch_agents-0.9.0.dist-info/licenses/LICENSE-3rdparty.csv +12 -0
- dispatch_agents-0.9.0.dist-info/licenses/NOTICE +5 -0
|
@@ -0,0 +1,986 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import Awaitable, Callable
|
|
9
|
+
from contextvars import ContextVar
|
|
10
|
+
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Any,
|
|
13
|
+
ParamSpec,
|
|
14
|
+
TypeVar,
|
|
15
|
+
get_args,
|
|
16
|
+
get_origin,
|
|
17
|
+
get_type_hints,
|
|
18
|
+
overload,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
import httpx
|
|
22
|
+
from pydantic import BaseModel, ValidationError
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from dispatch_agents.integrations.github import GitHubEventPayload
|
|
26
|
+
|
|
27
|
+
from dispatch_agents.models import (
|
|
28
|
+
BaseMessage,
|
|
29
|
+
ErrorPayload,
|
|
30
|
+
FunctionMessage,
|
|
31
|
+
InvokeFunctionRequest,
|
|
32
|
+
JsonSchema,
|
|
33
|
+
Message,
|
|
34
|
+
PublishEventBody,
|
|
35
|
+
StrictBaseModel,
|
|
36
|
+
SuccessPayload,
|
|
37
|
+
TopicMessage,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
from .version import get_sdk_version
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
# Type variable for generic payload types
|
|
45
|
+
PayloadT = TypeVar("PayloadT", bound=BaseModel)
|
|
46
|
+
ReturnT = TypeVar("ReturnT", bound=BaseModel | None)
|
|
47
|
+
|
|
48
|
+
# ParamSpec and TypeVar for preserving decorator function signatures
|
|
49
|
+
P = ParamSpec("P")
|
|
50
|
+
R = TypeVar("R")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Base class for all event payloads - users must inherit from this
|
|
54
|
+
class BasePayload(StrictBaseModel):
|
|
55
|
+
"""Base class for all dispatch agent event payloads.
|
|
56
|
+
|
|
57
|
+
All handler input parameters must inherit from this class to ensure
|
|
58
|
+
proper type validation and schema extraction.
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> class MyEventPayload(BasePayload):
|
|
62
|
+
... message: str
|
|
63
|
+
... user_id: int
|
|
64
|
+
...
|
|
65
|
+
>>> @on(topic="my.topic")
|
|
66
|
+
... async def my_handler(payload: MyEventPayload) -> str:
|
|
67
|
+
... return f"Hello {payload.user_id}"
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class HandlerMetadata(StrictBaseModel):
|
|
74
|
+
"""Serializable handler metadata for registration and introspection.
|
|
75
|
+
|
|
76
|
+
This model provides type-safe access to handler metadata including
|
|
77
|
+
input/output schemas, topic subscriptions, and documentation.
|
|
78
|
+
|
|
79
|
+
Note: input_model and output_model type references are not stored here
|
|
80
|
+
since they can be extracted from the handler function via get_type_hints()
|
|
81
|
+
when needed for validation.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
handler_name: str
|
|
85
|
+
topics: list[str]
|
|
86
|
+
input_schema: JsonSchema
|
|
87
|
+
output_schema: JsonSchema | None
|
|
88
|
+
handler_doc: str | None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# Type alias for async handler functions - takes a BaseModel and returns BaseModel or None
|
|
92
|
+
AsyncHandler = Callable[[BaseModel], Awaitable[BaseModel | None]]
|
|
93
|
+
|
|
94
|
+
# Unified handler registry - maps handler_name -> handler function and metadata
|
|
95
|
+
# Used by both @on (topic-based) and @fn (direct call) decorators
|
|
96
|
+
# All handlers can be invoked directly by name; @on handlers additionally have topic triggers
|
|
97
|
+
REGISTERED_HANDLERS: dict[str, AsyncHandler] = {}
|
|
98
|
+
HANDLER_METADATA: dict[str, HandlerMetadata] = {}
|
|
99
|
+
|
|
100
|
+
# Topic-to-handler mapping for efficient topic routing
|
|
101
|
+
# Maps topic -> list of handler_names (used to look up handlers in REGISTERED_HANDLERS)
|
|
102
|
+
# Multiple handlers can subscribe to the same topic (fan-out pattern)
|
|
103
|
+
TOPIC_HANDLERS: dict[str, list[str]] = {}
|
|
104
|
+
|
|
105
|
+
# Init hook - async function called once when the agent starts
|
|
106
|
+
# Runs in the agent's event loop before handling any requests
|
|
107
|
+
_INIT_HOOK: Callable[[], Awaitable[None]] | None = None
|
|
108
|
+
|
|
109
|
+
# Thread-safe context variables for tracking current execution context
|
|
110
|
+
_current_trace_id: ContextVar[str | None] = ContextVar("current_trace_id", default=None)
|
|
111
|
+
_current_invocation_id: ContextVar[str | None] = ContextVar(
|
|
112
|
+
"current_invocation_id", default=None
|
|
113
|
+
)
|
|
114
|
+
_current_parent_id: ContextVar[str | None] = ContextVar(
|
|
115
|
+
"current_parent_id", default=None
|
|
116
|
+
)
|
|
117
|
+
_current_message: ContextVar[BaseMessage | None] = ContextVar(
|
|
118
|
+
"current_message", default=None
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Trace-indexed invocation context store.
|
|
122
|
+
# Maps trace_id -> invocation_id for fallback lookup when context variables
|
|
123
|
+
# aren't propagated (e.g., when external SDKs use separate async contexts).
|
|
124
|
+
# Uses OrderedDict to maintain insertion order for LRU-style eviction.
|
|
125
|
+
# Thread-safe for single operations in CPython due to GIL.
|
|
126
|
+
from collections import OrderedDict
|
|
127
|
+
|
|
128
|
+
_trace_invocation_context: OrderedDict[str, str] = OrderedDict()
|
|
129
|
+
_TRACE_CONTEXT_MAX_SIZE = 100000 # Maximum number of trace mappings to keep
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _register_trace_invocation(trace_id: str, invocation_id: str) -> None:
|
|
133
|
+
"""Register a trace_id -> invocation_id mapping with bounded size.
|
|
134
|
+
|
|
135
|
+
Uses LRU-style eviction: when the cache is full, the oldest entries
|
|
136
|
+
are removed to make room for new ones.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
trace_id: The trace ID to register
|
|
140
|
+
invocation_id: The invocation ID to associate with the trace
|
|
141
|
+
"""
|
|
142
|
+
# If already exists, move to end (most recently used)
|
|
143
|
+
if trace_id in _trace_invocation_context:
|
|
144
|
+
_trace_invocation_context.move_to_end(trace_id)
|
|
145
|
+
_trace_invocation_context[trace_id] = invocation_id
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
# Add new entry
|
|
149
|
+
_trace_invocation_context[trace_id] = invocation_id
|
|
150
|
+
|
|
151
|
+
# Evict oldest entries if over limit
|
|
152
|
+
while len(_trace_invocation_context) > _TRACE_CONTEXT_MAX_SIZE:
|
|
153
|
+
_trace_invocation_context.popitem(last=False) # Remove oldest (first)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _unregister_trace_invocation(trace_id: str) -> None:
|
|
157
|
+
"""Remove a trace_id -> invocation_id mapping.
|
|
158
|
+
|
|
159
|
+
Called when an invocation completes to ensure deterministic cleanup.
|
|
160
|
+
Silently ignores if the trace_id is not found.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
trace_id: The trace ID to unregister
|
|
164
|
+
"""
|
|
165
|
+
_trace_invocation_context.pop(trace_id, None)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _extract_return_model(return_type: Any) -> type[BaseModel] | None:
|
|
169
|
+
"""Extract BaseModel from return type, handling Optional/Union."""
|
|
170
|
+
if not return_type:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
# Check if it's Optional[Model] or Union[Model, None]
|
|
174
|
+
origin = get_origin(return_type)
|
|
175
|
+
if origin is not None:
|
|
176
|
+
args = get_args(return_type)
|
|
177
|
+
for arg in args:
|
|
178
|
+
if (
|
|
179
|
+
arg is not type(None)
|
|
180
|
+
and isinstance(arg, type)
|
|
181
|
+
and issubclass(arg, BaseModel)
|
|
182
|
+
):
|
|
183
|
+
return arg
|
|
184
|
+
|
|
185
|
+
# Direct BaseModel subclass
|
|
186
|
+
if isinstance(return_type, type) and issubclass(return_type, BaseModel):
|
|
187
|
+
return return_type
|
|
188
|
+
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _get_input_model_from_handler(
|
|
193
|
+
func: Callable[..., Any],
|
|
194
|
+
) -> type[BaseModel] | None:
|
|
195
|
+
"""Extract the input model type from a handler function's type hints.
|
|
196
|
+
|
|
197
|
+
This is used at runtime to get the input model for payload validation,
|
|
198
|
+
avoiding the need to store type references in serializable metadata.
|
|
199
|
+
"""
|
|
200
|
+
try:
|
|
201
|
+
hints = get_type_hints(func)
|
|
202
|
+
except Exception:
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
sig = inspect.signature(func)
|
|
206
|
+
params = list(sig.parameters.values())
|
|
207
|
+
if not params:
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
first_param_type = hints.get(params[0].name)
|
|
211
|
+
if first_param_type and isinstance(first_param_type, type):
|
|
212
|
+
if issubclass(first_param_type, BaseModel):
|
|
213
|
+
return first_param_type
|
|
214
|
+
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def fn(
|
|
219
|
+
*, name: str | None = None
|
|
220
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
|
221
|
+
"""Register a function as directly callable by other agents.
|
|
222
|
+
|
|
223
|
+
Functions registered with @fn can be called directly using invoke().
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
name: Optional function name for invocation (defaults to function.__name__)
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
A decorator function that registers the callable while preserving type hints
|
|
230
|
+
|
|
231
|
+
Examples:
|
|
232
|
+
>>> @fn()
|
|
233
|
+
... async def get_weather(payload: WeatherRequest) -> WeatherResponse:
|
|
234
|
+
... return WeatherResponse(temp=72)
|
|
235
|
+
...
|
|
236
|
+
>>> # Called from another agent:
|
|
237
|
+
>>> result = await invoke("weather-agent", "get_weather", {"city": "NYC"})
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
|
241
|
+
fn_name = name or func.__name__
|
|
242
|
+
|
|
243
|
+
if fn_name in REGISTERED_HANDLERS:
|
|
244
|
+
raise ValueError(f"Handler already registered: {fn_name}")
|
|
245
|
+
|
|
246
|
+
# Extract type information from function signature
|
|
247
|
+
sig = inspect.signature(func)
|
|
248
|
+
params = list(sig.parameters.values())
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
hints = get_type_hints(func)
|
|
252
|
+
except Exception:
|
|
253
|
+
hints = {}
|
|
254
|
+
|
|
255
|
+
# Extract input model (first parameter that's a BaseModel subclass)
|
|
256
|
+
input_model: type[BaseModel] | None = None
|
|
257
|
+
if params:
|
|
258
|
+
first_param_type = hints.get(params[0].name)
|
|
259
|
+
if first_param_type:
|
|
260
|
+
if isinstance(first_param_type, type) and issubclass(
|
|
261
|
+
first_param_type, BaseModel
|
|
262
|
+
):
|
|
263
|
+
input_model = first_param_type
|
|
264
|
+
|
|
265
|
+
if not input_model:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Handler '{fn_name}' must have a first parameter "
|
|
268
|
+
f"annotated with a Pydantic BaseModel subclass. "
|
|
269
|
+
f"Example: async def {fn_name}(payload: MyPayload) -> Result: ..."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Extract output model
|
|
273
|
+
return_type = hints.get("return")
|
|
274
|
+
output_model = _extract_return_model(return_type)
|
|
275
|
+
|
|
276
|
+
# Store unified metadata (type-safe Pydantic model)
|
|
277
|
+
metadata = HandlerMetadata(
|
|
278
|
+
handler_name=fn_name,
|
|
279
|
+
topics=[], # No topic subscriptions for @fn
|
|
280
|
+
input_schema=input_model.model_json_schema(mode="serialization"),
|
|
281
|
+
output_schema=output_model.model_json_schema(mode="serialization")
|
|
282
|
+
if output_model
|
|
283
|
+
else None,
|
|
284
|
+
handler_doc=func.__doc__,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Store metadata on function for introspection
|
|
288
|
+
func._dispatch_metadata = metadata # type: ignore
|
|
289
|
+
|
|
290
|
+
# Register in unified registries
|
|
291
|
+
HANDLER_METADATA[fn_name] = metadata
|
|
292
|
+
REGISTERED_HANDLERS[fn_name] = func # type: ignore[assignment]
|
|
293
|
+
|
|
294
|
+
return func
|
|
295
|
+
|
|
296
|
+
return decorator
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def init(
|
|
300
|
+
func: Callable[[], Awaitable[None]],
|
|
301
|
+
) -> Callable[[], Awaitable[None]]:
|
|
302
|
+
"""Register the agent's initialization function.
|
|
303
|
+
|
|
304
|
+
The init function runs once in the agent's event loop before handling
|
|
305
|
+
any requests. Use this for async initialization such as connecting to
|
|
306
|
+
MCP servers or initializing database connections.
|
|
307
|
+
|
|
308
|
+
Only one function can be decorated with @init per agent.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
func: An async function with no parameters
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
The original function (unmodified)
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
ValueError: If an init function is already registered
|
|
318
|
+
|
|
319
|
+
Examples:
|
|
320
|
+
>>> from dispatch_agents.contrib.openai import get_mcp_servers
|
|
321
|
+
>>> from agents import Agent
|
|
322
|
+
>>>
|
|
323
|
+
>>> my_agent: Agent # Module-level, initialized by @init
|
|
324
|
+
>>>
|
|
325
|
+
>>> @init
|
|
326
|
+
... async def setup():
|
|
327
|
+
... mcp_servers = await get_mcp_servers()
|
|
328
|
+
... global my_agent
|
|
329
|
+
... my_agent = Agent(name="MyAgent", mcp_servers=mcp_servers)
|
|
330
|
+
>>>
|
|
331
|
+
>>> @on(topic="query")
|
|
332
|
+
... async def handle_query(payload: QueryRequest) -> QueryResponse:
|
|
333
|
+
... result = await Runner.run(my_agent, payload.prompt)
|
|
334
|
+
... return QueryResponse(result=result.final_output)
|
|
335
|
+
"""
|
|
336
|
+
global _INIT_HOOK
|
|
337
|
+
|
|
338
|
+
if not asyncio.iscoroutinefunction(func):
|
|
339
|
+
raise TypeError(f"@init function must be async: {func.__name__}")
|
|
340
|
+
|
|
341
|
+
if _INIT_HOOK is not None:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Only one @init function allowed. Already registered: {_INIT_HOOK.__name__}"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
_INIT_HOOK = func
|
|
347
|
+
return func
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def _validate_github_payload_compatibility(
|
|
351
|
+
input_model: type[BaseModel],
|
|
352
|
+
event_classes: list[type[GitHubEventPayload]],
|
|
353
|
+
handler_name: str,
|
|
354
|
+
) -> None:
|
|
355
|
+
"""Validate handler's payload type is compatible with subscribed GitHub event classes.
|
|
356
|
+
|
|
357
|
+
The handler's input model must be a base class (or exact match) of all event classes.
|
|
358
|
+
This ensures type safety when handling multiple event types with a common base.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
input_model: The handler's input payload model
|
|
362
|
+
event_classes: List of GitHub event classes the handler subscribes to
|
|
363
|
+
handler_name: Name of the handler (for error messages)
|
|
364
|
+
|
|
365
|
+
Raises:
|
|
366
|
+
TypeError: If the handler's payload type is not compatible with all events
|
|
367
|
+
"""
|
|
368
|
+
for event_cls in event_classes:
|
|
369
|
+
# The input model should be the same class or a base class of the event
|
|
370
|
+
if not issubclass(event_cls, input_model):
|
|
371
|
+
raise TypeError(
|
|
372
|
+
f"Handler '{handler_name}' payload type {input_model.__name__} "
|
|
373
|
+
f"is not compatible with {event_cls.__name__}. "
|
|
374
|
+
f"Use a common base class or the exact event class."
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def on(
|
|
379
|
+
*,
|
|
380
|
+
topic: str | None = None,
|
|
381
|
+
github_event: type[GitHubEventPayload]
|
|
382
|
+
| list[type[GitHubEventPayload]]
|
|
383
|
+
| None = None,
|
|
384
|
+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
|
385
|
+
"""Register an event handler for a topic or GitHub event(s).
|
|
386
|
+
|
|
387
|
+
The handler function should accept a payload parameter that is a Pydantic BaseModel
|
|
388
|
+
subclass. The decorator will automatically extract input/output schemas from the
|
|
389
|
+
function's type hints and register them for validation and API documentation.
|
|
390
|
+
|
|
391
|
+
Handlers registered with @on can also be called directly using invoke() by their
|
|
392
|
+
function name, just like @fn handlers.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
topic: The event topic to handle (e.g., "user.created")
|
|
396
|
+
github_event: GitHub event(s) to subscribe to. Mutually exclusive with topic.
|
|
397
|
+
Can be a single event class (e.g., PullRequestOpened) or a list of classes.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
A decorator function that registers the handler while preserving type hints
|
|
401
|
+
|
|
402
|
+
Examples:
|
|
403
|
+
# Subscribe to a custom topic
|
|
404
|
+
>>> @on(topic="user.created")
|
|
405
|
+
... async def handle_user_created(payload: UserCreatedPayload) -> WelcomeEmailPayload:
|
|
406
|
+
... return WelcomeEmailPayload(...)
|
|
407
|
+
|
|
408
|
+
# Subscribe to a GitHub event
|
|
409
|
+
>>> from dispatch_agents.integrations.github import PullRequestOpened
|
|
410
|
+
>>> @on(github_event=PullRequestOpened)
|
|
411
|
+
... async def handle_pr(payload: PullRequestOpened) -> None:
|
|
412
|
+
... print(f"PR opened: {payload.pull_request.title}")
|
|
413
|
+
|
|
414
|
+
# Subscribe to multiple GitHub events
|
|
415
|
+
>>> from dispatch_agents.integrations.github import (
|
|
416
|
+
... PullRequestOpened, PullRequestSynchronize, PullRequestBase
|
|
417
|
+
... )
|
|
418
|
+
>>> @on(github_event=[PullRequestOpened, PullRequestSynchronize])
|
|
419
|
+
... async def handle_pr_changes(payload: PullRequestBase) -> None:
|
|
420
|
+
... ...
|
|
421
|
+
"""
|
|
422
|
+
# Deferred import required to avoid circular import:
|
|
423
|
+
# events.py -> github/__init__.py -> events.py (for BasePayload)
|
|
424
|
+
from dispatch_agents.integrations.github import GitHubEventPayload
|
|
425
|
+
|
|
426
|
+
# Validate parameters
|
|
427
|
+
if topic and github_event:
|
|
428
|
+
raise ValueError("Cannot specify both 'topic' and 'github_event'")
|
|
429
|
+
if not topic and not github_event:
|
|
430
|
+
raise ValueError("Must specify either 'topic' or 'github_event'")
|
|
431
|
+
|
|
432
|
+
# Convert github_event to list of topics
|
|
433
|
+
topics: list[str] = []
|
|
434
|
+
github_event_classes: list[type[GitHubEventPayload]] = []
|
|
435
|
+
|
|
436
|
+
if github_event:
|
|
437
|
+
events = github_event if isinstance(github_event, list) else [github_event]
|
|
438
|
+
|
|
439
|
+
for event in events:
|
|
440
|
+
if isinstance(event, type) and issubclass(event, GitHubEventPayload):
|
|
441
|
+
github_event_classes.append(event)
|
|
442
|
+
topics.append(event.dispatch_topic())
|
|
443
|
+
else:
|
|
444
|
+
raise TypeError(
|
|
445
|
+
f"Invalid github_event type: {type(event)}. "
|
|
446
|
+
f"Expected a GitHub event class (e.g., PullRequestOpened)."
|
|
447
|
+
)
|
|
448
|
+
else:
|
|
449
|
+
topics = [topic] # type: ignore[list-item]
|
|
450
|
+
|
|
451
|
+
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
|
452
|
+
handler_name = func.__name__
|
|
453
|
+
|
|
454
|
+
if handler_name in REGISTERED_HANDLERS:
|
|
455
|
+
# Handler exists - check if we're just adding more topics to it
|
|
456
|
+
existing_metadata = HANDLER_METADATA[handler_name]
|
|
457
|
+
for t in topics:
|
|
458
|
+
if t not in existing_metadata.topics:
|
|
459
|
+
existing_metadata.topics.append(t)
|
|
460
|
+
if t not in TOPIC_HANDLERS:
|
|
461
|
+
TOPIC_HANDLERS[t] = []
|
|
462
|
+
if handler_name not in TOPIC_HANDLERS[t]:
|
|
463
|
+
TOPIC_HANDLERS[t].append(handler_name)
|
|
464
|
+
return func
|
|
465
|
+
|
|
466
|
+
# Extract type information from function signature
|
|
467
|
+
sig = inspect.signature(func)
|
|
468
|
+
params = list(sig.parameters.values())
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
hints = get_type_hints(func)
|
|
472
|
+
except Exception:
|
|
473
|
+
hints = {}
|
|
474
|
+
|
|
475
|
+
# Extract input model (first parameter that's a BaseModel subclass)
|
|
476
|
+
input_model: type[BaseModel] | None = None
|
|
477
|
+
if params:
|
|
478
|
+
first_param_type = hints.get(params[0].name)
|
|
479
|
+
if first_param_type:
|
|
480
|
+
if isinstance(first_param_type, type) and issubclass(
|
|
481
|
+
first_param_type, BaseModel
|
|
482
|
+
):
|
|
483
|
+
input_model = first_param_type
|
|
484
|
+
|
|
485
|
+
if not input_model:
|
|
486
|
+
topic_desc = ", ".join(topics)
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Handler for topic(s) '{topic_desc}' must have a first parameter "
|
|
489
|
+
f"annotated with a Pydantic BaseModel subclass. "
|
|
490
|
+
f"Example: async def handler(payload: MyPayload) -> Result: ..."
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# Validate GitHub event payload compatibility
|
|
494
|
+
if github_event_classes:
|
|
495
|
+
_validate_github_payload_compatibility(
|
|
496
|
+
input_model, github_event_classes, handler_name
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Extract output model (return type if it's a BaseModel subclass)
|
|
500
|
+
return_type = hints.get("return")
|
|
501
|
+
output_model = _extract_return_model(return_type)
|
|
502
|
+
|
|
503
|
+
# Store unified metadata (type-safe Pydantic model)
|
|
504
|
+
metadata = HandlerMetadata(
|
|
505
|
+
handler_name=handler_name,
|
|
506
|
+
topics=topics,
|
|
507
|
+
input_schema=input_model.model_json_schema(mode="serialization"),
|
|
508
|
+
output_schema=output_model.model_json_schema(mode="serialization")
|
|
509
|
+
if output_model
|
|
510
|
+
else None,
|
|
511
|
+
handler_doc=func.__doc__,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Store metadata on the function
|
|
515
|
+
func._dispatch_metadata = metadata # type: ignore
|
|
516
|
+
|
|
517
|
+
# Register in unified registries
|
|
518
|
+
HANDLER_METADATA[handler_name] = metadata
|
|
519
|
+
REGISTERED_HANDLERS[handler_name] = func # type: ignore[assignment]
|
|
520
|
+
for t in topics:
|
|
521
|
+
if t not in TOPIC_HANDLERS:
|
|
522
|
+
TOPIC_HANDLERS[t] = []
|
|
523
|
+
if handler_name not in TOPIC_HANDLERS[t]:
|
|
524
|
+
TOPIC_HANDLERS[t].append(handler_name)
|
|
525
|
+
|
|
526
|
+
return func
|
|
527
|
+
|
|
528
|
+
return decorator
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
async def dispatch_message(message: Message) -> SuccessPayload | ErrorPayload:
|
|
532
|
+
"""
|
|
533
|
+
Called by the agent's gRPC server when a message is received.
|
|
534
|
+
Routes to the appropriate handler based on message type:
|
|
535
|
+
- TopicMessage: looks up handler via TOPIC_HANDLERS[topic]
|
|
536
|
+
- FunctionMessage: looks up handler directly via REGISTERED_HANDLERS[function_name]
|
|
537
|
+
|
|
538
|
+
All handlers (from @on and @fn) are callable via FunctionMessage.
|
|
539
|
+
TopicMessage routing is maintained for backwards compatibility with existing workflows.
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
SuccessPayload: When handler executes successfully, contains the return value
|
|
543
|
+
ErrorPayload: When handler raises an exception, contains error details
|
|
544
|
+
"""
|
|
545
|
+
import traceback
|
|
546
|
+
|
|
547
|
+
# Set context for the duration of this message processing
|
|
548
|
+
_current_trace_id.set(message.trace_id)
|
|
549
|
+
_current_invocation_id.set(message.uid)
|
|
550
|
+
_current_parent_id.set(message.parent_id)
|
|
551
|
+
_current_message.set(message)
|
|
552
|
+
|
|
553
|
+
# Register trace -> invocation mapping for fallback lookup.
|
|
554
|
+
# This helps when external SDKs (OpenAI, Claude) don't properly propagate
|
|
555
|
+
# Python context variables to their tool call contexts.
|
|
556
|
+
if message.trace_id and message.uid:
|
|
557
|
+
_register_trace_invocation(message.trace_id, message.uid)
|
|
558
|
+
|
|
559
|
+
try:
|
|
560
|
+
# Route based on message type
|
|
561
|
+
if isinstance(message, TopicMessage):
|
|
562
|
+
if message.topic not in TOPIC_HANDLERS or not TOPIC_HANDLERS[message.topic]:
|
|
563
|
+
raise ValueError(f"No handler registered for topic: {message.topic}")
|
|
564
|
+
handler_names = TOPIC_HANDLERS[message.topic]
|
|
565
|
+
elif isinstance(message, FunctionMessage):
|
|
566
|
+
if message.function_name not in REGISTERED_HANDLERS:
|
|
567
|
+
raise ValueError(f"No handler registered: {message.function_name}")
|
|
568
|
+
handler_names = [message.function_name]
|
|
569
|
+
else:
|
|
570
|
+
raise ValueError(f"Unsupported message type: {type(message).__name__}")
|
|
571
|
+
|
|
572
|
+
# Call all handlers for the topic (fan-out pattern)
|
|
573
|
+
# Returns the last successful result, or the last error if all fail
|
|
574
|
+
last_result: SuccessPayload | ErrorPayload | None = None
|
|
575
|
+
|
|
576
|
+
for handler_name in handler_names:
|
|
577
|
+
func = REGISTERED_HANDLERS[handler_name]
|
|
578
|
+
# Extract input model from handler function's type hints
|
|
579
|
+
input_model = _get_input_model_from_handler(func)
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
# Validate payload against input schema
|
|
583
|
+
if not input_model:
|
|
584
|
+
raise ValueError(
|
|
585
|
+
f"No input model found for handler: {handler_name}"
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
payload_obj = input_model.model_validate(message.payload)
|
|
589
|
+
|
|
590
|
+
# Call handler with validated payload
|
|
591
|
+
raw_fn_return = await func(payload_obj)
|
|
592
|
+
|
|
593
|
+
# Serialize return value
|
|
594
|
+
if isinstance(raw_fn_return, BaseModel):
|
|
595
|
+
result = raw_fn_return.model_dump()
|
|
596
|
+
elif raw_fn_return is None:
|
|
597
|
+
result = None
|
|
598
|
+
else:
|
|
599
|
+
result = raw_fn_return
|
|
600
|
+
|
|
601
|
+
last_result = SuccessPayload(result=result)
|
|
602
|
+
|
|
603
|
+
except ValidationError as e:
|
|
604
|
+
logger.error(
|
|
605
|
+
"Validation error in event handler",
|
|
606
|
+
extra={
|
|
607
|
+
"handler": handler_name,
|
|
608
|
+
"error_type": "ValidationError",
|
|
609
|
+
"validation_errors": e.errors(),
|
|
610
|
+
},
|
|
611
|
+
exc_info=True,
|
|
612
|
+
)
|
|
613
|
+
# e.errors() can contain non-serializable objects (e.g. ValueError
|
|
614
|
+
# instances in ctx). Convert to JSON-safe format via str().
|
|
615
|
+
safe_details: list[dict[str, Any]] = []
|
|
616
|
+
_json_safe = str | int | float | bool | list | dict | None
|
|
617
|
+
for err in e.errors():
|
|
618
|
+
safe_err = {
|
|
619
|
+
k: str(v) if not isinstance(v, _json_safe) else v
|
|
620
|
+
for k, v in err.items()
|
|
621
|
+
}
|
|
622
|
+
safe_details.append(safe_err)
|
|
623
|
+
last_result = ErrorPayload(
|
|
624
|
+
error=str(e),
|
|
625
|
+
error_type="ValidationError",
|
|
626
|
+
trace=traceback.format_exc(),
|
|
627
|
+
details=safe_details,
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
except Exception as e:
|
|
631
|
+
logger.error(
|
|
632
|
+
"Error in event handler",
|
|
633
|
+
extra={"handler": handler_name, "error": str(e)},
|
|
634
|
+
exc_info=True,
|
|
635
|
+
)
|
|
636
|
+
last_result = ErrorPayload(
|
|
637
|
+
error=str(e),
|
|
638
|
+
error_type=type(e).__name__,
|
|
639
|
+
trace=traceback.format_exc(),
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# Should always have a result since we check handler_names is not empty
|
|
643
|
+
assert last_result is not None
|
|
644
|
+
return last_result
|
|
645
|
+
|
|
646
|
+
finally:
|
|
647
|
+
# Clean up trace context mapping when invocation completes.
|
|
648
|
+
# This ensures deterministic cleanup rather than relying solely on LRU eviction.
|
|
649
|
+
if message.trace_id:
|
|
650
|
+
_unregister_trace_invocation(message.trace_id)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def get_current_trace_id() -> str | None:
|
|
654
|
+
"""Get the current trace ID from execution context."""
|
|
655
|
+
return _current_trace_id.get()
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def get_current_invocation_id() -> str | None:
|
|
659
|
+
"""Get the current invocation ID from execution context.
|
|
660
|
+
|
|
661
|
+
The invocation ID uniquely identifies the current message/request being processed.
|
|
662
|
+
This is the most specific identifier for correlating downstream calls (like MCP
|
|
663
|
+
tool invocations) with the parent agent invocation.
|
|
664
|
+
"""
|
|
665
|
+
return _current_invocation_id.get()
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def get_invocation_id_for_trace(trace_id: str | None) -> str | None:
|
|
669
|
+
"""Look up invocation ID by trace ID.
|
|
670
|
+
|
|
671
|
+
This is a fallback mechanism for when Python context variables aren't properly
|
|
672
|
+
propagated (e.g., when external SDKs like OpenAI Agents or Claude Agent SDK
|
|
673
|
+
execute tool calls in separate async contexts).
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
trace_id: The trace ID to look up
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
The invocation ID associated with the trace, or None if not found.
|
|
680
|
+
"""
|
|
681
|
+
if trace_id is None:
|
|
682
|
+
return None
|
|
683
|
+
return _trace_invocation_context.get(trace_id)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def get_current_parent_id() -> str | None:
|
|
687
|
+
"""Get the current parent ID from execution context.
|
|
688
|
+
|
|
689
|
+
The parent ID identifies the message that triggered this invocation,
|
|
690
|
+
useful for tracing chains of invocations.
|
|
691
|
+
"""
|
|
692
|
+
return _current_parent_id.get()
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
async def run_init_hook() -> None:
|
|
696
|
+
"""Run the registered init hook if present.
|
|
697
|
+
|
|
698
|
+
Called by the gRPC server before starting to handle requests.
|
|
699
|
+
|
|
700
|
+
Raises:
|
|
701
|
+
Exception: If the init hook fails, the exception is propagated.
|
|
702
|
+
"""
|
|
703
|
+
if _INIT_HOOK is not None:
|
|
704
|
+
logger.info(f"Running @init function: {_INIT_HOOK.__name__}")
|
|
705
|
+
await _INIT_HOOK()
|
|
706
|
+
logger.info(f"Completed @init function: {_INIT_HOOK.__name__}")
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
def get_handler_schemas() -> dict[str, HandlerMetadata]:
|
|
710
|
+
"""Get all registered handler schemas.
|
|
711
|
+
|
|
712
|
+
Returns a dictionary mapping handler names to their metadata, including
|
|
713
|
+
input/output schemas, topics (if any), and documentation.
|
|
714
|
+
|
|
715
|
+
This is useful for:
|
|
716
|
+
- Registering agent capabilities with the backend
|
|
717
|
+
- Displaying available handlers in UI
|
|
718
|
+
- Schema validation
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
Dict mapping handler names to HandlerMetadata with fields:
|
|
722
|
+
- handler_name: Name of the handler function
|
|
723
|
+
- input_schema: JSON schema for input payload
|
|
724
|
+
- output_schema: JSON schema for output payload (or None)
|
|
725
|
+
- handler_doc: Docstring from the handler function
|
|
726
|
+
- topics: List of topics this handler subscribes to (empty for @fn)
|
|
727
|
+
"""
|
|
728
|
+
return dict(HANDLER_METADATA)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def get_handler_metadata(topic: str) -> HandlerMetadata | None:
|
|
732
|
+
"""Get metadata for a specific topic's handler.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
topic: The topic to get metadata for
|
|
736
|
+
|
|
737
|
+
Returns:
|
|
738
|
+
HandlerMetadata for the handler, or None if topic not registered.
|
|
739
|
+
If multiple handlers are registered for the topic, returns the first one's metadata.
|
|
740
|
+
"""
|
|
741
|
+
handler_names = TOPIC_HANDLERS.get(topic)
|
|
742
|
+
if not handler_names:
|
|
743
|
+
return None
|
|
744
|
+
# Return metadata for the first handler
|
|
745
|
+
handler_name = handler_names[0]
|
|
746
|
+
return HANDLER_METADATA.get(handler_name)
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _get_router_url() -> str:
|
|
750
|
+
"""Get the dispatch router URL from environment or default."""
|
|
751
|
+
return os.getenv("BACKEND_URL", "http://dispatch.api:8000")
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def _get_namespace() -> str | None:
|
|
755
|
+
"""Get the dispatch namespace from environment.
|
|
756
|
+
|
|
757
|
+
Returns None if not set - the caller should handle the missing namespace.
|
|
758
|
+
"""
|
|
759
|
+
return os.getenv("DISPATCH_NAMESPACE")
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def _get_api_base_url() -> str:
|
|
763
|
+
"""Get the API base URL from environment or default.
|
|
764
|
+
|
|
765
|
+
Raises RuntimeError if DISPATCH_NAMESPACE is not set.
|
|
766
|
+
"""
|
|
767
|
+
namespace = _get_namespace()
|
|
768
|
+
if not namespace:
|
|
769
|
+
raise RuntimeError(
|
|
770
|
+
"DISPATCH_NAMESPACE environment variable is required. "
|
|
771
|
+
"Set it to the namespace your agent is deployed in."
|
|
772
|
+
)
|
|
773
|
+
return _get_router_url() + f"/api/unstable/namespace/{namespace}"
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
def _get_auth_headers() -> dict[str, str]:
|
|
777
|
+
"""Get authentication and version headers for API requests.
|
|
778
|
+
|
|
779
|
+
Returns headers including Authorization and SDK version information.
|
|
780
|
+
"""
|
|
781
|
+
|
|
782
|
+
headers = {
|
|
783
|
+
"x-dispatch-client": "sdk",
|
|
784
|
+
"x-dispatch-client-version": get_sdk_version(),
|
|
785
|
+
"x-dispatch-client-commit": os.getenv("GIT_COMMIT", "unknown")[:8],
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
api_key = os.getenv("DISPATCH_API_KEY")
|
|
789
|
+
if api_key:
|
|
790
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
791
|
+
|
|
792
|
+
return headers
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
async def emit_event(topic: str, payload: Any, sender_id: str | None = None) -> str:
|
|
796
|
+
"""Emit an event to the dispatch router.
|
|
797
|
+
|
|
798
|
+
Args:
|
|
799
|
+
topic: The topic/event type to publish to
|
|
800
|
+
payload: The event payload data
|
|
801
|
+
sender_id: Optional sender identifier (defaults to current agent)
|
|
802
|
+
|
|
803
|
+
Returns:
|
|
804
|
+
The unique event ID (uid) of the published message
|
|
805
|
+
"""
|
|
806
|
+
if sender_id is None:
|
|
807
|
+
sender_id = os.getenv("DISPATCH_AGENT_NAME", "unknown-agent")
|
|
808
|
+
|
|
809
|
+
# Automatically inherit context from current execution
|
|
810
|
+
trace_id = _current_trace_id.get()
|
|
811
|
+
# Child events should point to the current handler (invocation_id) as their
|
|
812
|
+
# parent, not the current handler's own parent.
|
|
813
|
+
parent_id = _current_invocation_id.get()
|
|
814
|
+
|
|
815
|
+
event_body = PublishEventBody(
|
|
816
|
+
topic=topic,
|
|
817
|
+
payload=payload if isinstance(payload, dict) else {"data": payload},
|
|
818
|
+
sender_id=sender_id,
|
|
819
|
+
trace_id=trace_id,
|
|
820
|
+
parent_id=parent_id,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
api_base_url = _get_api_base_url()
|
|
824
|
+
auth_headers = _get_auth_headers()
|
|
825
|
+
async with httpx.AsyncClient() as client:
|
|
826
|
+
response = await client.post(
|
|
827
|
+
f"{api_base_url}/events/publish",
|
|
828
|
+
json=event_body.model_dump(),
|
|
829
|
+
headers=auth_headers,
|
|
830
|
+
timeout=10.0,
|
|
831
|
+
)
|
|
832
|
+
response.raise_for_status()
|
|
833
|
+
result = response.json()
|
|
834
|
+
return result.get("event_uid", str(uuid.uuid4()))
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
# Type variable for response_model generic
|
|
838
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
@overload
|
|
842
|
+
async def invoke(
|
|
843
|
+
agent_name: str,
|
|
844
|
+
function_name: str,
|
|
845
|
+
payload: dict[str, Any] | BaseModel,
|
|
846
|
+
*,
|
|
847
|
+
response_model: type[ResponseT],
|
|
848
|
+
timeout: float = 60.0,
|
|
849
|
+
poll_interval: float = 0.5,
|
|
850
|
+
) -> ResponseT: ...
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
@overload
|
|
854
|
+
async def invoke(
|
|
855
|
+
agent_name: str,
|
|
856
|
+
function_name: str,
|
|
857
|
+
payload: dict[str, Any] | BaseModel,
|
|
858
|
+
*,
|
|
859
|
+
response_model: None = None,
|
|
860
|
+
timeout: float = 60.0,
|
|
861
|
+
poll_interval: float = 0.5,
|
|
862
|
+
) -> dict[str, Any]: ...
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
async def invoke(
|
|
866
|
+
agent_name: str,
|
|
867
|
+
function_name: str,
|
|
868
|
+
payload: dict[str, Any] | BaseModel,
|
|
869
|
+
*,
|
|
870
|
+
response_model: type[ResponseT] | None = None,
|
|
871
|
+
timeout: float = 60.0,
|
|
872
|
+
poll_interval: float = 0.5,
|
|
873
|
+
) -> ResponseT | dict[str, Any]:
|
|
874
|
+
"""Call a function on another agent and await the response.
|
|
875
|
+
|
|
876
|
+
This enables direct function calls between agents in the same namespace.
|
|
877
|
+
The target agent must have the function registered with @fn decorator.
|
|
878
|
+
|
|
879
|
+
The function uses a polling pattern:
|
|
880
|
+
1. POST to /invoke starts the invocation and returns an invocation_id
|
|
881
|
+
2. GET /invoke/{invocation_id} polls until status is "completed" or "error"
|
|
882
|
+
3. Returns the result when done
|
|
883
|
+
|
|
884
|
+
For fire-and-forget calls, wrap in asyncio.create_task():
|
|
885
|
+
asyncio.create_task(invoke("agent", "fn", payload))
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
agent_name: Name of the target agent
|
|
889
|
+
function_name: Name of the function to call
|
|
890
|
+
payload: Input data (dict or Pydantic BaseModel)
|
|
891
|
+
response_model: Optional Pydantic model to validate and parse the response.
|
|
892
|
+
When provided, returns an instance of the model instead of a dict.
|
|
893
|
+
timeout: Maximum time to wait for completion in seconds (default 60)
|
|
894
|
+
poll_interval: Time between status checks in seconds (default 0.5)
|
|
895
|
+
|
|
896
|
+
Returns:
|
|
897
|
+
The function's return value as a dict, or as an instance of response_model
|
|
898
|
+
if provided.
|
|
899
|
+
|
|
900
|
+
Raises:
|
|
901
|
+
httpx.HTTPStatusError: If the backend returns an error
|
|
902
|
+
RuntimeError: If the call fails or agent returns an error
|
|
903
|
+
TimeoutError: If the invocation doesn't complete within timeout
|
|
904
|
+
ValidationError: If response_model is provided and response doesn't match
|
|
905
|
+
|
|
906
|
+
Examples:
|
|
907
|
+
>>> # Untyped (returns dict)
|
|
908
|
+
>>> result = await invoke("weather-agent", "get_forecast", {"city": "NYC"})
|
|
909
|
+
>>> print(result["temperature"])
|
|
910
|
+
|
|
911
|
+
>>> # Typed (returns WeatherResponse with IDE autocomplete)
|
|
912
|
+
>>> result = await invoke("weather-agent", "get_forecast", {"city": "NYC"},
|
|
913
|
+
... response_model=WeatherResponse)
|
|
914
|
+
>>> print(result.temperature) # IDE knows this is WeatherResponse
|
|
915
|
+
"""
|
|
916
|
+
# Convert Pydantic model to dict if needed
|
|
917
|
+
if isinstance(payload, BaseModel):
|
|
918
|
+
payload_dict = payload.model_dump()
|
|
919
|
+
else:
|
|
920
|
+
payload_dict = payload
|
|
921
|
+
|
|
922
|
+
# Inherit context from current execution
|
|
923
|
+
trace_id = _current_trace_id.get() or str(uuid.uuid4())
|
|
924
|
+
# Child invocations should point to the current handler (invocation_id) as
|
|
925
|
+
# their parent, not the current handler's own parent.
|
|
926
|
+
parent_id = _current_invocation_id.get()
|
|
927
|
+
|
|
928
|
+
# Build request body using typed model for API consistency
|
|
929
|
+
invoke_request = InvokeFunctionRequest(
|
|
930
|
+
agent_name=agent_name,
|
|
931
|
+
function_name=function_name,
|
|
932
|
+
payload=payload_dict,
|
|
933
|
+
trace_id=trace_id,
|
|
934
|
+
parent_id=parent_id,
|
|
935
|
+
timeout_seconds=int(timeout),
|
|
936
|
+
)
|
|
937
|
+
invoke_body = invoke_request.model_dump(exclude_none=True)
|
|
938
|
+
|
|
939
|
+
api_base_url = _get_api_base_url()
|
|
940
|
+
auth_headers = _get_auth_headers()
|
|
941
|
+
|
|
942
|
+
async with httpx.AsyncClient() as client:
|
|
943
|
+
# Step 1: Start the invocation (returns immediately with invocation_id)
|
|
944
|
+
response = await client.post(
|
|
945
|
+
f"{api_base_url}/invoke",
|
|
946
|
+
json=invoke_body,
|
|
947
|
+
headers=auth_headers,
|
|
948
|
+
timeout=10.0,
|
|
949
|
+
)
|
|
950
|
+
response.raise_for_status()
|
|
951
|
+
start_result = response.json()
|
|
952
|
+
|
|
953
|
+
invocation_id = start_result["invocation_id"]
|
|
954
|
+
|
|
955
|
+
# Step 2: Poll for completion
|
|
956
|
+
loop = asyncio.get_running_loop()
|
|
957
|
+
start_time = loop.time()
|
|
958
|
+
while True:
|
|
959
|
+
elapsed = loop.time() - start_time
|
|
960
|
+
if elapsed >= timeout:
|
|
961
|
+
raise TimeoutError(
|
|
962
|
+
f"Invocation {invocation_id} did not complete within {timeout}s"
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
# Check status
|
|
966
|
+
status_response = await client.get(
|
|
967
|
+
f"{api_base_url}/invoke/{invocation_id}",
|
|
968
|
+
headers=auth_headers,
|
|
969
|
+
timeout=10.0,
|
|
970
|
+
)
|
|
971
|
+
status_response.raise_for_status()
|
|
972
|
+
status = status_response.json()
|
|
973
|
+
|
|
974
|
+
if status["status"] == "completed":
|
|
975
|
+
result = status.get("result") or {}
|
|
976
|
+
if response_model is not None:
|
|
977
|
+
return response_model.model_validate(result)
|
|
978
|
+
return result
|
|
979
|
+
|
|
980
|
+
if status["status"] == "error":
|
|
981
|
+
raise RuntimeError(
|
|
982
|
+
f"Invoke failed: {status.get('error', 'Unknown error')}"
|
|
983
|
+
)
|
|
984
|
+
|
|
985
|
+
# Wait before next poll
|
|
986
|
+
await asyncio.sleep(poll_interval)
|