fast-agent-mcp 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/METADATA +1 -1
  2. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/RECORD +33 -33
  3. mcp_agent/agents/agent.py +2 -2
  4. mcp_agent/agents/base_agent.py +3 -3
  5. mcp_agent/agents/workflow/chain_agent.py +2 -2
  6. mcp_agent/agents/workflow/evaluator_optimizer.py +3 -3
  7. mcp_agent/agents/workflow/orchestrator_agent.py +3 -3
  8. mcp_agent/agents/workflow/parallel_agent.py +2 -2
  9. mcp_agent/agents/workflow/router_agent.py +2 -2
  10. mcp_agent/cli/commands/check_config.py +450 -0
  11. mcp_agent/cli/commands/setup.py +1 -1
  12. mcp_agent/cli/main.py +8 -15
  13. mcp_agent/core/agent_types.py +8 -8
  14. mcp_agent/core/direct_decorators.py +10 -8
  15. mcp_agent/core/direct_factory.py +4 -1
  16. mcp_agent/core/validation.py +6 -4
  17. mcp_agent/event_progress.py +6 -6
  18. mcp_agent/llm/augmented_llm.py +10 -2
  19. mcp_agent/llm/augmented_llm_passthrough.py +5 -3
  20. mcp_agent/llm/augmented_llm_playback.py +2 -1
  21. mcp_agent/llm/model_factory.py +7 -27
  22. mcp_agent/llm/provider_key_manager.py +83 -0
  23. mcp_agent/llm/provider_types.py +16 -0
  24. mcp_agent/llm/providers/augmented_llm_anthropic.py +5 -26
  25. mcp_agent/llm/providers/augmented_llm_deepseek.py +5 -24
  26. mcp_agent/llm/providers/augmented_llm_generic.py +2 -16
  27. mcp_agent/llm/providers/augmented_llm_openai.py +4 -26
  28. mcp_agent/llm/providers/augmented_llm_openrouter.py +17 -45
  29. mcp_agent/mcp/interfaces.py +2 -1
  30. mcp_agent/mcp_server/agent_server.py +120 -38
  31. mcp_agent/cli/commands/config.py +0 -11
  32. mcp_agent/executor/temporal.py +0 -383
  33. mcp_agent/executor/workflow.py +0 -195
  34. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/WHEEL +0 -0
  35. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/entry_points.txt +0 -0
  36. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.14.dist-info}/licenses/LICENSE +0 -0
@@ -143,13 +143,38 @@ class AgentMCPServer:
143
143
  self.mcp_server.settings.host = host
144
144
  self.mcp_server.settings.port = port
145
145
 
146
- try:
147
- self.mcp_server.run(transport=transport)
148
- except KeyboardInterrupt:
149
- print("\nServer stopped by user (CTRL+C)")
150
- finally:
151
- # Run an async cleanup in a new event loop
152
- asyncio.run(self.shutdown())
146
+ # For synchronous run, we can use the simpler approach
147
+ try:
148
+ # Add any server attributes that might help with shutdown
149
+ if not hasattr(self.mcp_server, "_server_should_exit"):
150
+ self.mcp_server._server_should_exit = False
151
+
152
+ # Run the server
153
+ self.mcp_server.run(transport=transport)
154
+ except KeyboardInterrupt:
155
+ print("\nServer stopped by user (CTRL+C)")
156
+ except SystemExit as e:
157
+ # Handle normal exit
158
+ print(f"\nServer exiting with code {e.code}")
159
+ # Re-raise to allow normal exit process
160
+ raise
161
+ except Exception as e:
162
+ print(f"\nServer error: {e}")
163
+ finally:
164
+ # Run an async cleanup in a new event loop
165
+ try:
166
+ asyncio.run(self.shutdown())
167
+ except (SystemExit, KeyboardInterrupt):
168
+ # These are expected during shutdown
169
+ pass
170
+ else: # stdio
171
+ try:
172
+ self.mcp_server.run(transport=transport)
173
+ except KeyboardInterrupt:
174
+ print("\nServer stopped by user (CTRL+C)")
175
+ finally:
176
+ # Minimal cleanup for stdio
177
+ asyncio.run(self._cleanup_stdio())
153
178
 
154
179
  async def run_async(
155
180
  self, transport: str = "sse", host: str = "0.0.0.0", port: int = 8000
@@ -169,20 +194,26 @@ class AgentMCPServer:
169
194
  try:
170
195
  # Wait for the server task to complete
171
196
  await self._server_task
172
- except asyncio.CancelledError:
173
- logger.info("Server task cancelled.")
174
- print("\nServer task cancelled.")
197
+ except (asyncio.CancelledError, KeyboardInterrupt):
198
+ # Both cancellation and KeyboardInterrupt are expected during shutdown
199
+ logger.info("Server stopped via cancellation or interrupt")
200
+ print("\nServer stopped")
201
+ except SystemExit as e:
202
+ # Handle normal exit cleanly
203
+ logger.info(f"Server exiting with code {e.code}")
204
+ print(f"\nServer exiting with code {e.code}")
205
+ # If this is exit code 0, let it propagate for normal exit
206
+ if e.code == 0:
207
+ raise
175
208
  except Exception as e:
176
209
  logger.error(f"Server error: {e}", exc_info=True)
177
210
  print(f"\nServer error: {e}")
178
211
  finally:
179
- # Ensure cleanup happens
180
- await self.shutdown()
181
- logger.info("Server shutdown complete.")
212
+ # Only do minimal cleanup - don't try to be too clever
213
+ await self._cleanup_stdio()
182
214
  print("\nServer shutdown complete.")
183
215
  else: # stdio
184
216
  # For STDIO, use simpler approach that respects STDIO lifecycle
185
- # STDIO will naturally terminate when streams close
186
217
  try:
187
218
  # Run directly without extra monitoring or signal handlers
188
219
  # This preserves the natural lifecycle of STDIO connections
@@ -190,9 +221,14 @@ class AgentMCPServer:
190
221
  except (asyncio.CancelledError, KeyboardInterrupt):
191
222
  logger.info("Server stopped (CTRL+C)")
192
223
  print("\nServer stopped (CTRL+C)")
193
-
224
+ except SystemExit as e:
225
+ # Handle normal exit cleanly
226
+ logger.info(f"Server exiting with code {e.code}")
227
+ print(f"\nServer exiting with code {e.code}")
228
+ # If this is exit code 0, let it propagate for normal exit
229
+ if e.code == 0:
230
+ raise
194
231
  # Only perform minimal cleanup needed for STDIO
195
- # Don't use our full shutdown procedure which could keep process alive
196
232
  await self._cleanup_stdio()
197
233
 
198
234
  async def _run_server_with_shutdown(self, transport: str):
@@ -246,7 +282,6 @@ class AgentMCPServer:
246
282
  force_shutdown_task = asyncio.create_task(self._force_shutdown_event.wait())
247
283
  timeout_task = asyncio.create_task(asyncio.sleep(self._shutdown_timeout))
248
284
 
249
- # Wait for either force shutdown or timeout
250
285
  done, pending = await asyncio.wait(
251
286
  [force_shutdown_task, timeout_task], return_when=asyncio.FIRST_COMPLETED
252
287
  )
@@ -255,21 +290,16 @@ class AgentMCPServer:
255
290
  for task in pending:
256
291
  task.cancel()
257
292
 
258
- # Determine the shutdown reason
293
+ # Determine shutdown reason
259
294
  if force_shutdown_task in done:
260
- logger.info("Force shutdown requested")
261
- print("\nForced shutdown initiated...")
295
+ logger.info("Force shutdown requested by user")
296
+ print("\nForce shutdown initiated...")
262
297
  else:
263
298
  logger.info(f"Graceful shutdown timed out after {self._shutdown_timeout} seconds")
264
299
  print(f"\nGraceful shutdown timed out after {self._shutdown_timeout} seconds")
265
300
 
266
- # Force close any remaining SSE connections
267
- await self._close_sse_connections()
301
+ os._exit(0)
268
302
 
269
- # Cancel the server task if running
270
- if self._server_task and not self._server_task.done():
271
- logger.info("Cancelling server task")
272
- self._server_task.cancel()
273
303
  except asyncio.CancelledError:
274
304
  # Monitor was cancelled - clean exit
275
305
  pass
@@ -302,11 +332,36 @@ class AgentMCPServer:
302
332
  for session_id, writer in writers:
303
333
  try:
304
334
  logger.debug(f"Closing SSE connection: {session_id}")
335
+ # Instead of aclose, try to close more gracefully
336
+ # Send a special event to notify client, then close
337
+ try:
338
+ if hasattr(writer, "send") and not getattr(writer, "_closed", False):
339
+ try:
340
+ # Try to send a close event if possible
341
+ await writer.send(Exception("Server shutting down"))
342
+ except (AttributeError, asyncio.CancelledError):
343
+ pass
344
+ except Exception:
345
+ pass
346
+
347
+ # Now close the stream
305
348
  await writer.aclose()
306
349
  sse._read_stream_writers.pop(session_id, None)
307
350
  except Exception as e:
308
351
  logger.error(f"Error closing SSE connection {session_id}: {e}")
309
352
 
353
+ # If we have a ASGI lifespan hook, try to signal closure
354
+ if (
355
+ hasattr(self.mcp_server, "_lifespan_state")
356
+ and self.mcp_server._lifespan_state == "started"
357
+ ):
358
+ logger.debug("Attempting to signal ASGI lifespan shutdown")
359
+ try:
360
+ if hasattr(self.mcp_server, "_on_shutdown"):
361
+ await self.mcp_server._on_shutdown()
362
+ except Exception as e:
363
+ logger.error(f"Error during ASGI lifespan shutdown: {e}")
364
+
310
365
  async def with_bridged_context(self, agent_context, mcp_context, func, *args, **kwargs):
311
366
  """
312
367
  Execute a function with bridged context between MCP and agent
@@ -374,18 +429,45 @@ class AgentMCPServer:
374
429
  # Signal shutdown
375
430
  self._graceful_shutdown_event.set()
376
431
 
377
- # Close SSE connections
378
- await self._close_sse_connections()
432
+ try:
433
+ # Close SSE connections
434
+ await self._close_sse_connections()
379
435
 
380
- # Close any resources in the exit stack
381
- await self._exit_stack.aclose()
436
+ # Close any resources in the exit stack
437
+ await self._exit_stack.aclose()
382
438
 
383
- # Shutdown any agent resources
384
- for agent_name, agent in self.agent_app._agents.items():
385
- try:
386
- if hasattr(agent, "shutdown"):
387
- await agent.shutdown()
388
- except Exception as e:
389
- logger.error(f"Error shutting down agent {agent_name}: {e}")
439
+ # Shutdown any agent resources
440
+ for agent_name, agent in self.agent_app._agents.items():
441
+ try:
442
+ if hasattr(agent, "shutdown"):
443
+ await agent.shutdown()
444
+ except Exception as e:
445
+ logger.error(f"Error shutting down agent {agent_name}: {e}")
446
+ except Exception as e:
447
+ # Log any errors but don't let them prevent shutdown
448
+ logger.error(f"Error during shutdown: {e}", exc_info=True)
449
+ finally:
450
+ logger.info("Full shutdown complete")
451
+
452
+ async def _cleanup_minimal(self):
453
+ """Perform minimal cleanup before simulating a KeyboardInterrupt."""
454
+ logger.info("Performing minimal cleanup before interrupt")
455
+
456
+ # Only close SSE connection writers directly
457
+ if (
458
+ hasattr(self.mcp_server, "_sse_transport")
459
+ and self.mcp_server._sse_transport is not None
460
+ ):
461
+ sse = self.mcp_server._sse_transport
462
+
463
+ # Close all read stream writers
464
+ if hasattr(sse, "_read_stream_writers"):
465
+ for session_id, writer in list(sse._read_stream_writers.items()):
466
+ try:
467
+ await writer.aclose()
468
+ except Exception:
469
+ # Ignore errors during cleanup
470
+ pass
390
471
 
391
- logger.info("Full shutdown complete")
472
+ # Clear active connections set to prevent further operations
473
+ self._active_connections.clear()
@@ -1,11 +0,0 @@
1
- from typing import NoReturn
2
-
3
- import typer
4
-
5
- app = typer.Typer()
6
-
7
-
8
- @app.command()
9
- def show() -> NoReturn:
10
- """Show the configuration."""
11
- raise NotImplementedError("The show configuration command has not been implemented yet")
@@ -1,383 +0,0 @@
1
- """
2
- Temporal based orchestrator for the MCP Agent.
3
- Temporal provides durable execution and robust workflow orchestration,
4
- as well as dynamic control flow, making it a good choice for an AI agent orchestrator.
5
- Read more: https://docs.temporal.io/develop/python/core-application
6
- """
7
-
8
- import asyncio
9
- import functools
10
- import uuid
11
- from typing import (
12
- TYPE_CHECKING,
13
- Any,
14
- AsyncIterator,
15
- Callable,
16
- Coroutine,
17
- Dict,
18
- List,
19
- Optional,
20
- )
21
-
22
- from pydantic import ConfigDict
23
- from temporalio import activity, exceptions, workflow
24
- from temporalio.client import Client as TemporalClient
25
- from temporalio.worker import Worker
26
-
27
- from mcp_agent.config import TemporalSettings
28
- from mcp_agent.executor.executor import Executor, ExecutorConfig, R
29
- from mcp_agent.executor.workflow_signal import (
30
- BaseSignalHandler,
31
- Signal,
32
- SignalHandler,
33
- SignalRegistration,
34
- SignalValueT,
35
- )
36
-
37
- if TYPE_CHECKING:
38
- from mcp_agent.context import Context
39
-
40
-
41
- class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
42
- """Temporal-based signal handling using workflow signals"""
43
-
44
- async def wait_for_signal(self, signal, timeout_seconds=None) -> SignalValueT:
45
- if not workflow._Runtime.current():
46
- raise RuntimeError(
47
- "TemporalSignalHandler.wait_for_signal must be called from within a workflow"
48
- )
49
-
50
- unique_signal_name = f"{signal.name}_{uuid.uuid4()}"
51
- registration = SignalRegistration(
52
- signal_name=signal.name,
53
- unique_name=unique_signal_name,
54
- workflow_id=workflow.info().workflow_id,
55
- )
56
-
57
- # Container for signal value
58
- container = {"value": None, "completed": False}
59
-
60
- # Define the signal handler for this specific registration
61
- @workflow.signal(name=unique_signal_name)
62
- def signal_handler(value: SignalValueT) -> None:
63
- container["value"] = value
64
- container["completed"] = True
65
-
66
- async with self._lock:
67
- # Register both the signal registration and handler atomically
68
- self._pending_signals.setdefault(signal.name, []).append(registration)
69
- self._handlers.setdefault(signal.name, []).append((unique_signal_name, signal_handler))
70
-
71
- try:
72
- # Wait for signal with optional timeout
73
- await workflow.wait_condition(lambda: container["completed"], timeout=timeout_seconds)
74
-
75
- return container["value"]
76
- except asyncio.TimeoutError as exc:
77
- raise TimeoutError(f"Timeout waiting for signal {signal.name}") from exc
78
- finally:
79
- async with self._lock:
80
- # Remove ourselves from _pending_signals
81
- if signal.name in self._pending_signals:
82
- self._pending_signals[signal.name] = [
83
- sr
84
- for sr in self._pending_signals[signal.name]
85
- if sr.unique_name != unique_signal_name
86
- ]
87
- if not self._pending_signals[signal.name]:
88
- del self._pending_signals[signal.name]
89
-
90
- # Remove ourselves from _handlers
91
- if signal.name in self._handlers:
92
- self._handlers[signal.name] = [
93
- h for h in self._handlers[signal.name] if h[0] != unique_signal_name
94
- ]
95
- if not self._handlers[signal.name]:
96
- del self._handlers[signal.name]
97
-
98
- def on_signal(self, signal_name):
99
- """Decorator to register a signal handler."""
100
-
101
- def decorator(func: Callable) -> Callable:
102
- # Create unique signal name for this handler
103
- unique_signal_name = f"{signal_name}_{uuid.uuid4()}"
104
-
105
- # Create the actual handler that will be registered with Temporal
106
- @workflow.signal(name=unique_signal_name)
107
- async def wrapped(signal_value: SignalValueT) -> None:
108
- # Create a signal object to pass to the handler
109
- signal = Signal(
110
- name=signal_name,
111
- payload=signal_value,
112
- workflow_id=workflow.info().workflow_id,
113
- )
114
- if asyncio.iscoroutinefunction(func):
115
- await func(signal)
116
- else:
117
- func(signal)
118
-
119
- # Register the handler under the original signal name
120
- self._handlers.setdefault(signal_name, []).append((unique_signal_name, wrapped))
121
- return func
122
-
123
- return decorator
124
-
125
- async def signal(self, signal) -> None:
126
- self.validate_signal(signal)
127
-
128
- workflow_handle = workflow.get_external_workflow_handle(workflow_id=signal.workflow_id)
129
-
130
- # Send the signal to all registrations of this signal
131
- async with self._lock:
132
- signal_tasks = []
133
-
134
- if signal.name in self._pending_signals:
135
- for pending_signal in self._pending_signals[signal.name]:
136
- registration = pending_signal.registration
137
- if registration.workflow_id == signal.workflow_id:
138
- # Only signal for registrations of that workflow
139
- signal_tasks.append(
140
- workflow_handle.signal(registration.unique_name, signal.payload)
141
- )
142
- else:
143
- continue
144
-
145
- # Notify any registered handler functions
146
- if signal.name in self._handlers:
147
- for unique_name, _ in self._handlers[signal.name]:
148
- signal_tasks.append(workflow_handle.signal(unique_name, signal.payload))
149
-
150
- await asyncio.gather(*signal_tasks, return_exceptions=True)
151
-
152
- def validate_signal(self, signal) -> None:
153
- super().validate_signal(signal)
154
- # Add TemporalSignalHandler-specific validation
155
- if signal.workflow_id is None:
156
- raise ValueError(
157
- "No workflow_id provided on Signal. That is required for Temporal signals"
158
- )
159
-
160
-
161
- class TemporalExecutorConfig(ExecutorConfig, TemporalSettings):
162
- """Configuration for Temporal executors."""
163
-
164
- model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
165
-
166
-
167
- class TemporalExecutor(Executor):
168
- """Executor that runs @workflows as Temporal workflows, with @workflow_tasks as Temporal activities"""
169
-
170
- def __init__(
171
- self,
172
- config: TemporalExecutorConfig | None = None,
173
- signal_bus: SignalHandler | None = None,
174
- client: TemporalClient | None = None,
175
- context: Optional["Context"] = None,
176
- **kwargs,
177
- ) -> None:
178
- signal_bus = signal_bus or TemporalSignalHandler()
179
- super().__init__(
180
- engine="temporal",
181
- config=config,
182
- signal_bus=signal_bus,
183
- context=context,
184
- **kwargs,
185
- )
186
- self.config: TemporalExecutorConfig = (
187
- config or self.context.config.temporal or TemporalExecutorConfig()
188
- )
189
- self.client = client
190
- self._worker = None
191
- self._activity_semaphore = None
192
-
193
- if config.max_concurrent_activities is not None:
194
- self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
195
-
196
- @staticmethod
197
- def wrap_as_activity(
198
- activity_name: str,
199
- func: Callable[..., R] | Coroutine[Any, Any, R],
200
- **kwargs: Any,
201
- ) -> Coroutine[Any, Any, R]:
202
- """
203
- Convert a function into a Temporal activity and return its info.
204
- """
205
-
206
- @activity.defn(name=activity_name)
207
- async def wrapped_activity(*args, **local_kwargs):
208
- try:
209
- if asyncio.iscoroutinefunction(func):
210
- return await func(*args, **local_kwargs)
211
- elif asyncio.iscoroutine(func):
212
- return await func
213
- else:
214
- return func(*args, **local_kwargs)
215
- except Exception as e:
216
- # Handle exceptions gracefully
217
- raise e
218
-
219
- return wrapped_activity
220
-
221
- async def _execute_task_as_async(
222
- self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
223
- ) -> R | BaseException:
224
- async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R:
225
- try:
226
- if asyncio.iscoroutine(task):
227
- return await task
228
- elif asyncio.iscoroutinefunction(task):
229
- return await task(**kwargs)
230
- else:
231
- # Execute the callable and await if it returns a coroutine
232
- loop = asyncio.get_running_loop()
233
-
234
- # If kwargs are provided, wrap the function with partial
235
- if kwargs:
236
- wrapped_task = functools.partial(task, **kwargs)
237
- result = await loop.run_in_executor(None, wrapped_task)
238
- else:
239
- result = await loop.run_in_executor(None, task)
240
-
241
- # Handle case where the sync function returns a coroutine
242
- if asyncio.iscoroutine(result):
243
- return await result
244
-
245
- return result
246
- except Exception as e:
247
- # TODO: saqadri - adding logging or other error handling here
248
- return e
249
-
250
- if self._activity_semaphore:
251
- async with self._activity_semaphore:
252
- return await run_task(task)
253
- else:
254
- return await run_task(task)
255
-
256
- async def _execute_task(
257
- self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
258
- ) -> R | BaseException:
259
- func = task.func if isinstance(task, functools.partial) else task
260
- is_workflow_task = getattr(func, "is_workflow_task", False)
261
- if not is_workflow_task:
262
- return await asyncio.create_task(self._execute_task_as_async(task, **kwargs))
263
-
264
- execution_metadata: Dict[str, Any] = getattr(func, "execution_metadata", {})
265
-
266
- # Derive stable activity name, e.g. module + qualname
267
- activity_name = execution_metadata.get("activity_name")
268
- if not activity_name:
269
- activity_name = f"{func.__module__}.{func.__qualname__}"
270
-
271
- schedule_to_close = execution_metadata.get(
272
- "schedule_to_close_timeout", self.config.timeout_seconds
273
- )
274
-
275
- retry_policy = execution_metadata.get("retry_policy", None)
276
-
277
- _task_activity = self.wrap_as_activity(activity_name=activity_name, func=task)
278
-
279
- # # For partials, we pass the partial's arguments
280
- # args = task.args if isinstance(task, functools.partial) else ()
281
- try:
282
- result = await workflow.execute_activity(
283
- activity_name,
284
- args=kwargs.get("args", ()),
285
- task_queue=self.config.task_queue,
286
- schedule_to_close_timeout=schedule_to_close,
287
- retry_policy=retry_policy,
288
- **kwargs,
289
- )
290
- return result
291
- except Exception as e:
292
- # Properly propagate activity errors
293
- if isinstance(e, exceptions.ActivityError):
294
- raise e.cause if e.cause else e
295
- raise
296
-
297
- async def execute(
298
- self,
299
- *tasks: Callable[..., R] | Coroutine[Any, Any, R],
300
- **kwargs: Any,
301
- ) -> List[R | BaseException]:
302
- # Must be called from within a workflow
303
- if not workflow._Runtime.current():
304
- raise RuntimeError("TemporalExecutor.execute must be called from within a workflow")
305
-
306
- # TODO: saqadri - validate if async with self.execution_context() is needed here
307
- async with self.execution_context():
308
- return await asyncio.gather(
309
- *(self._execute_task(task, **kwargs) for task in tasks),
310
- return_exceptions=True,
311
- )
312
-
313
- async def execute_streaming(
314
- self,
315
- *tasks: Callable[..., R] | Coroutine[Any, Any, R],
316
- **kwargs: Any,
317
- ) -> AsyncIterator[R | BaseException]:
318
- if not workflow._Runtime.current():
319
- raise RuntimeError(
320
- "TemporalExecutor.execute_streaming must be called from within a workflow"
321
- )
322
-
323
- # TODO: saqadri - validate if async with self.execution_context() is needed here
324
- async with self.execution_context():
325
- # Create futures for all tasks
326
- futures = [self._execute_task(task, **kwargs) for task in tasks]
327
- pending = set(futures)
328
-
329
- while pending:
330
- done, pending = await workflow.wait(pending, return_when=asyncio.FIRST_COMPLETED)
331
- for future in done:
332
- try:
333
- result = await future
334
- yield result
335
- except Exception as e:
336
- yield e
337
-
338
- async def ensure_client(self):
339
- """Ensure we have a connected Temporal client."""
340
- if self.client is None:
341
- self.client = await TemporalClient.connect(
342
- target_host=self.config.host,
343
- namespace=self.config.namespace,
344
- api_key=self.config.api_key,
345
- )
346
-
347
- return self.client
348
-
349
- async def start_worker(self) -> None:
350
- """
351
- Start a worker in this process, auto-registering all tasks
352
- from the global registry. Also picks up any classes decorated
353
- with @workflow_defn as recognized workflows.
354
- """
355
- await self.ensure_client()
356
-
357
- if self._worker is None:
358
- # We'll collect the activities from the global registry
359
- # and optionally wrap them with `activity.defn` if we want
360
- # (Not strictly required if your code calls `execute_activity("name")` by name)
361
- activity_registry = self.context.task_registry
362
- activities = []
363
- for name in activity_registry.list_activities():
364
- activities.append(activity_registry.get_activity(name))
365
-
366
- # Now we attempt to discover any classes that are recognized as workflows
367
- # But in this simple example, we rely on the user specifying them or
368
- # we might do a dynamic scan.
369
- # For demonstration, we'll just assume the user is only using
370
- # the workflow classes they decorate with `@workflow_defn`.
371
- # We'll rely on them passing the classes or scanning your code.
372
-
373
- self._worker = Worker(
374
- client=self.client,
375
- task_queue=self.config.task_queue,
376
- activities=activities,
377
- workflows=[], # We'll auto-load by Python scanning or let the user specify
378
- )
379
- print(
380
- f"Starting Temporal Worker on task queue '{self.config.task_queue}' with {len(activities)} activities."
381
- )
382
-
383
- await self._worker.run()