fast-agent-mcp 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/RECORD +33 -33
- mcp_agent/agents/agent.py +2 -2
- mcp_agent/agents/base_agent.py +3 -3
- mcp_agent/agents/workflow/chain_agent.py +2 -2
- mcp_agent/agents/workflow/evaluator_optimizer.py +3 -3
- mcp_agent/agents/workflow/orchestrator_agent.py +3 -3
- mcp_agent/agents/workflow/parallel_agent.py +2 -2
- mcp_agent/agents/workflow/router_agent.py +2 -2
- mcp_agent/cli/commands/check_config.py +450 -0
- mcp_agent/cli/commands/setup.py +1 -1
- mcp_agent/cli/main.py +8 -15
- mcp_agent/core/agent_types.py +8 -8
- mcp_agent/core/direct_decorators.py +10 -8
- mcp_agent/core/direct_factory.py +4 -1
- mcp_agent/core/validation.py +6 -4
- mcp_agent/event_progress.py +6 -6
- mcp_agent/llm/augmented_llm.py +10 -2
- mcp_agent/llm/augmented_llm_passthrough.py +5 -3
- mcp_agent/llm/augmented_llm_playback.py +2 -1
- mcp_agent/llm/model_factory.py +7 -27
- mcp_agent/llm/provider_key_manager.py +83 -0
- mcp_agent/llm/provider_types.py +16 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +5 -26
- mcp_agent/llm/providers/augmented_llm_deepseek.py +5 -24
- mcp_agent/llm/providers/augmented_llm_generic.py +2 -16
- mcp_agent/llm/providers/augmented_llm_openai.py +4 -26
- mcp_agent/llm/providers/augmented_llm_openrouter.py +17 -45
- mcp_agent/mcp/interfaces.py +2 -1
- mcp_agent/mcp_server/agent_server.py +120 -38
- mcp_agent/cli/commands/config.py +0 -11
- mcp_agent/executor/temporal.py +0 -383
- mcp_agent/executor/workflow.py +0 -195
- {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/licenses/LICENSE +0 -0
@@ -143,13 +143,38 @@ class AgentMCPServer:
|
|
143
143
|
self.mcp_server.settings.host = host
|
144
144
|
self.mcp_server.settings.port = port
|
145
145
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
146
|
+
# For synchronous run, we can use the simpler approach
|
147
|
+
try:
|
148
|
+
# Add any server attributes that might help with shutdown
|
149
|
+
if not hasattr(self.mcp_server, "_server_should_exit"):
|
150
|
+
self.mcp_server._server_should_exit = False
|
151
|
+
|
152
|
+
# Run the server
|
153
|
+
self.mcp_server.run(transport=transport)
|
154
|
+
except KeyboardInterrupt:
|
155
|
+
print("\nServer stopped by user (CTRL+C)")
|
156
|
+
except SystemExit as e:
|
157
|
+
# Handle normal exit
|
158
|
+
print(f"\nServer exiting with code {e.code}")
|
159
|
+
# Re-raise to allow normal exit process
|
160
|
+
raise
|
161
|
+
except Exception as e:
|
162
|
+
print(f"\nServer error: {e}")
|
163
|
+
finally:
|
164
|
+
# Run an async cleanup in a new event loop
|
165
|
+
try:
|
166
|
+
asyncio.run(self.shutdown())
|
167
|
+
except (SystemExit, KeyboardInterrupt):
|
168
|
+
# These are expected during shutdown
|
169
|
+
pass
|
170
|
+
else: # stdio
|
171
|
+
try:
|
172
|
+
self.mcp_server.run(transport=transport)
|
173
|
+
except KeyboardInterrupt:
|
174
|
+
print("\nServer stopped by user (CTRL+C)")
|
175
|
+
finally:
|
176
|
+
# Minimal cleanup for stdio
|
177
|
+
asyncio.run(self._cleanup_stdio())
|
153
178
|
|
154
179
|
async def run_async(
|
155
180
|
self, transport: str = "sse", host: str = "0.0.0.0", port: int = 8000
|
@@ -169,20 +194,26 @@ class AgentMCPServer:
|
|
169
194
|
try:
|
170
195
|
# Wait for the server task to complete
|
171
196
|
await self._server_task
|
172
|
-
except asyncio.CancelledError:
|
173
|
-
|
174
|
-
|
197
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
198
|
+
# Both cancellation and KeyboardInterrupt are expected during shutdown
|
199
|
+
logger.info("Server stopped via cancellation or interrupt")
|
200
|
+
print("\nServer stopped")
|
201
|
+
except SystemExit as e:
|
202
|
+
# Handle normal exit cleanly
|
203
|
+
logger.info(f"Server exiting with code {e.code}")
|
204
|
+
print(f"\nServer exiting with code {e.code}")
|
205
|
+
# If this is exit code 0, let it propagate for normal exit
|
206
|
+
if e.code == 0:
|
207
|
+
raise
|
175
208
|
except Exception as e:
|
176
209
|
logger.error(f"Server error: {e}", exc_info=True)
|
177
210
|
print(f"\nServer error: {e}")
|
178
211
|
finally:
|
179
|
-
#
|
180
|
-
await self.
|
181
|
-
logger.info("Server shutdown complete.")
|
212
|
+
# Only do minimal cleanup - don't try to be too clever
|
213
|
+
await self._cleanup_stdio()
|
182
214
|
print("\nServer shutdown complete.")
|
183
215
|
else: # stdio
|
184
216
|
# For STDIO, use simpler approach that respects STDIO lifecycle
|
185
|
-
# STDIO will naturally terminate when streams close
|
186
217
|
try:
|
187
218
|
# Run directly without extra monitoring or signal handlers
|
188
219
|
# This preserves the natural lifecycle of STDIO connections
|
@@ -190,9 +221,14 @@ class AgentMCPServer:
|
|
190
221
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
191
222
|
logger.info("Server stopped (CTRL+C)")
|
192
223
|
print("\nServer stopped (CTRL+C)")
|
193
|
-
|
224
|
+
except SystemExit as e:
|
225
|
+
# Handle normal exit cleanly
|
226
|
+
logger.info(f"Server exiting with code {e.code}")
|
227
|
+
print(f"\nServer exiting with code {e.code}")
|
228
|
+
# If this is exit code 0, let it propagate for normal exit
|
229
|
+
if e.code == 0:
|
230
|
+
raise
|
194
231
|
# Only perform minimal cleanup needed for STDIO
|
195
|
-
# Don't use our full shutdown procedure which could keep process alive
|
196
232
|
await self._cleanup_stdio()
|
197
233
|
|
198
234
|
async def _run_server_with_shutdown(self, transport: str):
|
@@ -246,7 +282,6 @@ class AgentMCPServer:
|
|
246
282
|
force_shutdown_task = asyncio.create_task(self._force_shutdown_event.wait())
|
247
283
|
timeout_task = asyncio.create_task(asyncio.sleep(self._shutdown_timeout))
|
248
284
|
|
249
|
-
# Wait for either force shutdown or timeout
|
250
285
|
done, pending = await asyncio.wait(
|
251
286
|
[force_shutdown_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
|
252
287
|
)
|
@@ -255,21 +290,16 @@ class AgentMCPServer:
|
|
255
290
|
for task in pending:
|
256
291
|
task.cancel()
|
257
292
|
|
258
|
-
# Determine
|
293
|
+
# Determine shutdown reason
|
259
294
|
if force_shutdown_task in done:
|
260
|
-
logger.info("Force shutdown requested")
|
261
|
-
print("\
|
295
|
+
logger.info("Force shutdown requested by user")
|
296
|
+
print("\nForce shutdown initiated...")
|
262
297
|
else:
|
263
298
|
logger.info(f"Graceful shutdown timed out after {self._shutdown_timeout} seconds")
|
264
299
|
print(f"\nGraceful shutdown timed out after {self._shutdown_timeout} seconds")
|
265
300
|
|
266
|
-
|
267
|
-
await self._close_sse_connections()
|
301
|
+
os._exit(0)
|
268
302
|
|
269
|
-
# Cancel the server task if running
|
270
|
-
if self._server_task and not self._server_task.done():
|
271
|
-
logger.info("Cancelling server task")
|
272
|
-
self._server_task.cancel()
|
273
303
|
except asyncio.CancelledError:
|
274
304
|
# Monitor was cancelled - clean exit
|
275
305
|
pass
|
@@ -302,11 +332,36 @@ class AgentMCPServer:
|
|
302
332
|
for session_id, writer in writers:
|
303
333
|
try:
|
304
334
|
logger.debug(f"Closing SSE connection: {session_id}")
|
335
|
+
# Instead of aclose, try to close more gracefully
|
336
|
+
# Send a special event to notify client, then close
|
337
|
+
try:
|
338
|
+
if hasattr(writer, "send") and not getattr(writer, "_closed", False):
|
339
|
+
try:
|
340
|
+
# Try to send a close event if possible
|
341
|
+
await writer.send(Exception("Server shutting down"))
|
342
|
+
except (AttributeError, asyncio.CancelledError):
|
343
|
+
pass
|
344
|
+
except Exception:
|
345
|
+
pass
|
346
|
+
|
347
|
+
# Now close the stream
|
305
348
|
await writer.aclose()
|
306
349
|
sse._read_stream_writers.pop(session_id, None)
|
307
350
|
except Exception as e:
|
308
351
|
logger.error(f"Error closing SSE connection {session_id}: {e}")
|
309
352
|
|
353
|
+
# If we have a ASGI lifespan hook, try to signal closure
|
354
|
+
if (
|
355
|
+
hasattr(self.mcp_server, "_lifespan_state")
|
356
|
+
and self.mcp_server._lifespan_state == "started"
|
357
|
+
):
|
358
|
+
logger.debug("Attempting to signal ASGI lifespan shutdown")
|
359
|
+
try:
|
360
|
+
if hasattr(self.mcp_server, "_on_shutdown"):
|
361
|
+
await self.mcp_server._on_shutdown()
|
362
|
+
except Exception as e:
|
363
|
+
logger.error(f"Error during ASGI lifespan shutdown: {e}")
|
364
|
+
|
310
365
|
async def with_bridged_context(self, agent_context, mcp_context, func, *args, **kwargs):
|
311
366
|
"""
|
312
367
|
Execute a function with bridged context between MCP and agent
|
@@ -374,18 +429,45 @@ class AgentMCPServer:
|
|
374
429
|
# Signal shutdown
|
375
430
|
self._graceful_shutdown_event.set()
|
376
431
|
|
377
|
-
|
378
|
-
|
432
|
+
try:
|
433
|
+
# Close SSE connections
|
434
|
+
await self._close_sse_connections()
|
379
435
|
|
380
|
-
|
381
|
-
|
436
|
+
# Close any resources in the exit stack
|
437
|
+
await self._exit_stack.aclose()
|
382
438
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
439
|
+
# Shutdown any agent resources
|
440
|
+
for agent_name, agent in self.agent_app._agents.items():
|
441
|
+
try:
|
442
|
+
if hasattr(agent, "shutdown"):
|
443
|
+
await agent.shutdown()
|
444
|
+
except Exception as e:
|
445
|
+
logger.error(f"Error shutting down agent {agent_name}: {e}")
|
446
|
+
except Exception as e:
|
447
|
+
# Log any errors but don't let them prevent shutdown
|
448
|
+
logger.error(f"Error during shutdown: {e}", exc_info=True)
|
449
|
+
finally:
|
450
|
+
logger.info("Full shutdown complete")
|
451
|
+
|
452
|
+
async def _cleanup_minimal(self):
|
453
|
+
"""Perform minimal cleanup before simulating a KeyboardInterrupt."""
|
454
|
+
logger.info("Performing minimal cleanup before interrupt")
|
455
|
+
|
456
|
+
# Only close SSE connection writers directly
|
457
|
+
if (
|
458
|
+
hasattr(self.mcp_server, "_sse_transport")
|
459
|
+
and self.mcp_server._sse_transport is not None
|
460
|
+
):
|
461
|
+
sse = self.mcp_server._sse_transport
|
462
|
+
|
463
|
+
# Close all read stream writers
|
464
|
+
if hasattr(sse, "_read_stream_writers"):
|
465
|
+
for session_id, writer in list(sse._read_stream_writers.items()):
|
466
|
+
try:
|
467
|
+
await writer.aclose()
|
468
|
+
except Exception:
|
469
|
+
# Ignore errors during cleanup
|
470
|
+
pass
|
390
471
|
|
391
|
-
|
472
|
+
# Clear active connections set to prevent further operations
|
473
|
+
self._active_connections.clear()
|
mcp_agent/cli/commands/config.py
DELETED
mcp_agent/executor/temporal.py
DELETED
@@ -1,383 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Temporal based orchestrator for the MCP Agent.
|
3
|
-
Temporal provides durable execution and robust workflow orchestration,
|
4
|
-
as well as dynamic control flow, making it a good choice for an AI agent orchestrator.
|
5
|
-
Read more: https://docs.temporal.io/develop/python/core-application
|
6
|
-
"""
|
7
|
-
|
8
|
-
import asyncio
|
9
|
-
import functools
|
10
|
-
import uuid
|
11
|
-
from typing import (
|
12
|
-
TYPE_CHECKING,
|
13
|
-
Any,
|
14
|
-
AsyncIterator,
|
15
|
-
Callable,
|
16
|
-
Coroutine,
|
17
|
-
Dict,
|
18
|
-
List,
|
19
|
-
Optional,
|
20
|
-
)
|
21
|
-
|
22
|
-
from pydantic import ConfigDict
|
23
|
-
from temporalio import activity, exceptions, workflow
|
24
|
-
from temporalio.client import Client as TemporalClient
|
25
|
-
from temporalio.worker import Worker
|
26
|
-
|
27
|
-
from mcp_agent.config import TemporalSettings
|
28
|
-
from mcp_agent.executor.executor import Executor, ExecutorConfig, R
|
29
|
-
from mcp_agent.executor.workflow_signal import (
|
30
|
-
BaseSignalHandler,
|
31
|
-
Signal,
|
32
|
-
SignalHandler,
|
33
|
-
SignalRegistration,
|
34
|
-
SignalValueT,
|
35
|
-
)
|
36
|
-
|
37
|
-
if TYPE_CHECKING:
|
38
|
-
from mcp_agent.context import Context
|
39
|
-
|
40
|
-
|
41
|
-
class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
42
|
-
"""Temporal-based signal handling using workflow signals"""
|
43
|
-
|
44
|
-
async def wait_for_signal(self, signal, timeout_seconds=None) -> SignalValueT:
|
45
|
-
if not workflow._Runtime.current():
|
46
|
-
raise RuntimeError(
|
47
|
-
"TemporalSignalHandler.wait_for_signal must be called from within a workflow"
|
48
|
-
)
|
49
|
-
|
50
|
-
unique_signal_name = f"{signal.name}_{uuid.uuid4()}"
|
51
|
-
registration = SignalRegistration(
|
52
|
-
signal_name=signal.name,
|
53
|
-
unique_name=unique_signal_name,
|
54
|
-
workflow_id=workflow.info().workflow_id,
|
55
|
-
)
|
56
|
-
|
57
|
-
# Container for signal value
|
58
|
-
container = {"value": None, "completed": False}
|
59
|
-
|
60
|
-
# Define the signal handler for this specific registration
|
61
|
-
@workflow.signal(name=unique_signal_name)
|
62
|
-
def signal_handler(value: SignalValueT) -> None:
|
63
|
-
container["value"] = value
|
64
|
-
container["completed"] = True
|
65
|
-
|
66
|
-
async with self._lock:
|
67
|
-
# Register both the signal registration and handler atomically
|
68
|
-
self._pending_signals.setdefault(signal.name, []).append(registration)
|
69
|
-
self._handlers.setdefault(signal.name, []).append((unique_signal_name, signal_handler))
|
70
|
-
|
71
|
-
try:
|
72
|
-
# Wait for signal with optional timeout
|
73
|
-
await workflow.wait_condition(lambda: container["completed"], timeout=timeout_seconds)
|
74
|
-
|
75
|
-
return container["value"]
|
76
|
-
except asyncio.TimeoutError as exc:
|
77
|
-
raise TimeoutError(f"Timeout waiting for signal {signal.name}") from exc
|
78
|
-
finally:
|
79
|
-
async with self._lock:
|
80
|
-
# Remove ourselves from _pending_signals
|
81
|
-
if signal.name in self._pending_signals:
|
82
|
-
self._pending_signals[signal.name] = [
|
83
|
-
sr
|
84
|
-
for sr in self._pending_signals[signal.name]
|
85
|
-
if sr.unique_name != unique_signal_name
|
86
|
-
]
|
87
|
-
if not self._pending_signals[signal.name]:
|
88
|
-
del self._pending_signals[signal.name]
|
89
|
-
|
90
|
-
# Remove ourselves from _handlers
|
91
|
-
if signal.name in self._handlers:
|
92
|
-
self._handlers[signal.name] = [
|
93
|
-
h for h in self._handlers[signal.name] if h[0] != unique_signal_name
|
94
|
-
]
|
95
|
-
if not self._handlers[signal.name]:
|
96
|
-
del self._handlers[signal.name]
|
97
|
-
|
98
|
-
def on_signal(self, signal_name):
|
99
|
-
"""Decorator to register a signal handler."""
|
100
|
-
|
101
|
-
def decorator(func: Callable) -> Callable:
|
102
|
-
# Create unique signal name for this handler
|
103
|
-
unique_signal_name = f"{signal_name}_{uuid.uuid4()}"
|
104
|
-
|
105
|
-
# Create the actual handler that will be registered with Temporal
|
106
|
-
@workflow.signal(name=unique_signal_name)
|
107
|
-
async def wrapped(signal_value: SignalValueT) -> None:
|
108
|
-
# Create a signal object to pass to the handler
|
109
|
-
signal = Signal(
|
110
|
-
name=signal_name,
|
111
|
-
payload=signal_value,
|
112
|
-
workflow_id=workflow.info().workflow_id,
|
113
|
-
)
|
114
|
-
if asyncio.iscoroutinefunction(func):
|
115
|
-
await func(signal)
|
116
|
-
else:
|
117
|
-
func(signal)
|
118
|
-
|
119
|
-
# Register the handler under the original signal name
|
120
|
-
self._handlers.setdefault(signal_name, []).append((unique_signal_name, wrapped))
|
121
|
-
return func
|
122
|
-
|
123
|
-
return decorator
|
124
|
-
|
125
|
-
async def signal(self, signal) -> None:
|
126
|
-
self.validate_signal(signal)
|
127
|
-
|
128
|
-
workflow_handle = workflow.get_external_workflow_handle(workflow_id=signal.workflow_id)
|
129
|
-
|
130
|
-
# Send the signal to all registrations of this signal
|
131
|
-
async with self._lock:
|
132
|
-
signal_tasks = []
|
133
|
-
|
134
|
-
if signal.name in self._pending_signals:
|
135
|
-
for pending_signal in self._pending_signals[signal.name]:
|
136
|
-
registration = pending_signal.registration
|
137
|
-
if registration.workflow_id == signal.workflow_id:
|
138
|
-
# Only signal for registrations of that workflow
|
139
|
-
signal_tasks.append(
|
140
|
-
workflow_handle.signal(registration.unique_name, signal.payload)
|
141
|
-
)
|
142
|
-
else:
|
143
|
-
continue
|
144
|
-
|
145
|
-
# Notify any registered handler functions
|
146
|
-
if signal.name in self._handlers:
|
147
|
-
for unique_name, _ in self._handlers[signal.name]:
|
148
|
-
signal_tasks.append(workflow_handle.signal(unique_name, signal.payload))
|
149
|
-
|
150
|
-
await asyncio.gather(*signal_tasks, return_exceptions=True)
|
151
|
-
|
152
|
-
def validate_signal(self, signal) -> None:
|
153
|
-
super().validate_signal(signal)
|
154
|
-
# Add TemporalSignalHandler-specific validation
|
155
|
-
if signal.workflow_id is None:
|
156
|
-
raise ValueError(
|
157
|
-
"No workflow_id provided on Signal. That is required for Temporal signals"
|
158
|
-
)
|
159
|
-
|
160
|
-
|
161
|
-
class TemporalExecutorConfig(ExecutorConfig, TemporalSettings):
|
162
|
-
"""Configuration for Temporal executors."""
|
163
|
-
|
164
|
-
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
165
|
-
|
166
|
-
|
167
|
-
class TemporalExecutor(Executor):
|
168
|
-
"""Executor that runs @workflows as Temporal workflows, with @workflow_tasks as Temporal activities"""
|
169
|
-
|
170
|
-
def __init__(
|
171
|
-
self,
|
172
|
-
config: TemporalExecutorConfig | None = None,
|
173
|
-
signal_bus: SignalHandler | None = None,
|
174
|
-
client: TemporalClient | None = None,
|
175
|
-
context: Optional["Context"] = None,
|
176
|
-
**kwargs,
|
177
|
-
) -> None:
|
178
|
-
signal_bus = signal_bus or TemporalSignalHandler()
|
179
|
-
super().__init__(
|
180
|
-
engine="temporal",
|
181
|
-
config=config,
|
182
|
-
signal_bus=signal_bus,
|
183
|
-
context=context,
|
184
|
-
**kwargs,
|
185
|
-
)
|
186
|
-
self.config: TemporalExecutorConfig = (
|
187
|
-
config or self.context.config.temporal or TemporalExecutorConfig()
|
188
|
-
)
|
189
|
-
self.client = client
|
190
|
-
self._worker = None
|
191
|
-
self._activity_semaphore = None
|
192
|
-
|
193
|
-
if config.max_concurrent_activities is not None:
|
194
|
-
self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
|
195
|
-
|
196
|
-
@staticmethod
|
197
|
-
def wrap_as_activity(
|
198
|
-
activity_name: str,
|
199
|
-
func: Callable[..., R] | Coroutine[Any, Any, R],
|
200
|
-
**kwargs: Any,
|
201
|
-
) -> Coroutine[Any, Any, R]:
|
202
|
-
"""
|
203
|
-
Convert a function into a Temporal activity and return its info.
|
204
|
-
"""
|
205
|
-
|
206
|
-
@activity.defn(name=activity_name)
|
207
|
-
async def wrapped_activity(*args, **local_kwargs):
|
208
|
-
try:
|
209
|
-
if asyncio.iscoroutinefunction(func):
|
210
|
-
return await func(*args, **local_kwargs)
|
211
|
-
elif asyncio.iscoroutine(func):
|
212
|
-
return await func
|
213
|
-
else:
|
214
|
-
return func(*args, **local_kwargs)
|
215
|
-
except Exception as e:
|
216
|
-
# Handle exceptions gracefully
|
217
|
-
raise e
|
218
|
-
|
219
|
-
return wrapped_activity
|
220
|
-
|
221
|
-
async def _execute_task_as_async(
|
222
|
-
self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
|
223
|
-
) -> R | BaseException:
|
224
|
-
async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R:
|
225
|
-
try:
|
226
|
-
if asyncio.iscoroutine(task):
|
227
|
-
return await task
|
228
|
-
elif asyncio.iscoroutinefunction(task):
|
229
|
-
return await task(**kwargs)
|
230
|
-
else:
|
231
|
-
# Execute the callable and await if it returns a coroutine
|
232
|
-
loop = asyncio.get_running_loop()
|
233
|
-
|
234
|
-
# If kwargs are provided, wrap the function with partial
|
235
|
-
if kwargs:
|
236
|
-
wrapped_task = functools.partial(task, **kwargs)
|
237
|
-
result = await loop.run_in_executor(None, wrapped_task)
|
238
|
-
else:
|
239
|
-
result = await loop.run_in_executor(None, task)
|
240
|
-
|
241
|
-
# Handle case where the sync function returns a coroutine
|
242
|
-
if asyncio.iscoroutine(result):
|
243
|
-
return await result
|
244
|
-
|
245
|
-
return result
|
246
|
-
except Exception as e:
|
247
|
-
# TODO: saqadri - adding logging or other error handling here
|
248
|
-
return e
|
249
|
-
|
250
|
-
if self._activity_semaphore:
|
251
|
-
async with self._activity_semaphore:
|
252
|
-
return await run_task(task)
|
253
|
-
else:
|
254
|
-
return await run_task(task)
|
255
|
-
|
256
|
-
async def _execute_task(
|
257
|
-
self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
|
258
|
-
) -> R | BaseException:
|
259
|
-
func = task.func if isinstance(task, functools.partial) else task
|
260
|
-
is_workflow_task = getattr(func, "is_workflow_task", False)
|
261
|
-
if not is_workflow_task:
|
262
|
-
return await asyncio.create_task(self._execute_task_as_async(task, **kwargs))
|
263
|
-
|
264
|
-
execution_metadata: Dict[str, Any] = getattr(func, "execution_metadata", {})
|
265
|
-
|
266
|
-
# Derive stable activity name, e.g. module + qualname
|
267
|
-
activity_name = execution_metadata.get("activity_name")
|
268
|
-
if not activity_name:
|
269
|
-
activity_name = f"{func.__module__}.{func.__qualname__}"
|
270
|
-
|
271
|
-
schedule_to_close = execution_metadata.get(
|
272
|
-
"schedule_to_close_timeout", self.config.timeout_seconds
|
273
|
-
)
|
274
|
-
|
275
|
-
retry_policy = execution_metadata.get("retry_policy", None)
|
276
|
-
|
277
|
-
_task_activity = self.wrap_as_activity(activity_name=activity_name, func=task)
|
278
|
-
|
279
|
-
# # For partials, we pass the partial's arguments
|
280
|
-
# args = task.args if isinstance(task, functools.partial) else ()
|
281
|
-
try:
|
282
|
-
result = await workflow.execute_activity(
|
283
|
-
activity_name,
|
284
|
-
args=kwargs.get("args", ()),
|
285
|
-
task_queue=self.config.task_queue,
|
286
|
-
schedule_to_close_timeout=schedule_to_close,
|
287
|
-
retry_policy=retry_policy,
|
288
|
-
**kwargs,
|
289
|
-
)
|
290
|
-
return result
|
291
|
-
except Exception as e:
|
292
|
-
# Properly propagate activity errors
|
293
|
-
if isinstance(e, exceptions.ActivityError):
|
294
|
-
raise e.cause if e.cause else e
|
295
|
-
raise
|
296
|
-
|
297
|
-
async def execute(
|
298
|
-
self,
|
299
|
-
*tasks: Callable[..., R] | Coroutine[Any, Any, R],
|
300
|
-
**kwargs: Any,
|
301
|
-
) -> List[R | BaseException]:
|
302
|
-
# Must be called from within a workflow
|
303
|
-
if not workflow._Runtime.current():
|
304
|
-
raise RuntimeError("TemporalExecutor.execute must be called from within a workflow")
|
305
|
-
|
306
|
-
# TODO: saqadri - validate if async with self.execution_context() is needed here
|
307
|
-
async with self.execution_context():
|
308
|
-
return await asyncio.gather(
|
309
|
-
*(self._execute_task(task, **kwargs) for task in tasks),
|
310
|
-
return_exceptions=True,
|
311
|
-
)
|
312
|
-
|
313
|
-
async def execute_streaming(
|
314
|
-
self,
|
315
|
-
*tasks: Callable[..., R] | Coroutine[Any, Any, R],
|
316
|
-
**kwargs: Any,
|
317
|
-
) -> AsyncIterator[R | BaseException]:
|
318
|
-
if not workflow._Runtime.current():
|
319
|
-
raise RuntimeError(
|
320
|
-
"TemporalExecutor.execute_streaming must be called from within a workflow"
|
321
|
-
)
|
322
|
-
|
323
|
-
# TODO: saqadri - validate if async with self.execution_context() is needed here
|
324
|
-
async with self.execution_context():
|
325
|
-
# Create futures for all tasks
|
326
|
-
futures = [self._execute_task(task, **kwargs) for task in tasks]
|
327
|
-
pending = set(futures)
|
328
|
-
|
329
|
-
while pending:
|
330
|
-
done, pending = await workflow.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
331
|
-
for future in done:
|
332
|
-
try:
|
333
|
-
result = await future
|
334
|
-
yield result
|
335
|
-
except Exception as e:
|
336
|
-
yield e
|
337
|
-
|
338
|
-
async def ensure_client(self):
|
339
|
-
"""Ensure we have a connected Temporal client."""
|
340
|
-
if self.client is None:
|
341
|
-
self.client = await TemporalClient.connect(
|
342
|
-
target_host=self.config.host,
|
343
|
-
namespace=self.config.namespace,
|
344
|
-
api_key=self.config.api_key,
|
345
|
-
)
|
346
|
-
|
347
|
-
return self.client
|
348
|
-
|
349
|
-
async def start_worker(self) -> None:
|
350
|
-
"""
|
351
|
-
Start a worker in this process, auto-registering all tasks
|
352
|
-
from the global registry. Also picks up any classes decorated
|
353
|
-
with @workflow_defn as recognized workflows.
|
354
|
-
"""
|
355
|
-
await self.ensure_client()
|
356
|
-
|
357
|
-
if self._worker is None:
|
358
|
-
# We'll collect the activities from the global registry
|
359
|
-
# and optionally wrap them with `activity.defn` if we want
|
360
|
-
# (Not strictly required if your code calls `execute_activity("name")` by name)
|
361
|
-
activity_registry = self.context.task_registry
|
362
|
-
activities = []
|
363
|
-
for name in activity_registry.list_activities():
|
364
|
-
activities.append(activity_registry.get_activity(name))
|
365
|
-
|
366
|
-
# Now we attempt to discover any classes that are recognized as workflows
|
367
|
-
# But in this simple example, we rely on the user specifying them or
|
368
|
-
# we might do a dynamic scan.
|
369
|
-
# For demonstration, we'll just assume the user is only using
|
370
|
-
# the workflow classes they decorate with `@workflow_defn`.
|
371
|
-
# We'll rely on them passing the classes or scanning your code.
|
372
|
-
|
373
|
-
self._worker = Worker(
|
374
|
-
client=self.client,
|
375
|
-
task_queue=self.config.task_queue,
|
376
|
-
activities=activities,
|
377
|
-
workflows=[], # We'll auto-load by Python scanning or let the user specify
|
378
|
-
)
|
379
|
-
print(
|
380
|
-
f"Starting Temporal Worker on task queue '{self.config.task_queue}' with {len(activities)} activities."
|
381
|
-
)
|
382
|
-
|
383
|
-
await self._worker.run()
|