dispatch_agents 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentservice/__init__.py +0 -0
- agentservice/py.typed +0 -0
- agentservice/v1/__init__.py +0 -0
- agentservice/v1/message_pb2.py +41 -0
- agentservice/v1/message_pb2.pyi +22 -0
- agentservice/v1/message_pb2_grpc.py +4 -0
- agentservice/v1/request_response_pb2.py +46 -0
- agentservice/v1/request_response_pb2.pyi +54 -0
- agentservice/v1/request_response_pb2_grpc.py +4 -0
- agentservice/v1/service_pb2.py +43 -0
- agentservice/v1/service_pb2.pyi +6 -0
- agentservice/v1/service_pb2_grpc.py +129 -0
- dispatch_agents/__init__.py +281 -0
- dispatch_agents/agent_service.py +135 -0
- dispatch_agents/config.py +490 -0
- dispatch_agents/contrib/__init__.py +1 -0
- dispatch_agents/contrib/claude/__init__.py +246 -0
- dispatch_agents/contrib/openai/__init__.py +167 -0
- dispatch_agents/events.py +986 -0
- dispatch_agents/grpc_server.py +565 -0
- dispatch_agents/instrument.py +217 -0
- dispatch_agents/integrations/__init__.py +1 -0
- dispatch_agents/integrations/github/README.md +9 -0
- dispatch_agents/integrations/github/__init__.py +4268 -0
- dispatch_agents/invocation.py +25 -0
- dispatch_agents/llm.py +1017 -0
- dispatch_agents/llm_langchain.py +394 -0
- dispatch_agents/logging_config.py +133 -0
- dispatch_agents/mcp.py +266 -0
- dispatch_agents/memory.py +264 -0
- dispatch_agents/models.py +748 -0
- dispatch_agents/proxy/__init__.py +6 -0
- dispatch_agents/proxy/server.py +1137 -0
- dispatch_agents/proxy/sse_utils.py +76 -0
- dispatch_agents/py.typed +0 -0
- dispatch_agents/resources.py +68 -0
- dispatch_agents/version.py +19 -0
- dispatch_agents-0.9.0.dist-info/METADATA +20 -0
- dispatch_agents-0.9.0.dist-info/RECORD +43 -0
- dispatch_agents-0.9.0.dist-info/WHEEL +4 -0
- dispatch_agents-0.9.0.dist-info/licenses/LICENSE +191 -0
- dispatch_agents-0.9.0.dist-info/licenses/LICENSE-3rdparty.csv +12 -0
- dispatch_agents-0.9.0.dist-info/licenses/NOTICE +5 -0
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
"""gRPC server implementation for dispatch agents.
|
|
2
|
+
|
|
3
|
+
This module provides a gRPC server that implements the AgentService interface,
|
|
4
|
+
allowing agents to be invoked via gRPC instead of HTTP.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import signal
|
|
12
|
+
import time
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import grpc
|
|
16
|
+
import httpx
|
|
17
|
+
from grpc import aio
|
|
18
|
+
|
|
19
|
+
from agentservice.v1 import (
|
|
20
|
+
message_pb2,
|
|
21
|
+
request_response_pb2,
|
|
22
|
+
service_pb2_grpc,
|
|
23
|
+
)
|
|
24
|
+
from dispatch_agents.events import (
|
|
25
|
+
HANDLER_METADATA,
|
|
26
|
+
REGISTERED_HANDLERS,
|
|
27
|
+
TOPIC_HANDLERS,
|
|
28
|
+
dispatch_message,
|
|
29
|
+
run_init_hook,
|
|
30
|
+
)
|
|
31
|
+
from dispatch_agents.logging_config import get_logger
|
|
32
|
+
from dispatch_agents.models import ErrorPayload, FunctionMessage, Message, TopicMessage
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _SubscribeLogFilter(logging.Filter):
|
|
38
|
+
"""Filter to suppress successful subscription heartbeat logs from httpx.
|
|
39
|
+
|
|
40
|
+
Only suppresses httpx logs for /events/subscribe that return 200 OK.
|
|
41
|
+
Failed subscriptions (4xx, 5xx) are still logged for debugging.
|
|
42
|
+
All other HTTP requests (emit_event, invoke, memory API) remain visible.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
|
46
|
+
msg = str(record.getMessage())
|
|
47
|
+
# Only suppress successful subscription heartbeats
|
|
48
|
+
if "/events/subscribe" in msg and "200" in msg:
|
|
49
|
+
return False # Suppress successful subscription logs
|
|
50
|
+
return True # Allow failures and all other logs
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Apply filter to httpx logger - only suppresses successful subscription heartbeats
|
|
54
|
+
logging.getLogger("httpx").addFilter(_SubscribeLogFilter())
|
|
55
|
+
|
|
56
|
+
# File-based health signal for ECS container health checks.
|
|
57
|
+
# Contains a Unix timestamp updated on each successful subscription.
|
|
58
|
+
# The ECS health check verifies the timestamp is recent (< 90s old),
|
|
59
|
+
# so a stale file from a crashed process won't fool the check.
|
|
60
|
+
# Uses /tmp/ which is a writable tmpfs mount on containers with
|
|
61
|
+
# readonlyRootFilesystem. The /app/ directory is read-only.
|
|
62
|
+
_HEALTH_FILE = Path("/tmp/.dispatch_healthy")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _mark_healthy() -> None:
|
|
66
|
+
"""Best-effort: write current timestamp to health marker file."""
|
|
67
|
+
try:
|
|
68
|
+
_HEALTH_FILE.write_text(str(int(time.time())))
|
|
69
|
+
except OSError as exc:
|
|
70
|
+
logger.warning("Could not write health file %s: %s", _HEALTH_FILE, exc)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _mark_unhealthy() -> None:
|
|
74
|
+
"""Best-effort: remove health marker file."""
|
|
75
|
+
try:
|
|
76
|
+
_HEALTH_FILE.unlink(missing_ok=True)
|
|
77
|
+
except OSError as exc:
|
|
78
|
+
logger.warning("Could not remove health file %s: %s", _HEALTH_FILE, exc)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AgentServiceServicer(service_pb2_grpc.AgentServiceServicer):
|
|
82
|
+
"""Implementation of the AgentService gRPC interface."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, agent_name: str):
|
|
85
|
+
"""Initialize the servicer.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
agent_name: The name of the agent being served
|
|
89
|
+
"""
|
|
90
|
+
self.agent_name = agent_name
|
|
91
|
+
|
|
92
|
+
async def Invoke(
|
|
93
|
+
self,
|
|
94
|
+
request: request_response_pb2.InvokeRequest,
|
|
95
|
+
context: grpc.aio.ServicerContext,
|
|
96
|
+
) -> request_response_pb2.InvokeResponse:
|
|
97
|
+
"""Invoke an agent function and return the result.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
request: The InvokeRequest containing function name and payload
|
|
101
|
+
context: The gRPC context
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
InvokeResponse containing the result payload
|
|
105
|
+
"""
|
|
106
|
+
logger.info(
|
|
107
|
+
f"Received Invoke request: message_type={request.message_type}, "
|
|
108
|
+
f"topic={request.topic}, function_name={request.function_name}, "
|
|
109
|
+
f"trace_id={request.trace_id}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
# Decode the payload from protobuf
|
|
114
|
+
payload_data = json.loads(request.payload.data.decode("utf-8"))
|
|
115
|
+
|
|
116
|
+
# Create appropriate message type based on message_type field
|
|
117
|
+
# - "topic": Creates TopicMessage, routes via TOPIC_HANDLERS[topic]
|
|
118
|
+
# - "function": Creates FunctionMessage, routes via REGISTERED_HANDLERS[function_name]
|
|
119
|
+
message: Message
|
|
120
|
+
if request.message_type == "topic":
|
|
121
|
+
message = TopicMessage(
|
|
122
|
+
topic=request.topic,
|
|
123
|
+
payload=payload_data,
|
|
124
|
+
uid=request.uid,
|
|
125
|
+
trace_id=request.trace_id,
|
|
126
|
+
sender_id="grpc-client",
|
|
127
|
+
ts=request.ts,
|
|
128
|
+
parent_id=None,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
# Default to function message (for backwards compatibility and direct calls)
|
|
132
|
+
message = FunctionMessage(
|
|
133
|
+
function_name=request.function_name,
|
|
134
|
+
payload=payload_data,
|
|
135
|
+
uid=request.uid,
|
|
136
|
+
trace_id=request.trace_id,
|
|
137
|
+
sender_id="grpc-client",
|
|
138
|
+
ts=request.ts,
|
|
139
|
+
parent_id=None,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Dispatch the message to the appropriate handler
|
|
143
|
+
result = await dispatch_message(message)
|
|
144
|
+
|
|
145
|
+
# Serialize the result (SuccessPayload or ErrorPayload) to JSON
|
|
146
|
+
is_error = isinstance(result, ErrorPayload)
|
|
147
|
+
result_payload = message_pb2.Payload(
|
|
148
|
+
metadata={},
|
|
149
|
+
data=json.dumps(result.model_dump()).encode("utf-8"),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return request_response_pb2.InvokeResponse(
|
|
153
|
+
result=result_payload,
|
|
154
|
+
is_error=is_error,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
except Exception as e:
|
|
158
|
+
logger.error(f"Error processing Invoke request: {e}", exc_info=True)
|
|
159
|
+
# Return error as gRPC status
|
|
160
|
+
await context.abort(
|
|
161
|
+
grpc.StatusCode.INTERNAL,
|
|
162
|
+
f"Error processing request: {str(e)}",
|
|
163
|
+
)
|
|
164
|
+
# This line won't be reached, but satisfies type checker
|
|
165
|
+
raise
|
|
166
|
+
|
|
167
|
+
async def HealthCheck(
|
|
168
|
+
self,
|
|
169
|
+
request: request_response_pb2.HealthCheckRequest,
|
|
170
|
+
context: grpc.aio.ServicerContext,
|
|
171
|
+
) -> request_response_pb2.HealthCheckResponse:
|
|
172
|
+
"""Check the health of the agent.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
request: The HealthCheckRequest
|
|
176
|
+
context: The gRPC context
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
HealthCheckResponse with serving status
|
|
180
|
+
"""
|
|
181
|
+
logger.debug("Received HealthCheck request")
|
|
182
|
+
return request_response_pb2.HealthCheckResponse(
|
|
183
|
+
status=request_response_pb2.HealthCheckResponse.SERVING_STATUS_SERVING
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _is_local_dev_mode() -> bool:
|
|
188
|
+
"""Check if running in local development mode.
|
|
189
|
+
|
|
190
|
+
Returns True only if DISPATCH_LOCAL_DEV is explicitly set to a truthy value.
|
|
191
|
+
This enables dev-friendly behaviors like auto-shutdown on backend connection failure.
|
|
192
|
+
|
|
193
|
+
Note: We use explicit opt-in rather than heuristics (like checking for AWS env vars)
|
|
194
|
+
because localstack Docker containers should behave like production agents, not dev mode.
|
|
195
|
+
"""
|
|
196
|
+
local_dev = os.getenv("DISPATCH_LOCAL_DEV", "").lower()
|
|
197
|
+
return local_dev in ("1", "true", "yes")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
async def _subscribe_registered_triggers(
|
|
201
|
+
agent_name: str, *, is_initial: bool = False
|
|
202
|
+
) -> bool:
|
|
203
|
+
"""Subscribe the agent to all registered topics with the backend.
|
|
204
|
+
|
|
205
|
+
This function performs a single subscription attempt. For continuous re-subscription,
|
|
206
|
+
use _subscription_loop().
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
agent_name: The name of the agent to subscribe
|
|
210
|
+
is_initial: Whether this is the initial subscription (affects log level)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
True if subscription succeeded, False if it failed (for local dev tracking)
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
RuntimeError: If subscription fails and backend is expected to be available
|
|
217
|
+
"""
|
|
218
|
+
topics = list(TOPIC_HANDLERS.keys())
|
|
219
|
+
|
|
220
|
+
# Count total handlers (all registered handlers)
|
|
221
|
+
# Topic-based (@on) handlers have topics, callable (@fn) handlers have empty topics
|
|
222
|
+
fn_handlers = [
|
|
223
|
+
name
|
|
224
|
+
for name, meta in HANDLER_METADATA.items()
|
|
225
|
+
if not meta.topics # @fn handlers have empty topics list
|
|
226
|
+
]
|
|
227
|
+
total_handlers = len(REGISTERED_HANDLERS)
|
|
228
|
+
|
|
229
|
+
if total_handlers == 0:
|
|
230
|
+
logger.info("No registered handlers found; skipping subscription.")
|
|
231
|
+
return True # No handlers = nothing to subscribe, counts as success
|
|
232
|
+
|
|
233
|
+
# Get backend URL - REQUIRED
|
|
234
|
+
backend_url = os.getenv("BACKEND_URL")
|
|
235
|
+
if not backend_url:
|
|
236
|
+
error_msg = (
|
|
237
|
+
"BACKEND_URL environment variable is required but not set. "
|
|
238
|
+
"This should be configured during agent deployment."
|
|
239
|
+
)
|
|
240
|
+
logger.error(error_msg)
|
|
241
|
+
raise RuntimeError(error_msg)
|
|
242
|
+
|
|
243
|
+
# Get namespace - optional for backwards compatibility with simple local router
|
|
244
|
+
namespace = os.getenv("DISPATCH_NAMESPACE")
|
|
245
|
+
if namespace:
|
|
246
|
+
# Use namespace-scoped endpoints (backend infrastructure)
|
|
247
|
+
api_base_url = f"{backend_url}/api/unstable/namespace/{namespace}"
|
|
248
|
+
else:
|
|
249
|
+
# Use simple non-namespaced endpoints (local router)
|
|
250
|
+
api_base_url = f"{backend_url}/api/unstable"
|
|
251
|
+
logger.info(
|
|
252
|
+
"DISPATCH_NAMESPACE not set, using non-namespaced endpoints (local router mode)"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
url = f"{api_base_url}/events/subscribe"
|
|
256
|
+
|
|
257
|
+
# Get auth headers - API key is required for deployed agents
|
|
258
|
+
api_key = os.getenv("DISPATCH_API_KEY")
|
|
259
|
+
headers: dict[str, str] = {"Content-Type": "application/json"}
|
|
260
|
+
agent_version = os.getenv("DISPATCH_AGENT_VERSION")
|
|
261
|
+
if agent_version:
|
|
262
|
+
headers["X-Dispatch-Agent-Version"] = agent_version
|
|
263
|
+
if api_key:
|
|
264
|
+
# Mask the API key for logging (show first 12 chars only)
|
|
265
|
+
api_key_preview = api_key[:12] + "..." if len(api_key) > 12 else "***"
|
|
266
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
267
|
+
logger.debug(f"Using API key for authentication: {api_key_preview}")
|
|
268
|
+
else:
|
|
269
|
+
# In local dev mode, API key is optional - use debug level to reduce noise
|
|
270
|
+
# In production, backend will enforce authentication
|
|
271
|
+
logger.debug(
|
|
272
|
+
"No DISPATCH_API_KEY found in environment (optional for local dev)"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Build functions list from unified handler registry
|
|
276
|
+
from dispatch_agents.models import AgentFunction, FunctionTrigger
|
|
277
|
+
|
|
278
|
+
functions = []
|
|
279
|
+
|
|
280
|
+
# All handlers are in HANDLER_METADATA - build functions list from there
|
|
281
|
+
for handler_name, metadata in HANDLER_METADATA.items():
|
|
282
|
+
handler_topics = metadata.topics
|
|
283
|
+
|
|
284
|
+
# Build triggers based on handler type
|
|
285
|
+
triggers = []
|
|
286
|
+
|
|
287
|
+
# Add topic triggers for @on handlers
|
|
288
|
+
for topic in handler_topics:
|
|
289
|
+
triggers.append(FunctionTrigger(type="topic", topic=topic))
|
|
290
|
+
|
|
291
|
+
# All handlers are callable by name (even @on handlers)
|
|
292
|
+
triggers.append(FunctionTrigger(type="callable", function_name=handler_name))
|
|
293
|
+
|
|
294
|
+
function = AgentFunction(
|
|
295
|
+
name=handler_name,
|
|
296
|
+
description=metadata.handler_doc,
|
|
297
|
+
input_schema=metadata.input_schema,
|
|
298
|
+
output_schema=metadata.output_schema,
|
|
299
|
+
triggers=triggers,
|
|
300
|
+
)
|
|
301
|
+
functions.append(function.model_dump())
|
|
302
|
+
|
|
303
|
+
payload = {"topics": topics, "agent_name": agent_name, "functions": functions}
|
|
304
|
+
|
|
305
|
+
# Log configuration for debugging (use debug for subsequent calls)
|
|
306
|
+
logger.debug(
|
|
307
|
+
f"Subscribing agent '{agent_name}': {len(topics)} topic(s), {len(fn_handlers)} callable function(s)"
|
|
308
|
+
)
|
|
309
|
+
logger.debug(f"Backend URL: {backend_url}")
|
|
310
|
+
logger.debug(f"Subscription endpoint: {url}")
|
|
311
|
+
logger.debug(f"Sending {len(functions)} function(s) with schemas")
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
async with httpx.AsyncClient() as client:
|
|
315
|
+
logger.debug(f"Sending POST request to {url}")
|
|
316
|
+
response = await client.post(
|
|
317
|
+
url, json=payload, headers=headers, timeout=10.0
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Log response details for debugging
|
|
321
|
+
logger.debug(f"Response status: {response.status_code}")
|
|
322
|
+
logger.debug(f"Response headers: {dict(response.headers)}")
|
|
323
|
+
|
|
324
|
+
response.raise_for_status()
|
|
325
|
+
|
|
326
|
+
# Log at INFO for initial subscription, DEBUG for re-subscriptions to reduce noise
|
|
327
|
+
log_msg = f"✅ Successfully subscribed {len(topics)} topic(s) for agent {agent_name}"
|
|
328
|
+
if is_initial:
|
|
329
|
+
logger.info(log_msg)
|
|
330
|
+
else:
|
|
331
|
+
logger.debug(log_msg)
|
|
332
|
+
|
|
333
|
+
# Signal healthy to ECS container health check
|
|
334
|
+
_mark_healthy()
|
|
335
|
+
return True # Success
|
|
336
|
+
|
|
337
|
+
except httpx.ConnectError as e:
|
|
338
|
+
_mark_unhealthy()
|
|
339
|
+
error_msg = (
|
|
340
|
+
f"Failed to connect to backend at {url}. "
|
|
341
|
+
f"Connection error: {e}. "
|
|
342
|
+
f"Ensure BACKEND_URL is set correctly and the backend is accessible."
|
|
343
|
+
)
|
|
344
|
+
logger.error(error_msg)
|
|
345
|
+
# Exit with error if we're in a deployed environment (ECS)
|
|
346
|
+
if not _is_local_dev_mode():
|
|
347
|
+
raise RuntimeError(error_msg) from e
|
|
348
|
+
logger.warning("Connection failed (local development mode)")
|
|
349
|
+
return False # Signal failure for retry tracking
|
|
350
|
+
|
|
351
|
+
except httpx.HTTPStatusError as e:
|
|
352
|
+
_mark_unhealthy()
|
|
353
|
+
error_msg = (
|
|
354
|
+
f"Backend returned error status {e.response.status_code} for {url}. "
|
|
355
|
+
f"Response: {e.response.text}"
|
|
356
|
+
)
|
|
357
|
+
logger.error(error_msg)
|
|
358
|
+
# Always exit on HTTP errors (401, 403, 404, 500, etc)
|
|
359
|
+
raise RuntimeError(error_msg) from e
|
|
360
|
+
|
|
361
|
+
except httpx.TimeoutException as e:
|
|
362
|
+
_mark_unhealthy()
|
|
363
|
+
error_msg = f"Timeout connecting to backend at {url} after 10s: {e}"
|
|
364
|
+
logger.error(error_msg)
|
|
365
|
+
# Exit with error if we're in a deployed environment
|
|
366
|
+
if not _is_local_dev_mode():
|
|
367
|
+
raise RuntimeError(error_msg) from e
|
|
368
|
+
logger.warning("Connection timed out (local development mode)")
|
|
369
|
+
return False # Signal failure for retry tracking
|
|
370
|
+
|
|
371
|
+
except Exception as e:
|
|
372
|
+
_mark_unhealthy()
|
|
373
|
+
error_msg = (
|
|
374
|
+
f"Unexpected error during subscription to {url}: {type(e).__name__}: {e}"
|
|
375
|
+
)
|
|
376
|
+
logger.error(error_msg, exc_info=True)
|
|
377
|
+
raise RuntimeError(error_msg) from e
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
async def _subscription_loop(
|
|
381
|
+
agent_name: str,
|
|
382
|
+
interval_seconds: int = 30,
|
|
383
|
+
*,
|
|
384
|
+
shutdown_event: asyncio.Event | None = None,
|
|
385
|
+
) -> None:
|
|
386
|
+
"""Continuously re-subscribe the agent to registered topics.
|
|
387
|
+
|
|
388
|
+
This background task ensures that the agent maintains its subscription even if
|
|
389
|
+
the backend server is restarted or becomes temporarily unavailable.
|
|
390
|
+
|
|
391
|
+
In local development mode, the agent will exit after 2 consecutive connection
|
|
392
|
+
failures to avoid running indefinitely when the router is not available.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
agent_name: The name of the agent to subscribe
|
|
396
|
+
interval_seconds: Time to wait between subscription attempts (default: 30)
|
|
397
|
+
shutdown_event: Optional event to signal shutdown on fatal errors
|
|
398
|
+
"""
|
|
399
|
+
logger.info(
|
|
400
|
+
f"Starting subscription loop for agent '{agent_name}' "
|
|
401
|
+
f"(interval: {interval_seconds}s)"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
consecutive_failures = 0
|
|
405
|
+
max_consecutive_failures = 2 # Exit after 2 failures in local dev mode
|
|
406
|
+
|
|
407
|
+
# Perform initial subscription immediately
|
|
408
|
+
try:
|
|
409
|
+
success = await _subscribe_registered_triggers(agent_name, is_initial=True)
|
|
410
|
+
if success:
|
|
411
|
+
consecutive_failures = 0
|
|
412
|
+
else:
|
|
413
|
+
consecutive_failures += 1
|
|
414
|
+
if (
|
|
415
|
+
_is_local_dev_mode()
|
|
416
|
+
and consecutive_failures >= max_consecutive_failures
|
|
417
|
+
):
|
|
418
|
+
logger.error(
|
|
419
|
+
f"Failed to connect to backend {consecutive_failures} times in a row. "
|
|
420
|
+
f"Shutting down. Is the local router running? "
|
|
421
|
+
f"Start it with: dispatch router start"
|
|
422
|
+
)
|
|
423
|
+
if shutdown_event:
|
|
424
|
+
shutdown_event.set()
|
|
425
|
+
return
|
|
426
|
+
except Exception as e:
|
|
427
|
+
logger.error(f"Initial subscription failed: {e}")
|
|
428
|
+
consecutive_failures += 1
|
|
429
|
+
|
|
430
|
+
# Continuously re-subscribe at the specified interval
|
|
431
|
+
while True:
|
|
432
|
+
try:
|
|
433
|
+
await asyncio.sleep(interval_seconds)
|
|
434
|
+
success = await _subscribe_registered_triggers(agent_name, is_initial=False)
|
|
435
|
+
if success:
|
|
436
|
+
consecutive_failures = 0
|
|
437
|
+
else:
|
|
438
|
+
consecutive_failures += 1
|
|
439
|
+
if (
|
|
440
|
+
_is_local_dev_mode()
|
|
441
|
+
and consecutive_failures >= max_consecutive_failures
|
|
442
|
+
):
|
|
443
|
+
logger.error(
|
|
444
|
+
f"Failed to connect to backend {consecutive_failures} times in a row. "
|
|
445
|
+
f"Shutting down. Is the local router running? "
|
|
446
|
+
f"Start it with: dispatch router start"
|
|
447
|
+
)
|
|
448
|
+
if shutdown_event:
|
|
449
|
+
shutdown_event.set()
|
|
450
|
+
return
|
|
451
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
452
|
+
logger.info("Subscription loop cancelled, shutting down")
|
|
453
|
+
raise
|
|
454
|
+
except Exception as e:
|
|
455
|
+
logger.error(f"Error in subscription loop: {e}", exc_info=True)
|
|
456
|
+
consecutive_failures += 1
|
|
457
|
+
if (
|
|
458
|
+
_is_local_dev_mode()
|
|
459
|
+
and consecutive_failures >= max_consecutive_failures
|
|
460
|
+
):
|
|
461
|
+
logger.error(
|
|
462
|
+
f"Failed to connect to backend {consecutive_failures} times in a row. "
|
|
463
|
+
f"Shutting down. Is the local router running? "
|
|
464
|
+
f"Start it with: dispatch router start"
|
|
465
|
+
)
|
|
466
|
+
if shutdown_event:
|
|
467
|
+
shutdown_event.set()
|
|
468
|
+
return
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
async def serve(
|
|
472
|
+
agent_name: str,
|
|
473
|
+
port: int = 50051,
|
|
474
|
+
*,
|
|
475
|
+
insecure: bool = True,
|
|
476
|
+
cert_dir: str | None = None,
|
|
477
|
+
subscription_interval: int = 30,
|
|
478
|
+
) -> None:
|
|
479
|
+
"""Start the gRPC server for the agent.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
agent_name: The name of the agent
|
|
483
|
+
port: The port to listen on (default: 50051)
|
|
484
|
+
insecure: Whether to use an insecure server (default: True for development)
|
|
485
|
+
cert_dir: Directory containing server.crt, server.key, ca.crt for mTLS
|
|
486
|
+
subscription_interval: Seconds between subscription attempts (default: 30)
|
|
487
|
+
"""
|
|
488
|
+
server = aio.server()
|
|
489
|
+
servicer = AgentServiceServicer(agent_name=agent_name)
|
|
490
|
+
service_pb2_grpc.add_AgentServiceServicer_to_server(servicer, server)
|
|
491
|
+
|
|
492
|
+
listen_addr = f"0.0.0.0:{port}"
|
|
493
|
+
if insecure:
|
|
494
|
+
server.add_insecure_port(listen_addr)
|
|
495
|
+
logger.info(f"Starting insecure gRPC server on {listen_addr}")
|
|
496
|
+
else:
|
|
497
|
+
if not cert_dir:
|
|
498
|
+
raise ValueError("cert_dir is required when insecure=False")
|
|
499
|
+
tls_dir = Path(cert_dir)
|
|
500
|
+
server_key = (tls_dir / "server.key").read_bytes()
|
|
501
|
+
server_cert = (tls_dir / "server.crt").read_bytes()
|
|
502
|
+
ca_cert = (tls_dir / "ca.crt").read_bytes()
|
|
503
|
+
server_creds = grpc.ssl_server_credentials(
|
|
504
|
+
[(server_key, server_cert)],
|
|
505
|
+
root_certificates=ca_cert,
|
|
506
|
+
require_client_auth=True,
|
|
507
|
+
)
|
|
508
|
+
server.add_secure_port(listen_addr, server_creds)
|
|
509
|
+
logger.info("Starting mTLS gRPC server on %s", listen_addr)
|
|
510
|
+
|
|
511
|
+
await server.start()
|
|
512
|
+
logger.info(f"gRPC server started for agent '{agent_name}' on {listen_addr}")
|
|
513
|
+
|
|
514
|
+
# Run @init function before handling any requests
|
|
515
|
+
await run_init_hook()
|
|
516
|
+
|
|
517
|
+
# Setup signal handlers for graceful shutdown
|
|
518
|
+
loop = asyncio.get_running_loop()
|
|
519
|
+
shutdown_event = asyncio.Event()
|
|
520
|
+
|
|
521
|
+
# Start background subscription loop (pass shutdown_event for fatal error handling)
|
|
522
|
+
subscription_task = asyncio.create_task(
|
|
523
|
+
_subscription_loop(
|
|
524
|
+
agent_name, subscription_interval, shutdown_event=shutdown_event
|
|
525
|
+
)
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def signal_handler():
|
|
529
|
+
logger.info("Received shutdown signal, initiating graceful shutdown")
|
|
530
|
+
shutdown_event.set()
|
|
531
|
+
|
|
532
|
+
# Register handlers for SIGTERM (container stop) and SIGINT (Ctrl+C)
|
|
533
|
+
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
534
|
+
loop.add_signal_handler(sig, signal_handler)
|
|
535
|
+
|
|
536
|
+
try:
|
|
537
|
+
# Wait for either server termination or shutdown signal
|
|
538
|
+
done, pending = await asyncio.wait(
|
|
539
|
+
[
|
|
540
|
+
asyncio.create_task(server.wait_for_termination()),
|
|
541
|
+
asyncio.create_task(shutdown_event.wait()),
|
|
542
|
+
],
|
|
543
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Cancel pending tasks
|
|
547
|
+
for task in pending:
|
|
548
|
+
task.cancel()
|
|
549
|
+
|
|
550
|
+
# If shutdown signal was received, stop the server
|
|
551
|
+
if shutdown_event.is_set():
|
|
552
|
+
logger.info("Stopping gRPC server...")
|
|
553
|
+
await server.stop(grace=5)
|
|
554
|
+
|
|
555
|
+
except KeyboardInterrupt:
|
|
556
|
+
# Fallback for systems where signal handlers don't work
|
|
557
|
+
logger.info("Received keyboard interrupt, shutting down")
|
|
558
|
+
await server.stop(grace=5)
|
|
559
|
+
finally:
|
|
560
|
+
# Cancel the subscription loop when server terminates
|
|
561
|
+
subscription_task.cancel()
|
|
562
|
+
try:
|
|
563
|
+
await subscription_task
|
|
564
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
565
|
+
pass
|