fast-agent-mcp 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

Files changed (100) hide show
  1. fast_agent_mcp-0.0.7.dist-info/METADATA +322 -0
  2. fast_agent_mcp-0.0.7.dist-info/RECORD +100 -0
  3. fast_agent_mcp-0.0.7.dist-info/WHEEL +4 -0
  4. fast_agent_mcp-0.0.7.dist-info/entry_points.txt +5 -0
  5. fast_agent_mcp-0.0.7.dist-info/licenses/LICENSE +201 -0
  6. mcp_agent/__init__.py +0 -0
  7. mcp_agent/agents/__init__.py +0 -0
  8. mcp_agent/agents/agent.py +277 -0
  9. mcp_agent/app.py +303 -0
  10. mcp_agent/cli/__init__.py +0 -0
  11. mcp_agent/cli/__main__.py +4 -0
  12. mcp_agent/cli/commands/bootstrap.py +221 -0
  13. mcp_agent/cli/commands/config.py +11 -0
  14. mcp_agent/cli/commands/setup.py +229 -0
  15. mcp_agent/cli/main.py +68 -0
  16. mcp_agent/cli/terminal.py +24 -0
  17. mcp_agent/config.py +334 -0
  18. mcp_agent/console.py +28 -0
  19. mcp_agent/context.py +251 -0
  20. mcp_agent/context_dependent.py +48 -0
  21. mcp_agent/core/fastagent.py +1013 -0
  22. mcp_agent/eval/__init__.py +0 -0
  23. mcp_agent/event_progress.py +88 -0
  24. mcp_agent/executor/__init__.py +0 -0
  25. mcp_agent/executor/decorator_registry.py +120 -0
  26. mcp_agent/executor/executor.py +293 -0
  27. mcp_agent/executor/task_registry.py +34 -0
  28. mcp_agent/executor/temporal.py +405 -0
  29. mcp_agent/executor/workflow.py +197 -0
  30. mcp_agent/executor/workflow_signal.py +325 -0
  31. mcp_agent/human_input/__init__.py +0 -0
  32. mcp_agent/human_input/handler.py +49 -0
  33. mcp_agent/human_input/types.py +58 -0
  34. mcp_agent/logging/__init__.py +0 -0
  35. mcp_agent/logging/events.py +123 -0
  36. mcp_agent/logging/json_serializer.py +163 -0
  37. mcp_agent/logging/listeners.py +216 -0
  38. mcp_agent/logging/logger.py +365 -0
  39. mcp_agent/logging/rich_progress.py +120 -0
  40. mcp_agent/logging/tracing.py +140 -0
  41. mcp_agent/logging/transport.py +461 -0
  42. mcp_agent/mcp/__init__.py +0 -0
  43. mcp_agent/mcp/gen_client.py +85 -0
  44. mcp_agent/mcp/mcp_activity.py +18 -0
  45. mcp_agent/mcp/mcp_agent_client_session.py +242 -0
  46. mcp_agent/mcp/mcp_agent_server.py +56 -0
  47. mcp_agent/mcp/mcp_aggregator.py +394 -0
  48. mcp_agent/mcp/mcp_connection_manager.py +330 -0
  49. mcp_agent/mcp/stdio.py +104 -0
  50. mcp_agent/mcp_server_registry.py +275 -0
  51. mcp_agent/progress_display.py +10 -0
  52. mcp_agent/resources/examples/decorator/main.py +26 -0
  53. mcp_agent/resources/examples/decorator/optimizer.py +78 -0
  54. mcp_agent/resources/examples/decorator/orchestrator.py +68 -0
  55. mcp_agent/resources/examples/decorator/parallel.py +81 -0
  56. mcp_agent/resources/examples/decorator/router.py +56 -0
  57. mcp_agent/resources/examples/decorator/tiny.py +22 -0
  58. mcp_agent/resources/examples/mcp_researcher/main-evalopt.py +53 -0
  59. mcp_agent/resources/examples/mcp_researcher/main.py +38 -0
  60. mcp_agent/telemetry/__init__.py +0 -0
  61. mcp_agent/telemetry/usage_tracking.py +18 -0
  62. mcp_agent/workflows/__init__.py +0 -0
  63. mcp_agent/workflows/embedding/__init__.py +0 -0
  64. mcp_agent/workflows/embedding/embedding_base.py +61 -0
  65. mcp_agent/workflows/embedding/embedding_cohere.py +49 -0
  66. mcp_agent/workflows/embedding/embedding_openai.py +46 -0
  67. mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
  68. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +359 -0
  69. mcp_agent/workflows/intent_classifier/__init__.py +0 -0
  70. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +120 -0
  71. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +134 -0
  72. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +45 -0
  73. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +45 -0
  74. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +161 -0
  75. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +60 -0
  76. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +60 -0
  77. mcp_agent/workflows/llm/__init__.py +0 -0
  78. mcp_agent/workflows/llm/augmented_llm.py +645 -0
  79. mcp_agent/workflows/llm/augmented_llm_anthropic.py +539 -0
  80. mcp_agent/workflows/llm/augmented_llm_openai.py +615 -0
  81. mcp_agent/workflows/llm/llm_selector.py +345 -0
  82. mcp_agent/workflows/llm/model_factory.py +175 -0
  83. mcp_agent/workflows/orchestrator/__init__.py +0 -0
  84. mcp_agent/workflows/orchestrator/orchestrator.py +407 -0
  85. mcp_agent/workflows/orchestrator/orchestrator_models.py +154 -0
  86. mcp_agent/workflows/orchestrator/orchestrator_prompts.py +113 -0
  87. mcp_agent/workflows/parallel/__init__.py +0 -0
  88. mcp_agent/workflows/parallel/fan_in.py +350 -0
  89. mcp_agent/workflows/parallel/fan_out.py +187 -0
  90. mcp_agent/workflows/parallel/parallel_llm.py +141 -0
  91. mcp_agent/workflows/router/__init__.py +0 -0
  92. mcp_agent/workflows/router/router_base.py +276 -0
  93. mcp_agent/workflows/router/router_embedding.py +240 -0
  94. mcp_agent/workflows/router/router_embedding_cohere.py +59 -0
  95. mcp_agent/workflows/router/router_embedding_openai.py +59 -0
  96. mcp_agent/workflows/router/router_llm.py +301 -0
  97. mcp_agent/workflows/swarm/__init__.py +0 -0
  98. mcp_agent/workflows/swarm/swarm.py +320 -0
  99. mcp_agent/workflows/swarm/swarm_anthropic.py +42 -0
  100. mcp_agent/workflows/swarm/swarm_openai.py +41 -0
File without changes
@@ -0,0 +1,88 @@
1
+ """Module for converting log events to progress events."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Optional
6
+
7
+ from mcp_agent.logging.events import Event
8
+
9
+
10
+ class ProgressAction(str, Enum):
11
+ """Progress actions available in the system."""
12
+
13
+ STARTING = "Starting"
14
+ INITIALIZED = "Initialized"
15
+ CHATTING = "Chatting"
16
+ READY = "Ready"
17
+ CALLING_TOOL = "Calling Tool"
18
+ FINISHED = "Finished"
19
+ SHUTDOWN = "Shutdown"
20
+ AGGREGATOR_INITIALIZED = "Running"
21
+ ROUTING = "Routing"
22
+
23
+
24
+ @dataclass
25
+ class ProgressEvent:
26
+ """Represents a progress event converted from a log event."""
27
+
28
+ action: ProgressAction
29
+ target: str
30
+ details: Optional[str] = None
31
+ agent_name: Optional[str] = None
32
+
33
+ def __str__(self) -> str:
34
+ """Format the progress event for display."""
35
+ base = f"{self.action.ljust(11)}. {self.target}"
36
+ if self.details:
37
+ base += f" - {self.details}"
38
+ if self.agent_name:
39
+ base = f"[{self.agent_name}] {base}"
40
+ return base
41
+
42
+
43
+ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
44
+ """Convert a log event to a progress event if applicable."""
45
+
46
+ # Check to see if there is any additional data
47
+ if not event.data:
48
+ return None
49
+
50
+ event_data = event.data.get("data")
51
+ if not isinstance(event_data, dict):
52
+ return None
53
+
54
+ progress_action = event_data.get("progress_action")
55
+ if not progress_action:
56
+ return None
57
+
58
+ # Build target string based on the event type.
59
+ # Progress display is currently [time] [event] --- [target] [details]
60
+ namespace = event.namespace
61
+ agent_name = event_data.get("agent_name")
62
+ target = agent_name
63
+ details = ""
64
+ if "mcp_aggregator" in namespace:
65
+ server_name = event_data.get("server_name", "")
66
+ tool_name = event_data.get("tool_name")
67
+ if tool_name:
68
+ # fetch(fetch)
69
+ details = f"{server_name} ({tool_name})"
70
+ else:
71
+ details = f"{server_name}"
72
+
73
+ elif "augmented_llm" in namespace:
74
+ model = event_data.get("model", "")
75
+
76
+ details = f"{model}"
77
+ chat_turn = event_data.get("chat_turn")
78
+ if chat_turn is not None:
79
+ details = f"{model} turn {chat_turn}"
80
+ else:
81
+ target = event_data.get("target", "unknown")
82
+
83
+ return ProgressEvent(
84
+ ProgressAction(progress_action),
85
+ target,
86
+ details,
87
+ agent_name=event_data.get("agent_name"),
88
+ )
File without changes
@@ -0,0 +1,120 @@
1
+ """
2
+ Keep track of all workflow decorator overloads indexed by executor backend.
3
+ Different executors may have different ways of configuring workflows.
4
+ """
5
+
6
+ from typing import Callable, Dict, Type, TypeVar
7
+
8
+ R = TypeVar("R")
9
+
10
+
11
+ class DecoratorRegistry:
12
+ """Centralized decorator management with validation and metadata."""
13
+
14
+ def __init__(self):
15
+ self._workflow_defn_decorators: Dict[str, Callable[[Type], Type]] = {}
16
+ self._workflow_run_decorators: Dict[
17
+ str, Callable[[Callable[..., R]], Callable[..., R]]
18
+ ] = {}
19
+
20
+ def register_workflow_defn_decorator(
21
+ self,
22
+ executor_name: str,
23
+ decorator: Callable[[Type], Type],
24
+ ):
25
+ """
26
+ Registers a workflow definition decorator for a given executor.
27
+
28
+ :param executor_name: Unique name of the executor.
29
+ :param decorator: The decorator to register.
30
+ """
31
+ if executor_name in self._workflow_defn_decorators:
32
+ print(
33
+ "Workflow definition decorator already registered for '%s'. Overwriting.",
34
+ executor_name,
35
+ )
36
+ self._workflow_defn_decorators[executor_name] = decorator
37
+
38
+ def get_workflow_defn_decorator(self, executor_name: str) -> Callable[[Type], Type]:
39
+ """
40
+ Retrieves a workflow definition decorator for a given executor.
41
+
42
+ :param executor_name: Unique name of the executor.
43
+ :return: The decorator function.
44
+ """
45
+ return self._workflow_defn_decorators.get(executor_name)
46
+
47
+ def register_workflow_run_decorator(
48
+ self,
49
+ executor_name: str,
50
+ decorator: Callable[[Callable[..., R]], Callable[..., R]],
51
+ ):
52
+ """
53
+ Registers a workflow run decorator for a given executor.
54
+
55
+ :param executor_name: Unique name of the executor.
56
+ :param decorator: The decorator to register.
57
+ """
58
+ if executor_name in self._workflow_run_decorators:
59
+ print(
60
+ "Workflow run decorator already registered for '%s'. Overwriting.",
61
+ executor_name,
62
+ )
63
+ self._workflow_run_decorators[executor_name] = decorator
64
+
65
+ def get_workflow_run_decorator(
66
+ self, executor_name: str
67
+ ) -> Callable[[Callable[..., R]], Callable[..., R]]:
68
+ """
69
+ Retrieves a workflow run decorator for a given executor.
70
+
71
+ :param executor_name: Unique name of the executor.
72
+ :return: The decorator function.
73
+ """
74
+ return self._workflow_run_decorators.get(executor_name)
75
+
76
+
77
+ def default_workflow_defn(cls: Type, *args, **kwargs) -> Type:
78
+ """Default no-op workflow definition decorator."""
79
+ return cls
80
+
81
+
82
+ def default_workflow_run(fn: Callable[..., R]) -> Callable[..., R]:
83
+ """Default no-op workflow run decorator."""
84
+
85
+ def wrapper(*args, **kwargs):
86
+ return fn(*args, **kwargs)
87
+
88
+ return wrapper
89
+
90
+
91
+ def register_asyncio_decorators(decorator_registry: DecoratorRegistry):
92
+ """Registers default asyncio decorators."""
93
+ executor_name = "asyncio"
94
+ decorator_registry.register_workflow_defn_decorator(
95
+ executor_name, default_workflow_defn
96
+ )
97
+ decorator_registry.register_workflow_run_decorator(
98
+ executor_name, default_workflow_run
99
+ )
100
+
101
+
102
+ def register_temporal_decorators(decorator_registry: DecoratorRegistry):
103
+ """Registers Temporal decorators if Temporal SDK is available."""
104
+ try:
105
+ import temporalio.workflow as temporal_workflow
106
+
107
+ TEMPORAL_AVAILABLE = True
108
+ except ImportError:
109
+ TEMPORAL_AVAILABLE = False
110
+
111
+ if not TEMPORAL_AVAILABLE:
112
+ return
113
+
114
+ executor_name = "temporal"
115
+ decorator_registry.register_workflow_defn_decorator(
116
+ executor_name, temporal_workflow.defn
117
+ )
118
+ decorator_registry.register_workflow_run_decorator(
119
+ executor_name, temporal_workflow.run
120
+ )
@@ -0,0 +1,293 @@
1
+ import asyncio
2
+ import functools
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import asynccontextmanager
5
+ from datetime import timedelta
6
+ from typing import (
7
+ Any,
8
+ AsyncIterator,
9
+ Callable,
10
+ Coroutine,
11
+ Dict,
12
+ List,
13
+ Optional,
14
+ Type,
15
+ TypeVar,
16
+ TYPE_CHECKING,
17
+ )
18
+
19
+ from pydantic import BaseModel, ConfigDict
20
+
21
+ from mcp_agent.context_dependent import ContextDependent
22
+ from mcp_agent.executor.workflow_signal import (
23
+ AsyncioSignalHandler,
24
+ Signal,
25
+ SignalHandler,
26
+ SignalValueT,
27
+ )
28
+ from mcp_agent.logging.logger import get_logger
29
+
30
+ if TYPE_CHECKING:
31
+ from mcp_agent.context import Context
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ # Type variable for the return type of tasks
36
+ R = TypeVar("R")
37
+
38
+
39
+ class ExecutorConfig(BaseModel):
40
+ """Configuration for executors."""
41
+
42
+ max_concurrent_activities: int | None = None # Unbounded by default
43
+ timeout_seconds: timedelta | None = None # No timeout by default
44
+ retry_policy: Dict[str, Any] | None = None
45
+
46
+ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
47
+
48
+
49
+ class Executor(ABC, ContextDependent):
50
+ """Abstract base class for different execution backends"""
51
+
52
+ def __init__(
53
+ self,
54
+ engine: str,
55
+ config: ExecutorConfig | None = None,
56
+ signal_bus: SignalHandler = None,
57
+ context: Optional["Context"] = None,
58
+ **kwargs,
59
+ ):
60
+ super().__init__(context=context, **kwargs)
61
+ self.execution_engine = engine
62
+
63
+ if config:
64
+ self.config = config
65
+ else:
66
+ # TODO: saqadri - executor config should be loaded from settings
67
+ # ctx = get_current_context()
68
+ self.config = ExecutorConfig()
69
+
70
+ self.signal_bus = signal_bus
71
+
72
+ @asynccontextmanager
73
+ async def execution_context(self):
74
+ """Context manager for execution setup/teardown."""
75
+ try:
76
+ yield
77
+ except Exception as e:
78
+ # TODO: saqadri - add logging or other error handling here
79
+ raise e
80
+
81
+ @abstractmethod
82
+ async def execute(
83
+ self,
84
+ *tasks: Callable[..., R] | Coroutine[Any, Any, R],
85
+ **kwargs: Any,
86
+ ) -> List[R | BaseException]:
87
+ """Execute a list of tasks and return their results"""
88
+
89
+ @abstractmethod
90
+ async def execute_streaming(
91
+ self,
92
+ *tasks: List[Callable[..., R] | Coroutine[Any, Any, R]],
93
+ **kwargs: Any,
94
+ ) -> AsyncIterator[R | BaseException]:
95
+ """Execute tasks and yield results as they complete"""
96
+
97
+ async def map(
98
+ self,
99
+ func: Callable[..., R],
100
+ inputs: List[Any],
101
+ **kwargs: Any,
102
+ ) -> List[R | BaseException]:
103
+ """
104
+ Run `func(item)` for each item in `inputs` with concurrency limit.
105
+ """
106
+ results: List[R, BaseException] = []
107
+
108
+ async def run(item):
109
+ if self.config.max_concurrent_activities:
110
+ semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
111
+ async with semaphore:
112
+ return await self.execute(functools.partial(func, item), **kwargs)
113
+ else:
114
+ return await self.execute(functools.partial(func, item), **kwargs)
115
+
116
+ coros = [run(x) for x in inputs]
117
+ # gather all, each returns a single-element list
118
+ list_of_lists = await asyncio.gather(*coros, return_exceptions=True)
119
+
120
+ # Flatten results
121
+ for entry in list_of_lists:
122
+ if isinstance(entry, list):
123
+ results.extend(entry)
124
+ else:
125
+ # Means we got an exception at the gather level
126
+ results.append(entry)
127
+
128
+ return results
129
+
130
+ async def validate_task(
131
+ self, task: Callable[..., R] | Coroutine[Any, Any, R]
132
+ ) -> None:
133
+ """Validate a task before execution."""
134
+ if not (asyncio.iscoroutine(task) or asyncio.iscoroutinefunction(task)):
135
+ raise TypeError(f"Task must be async: {task}")
136
+
137
+ async def signal(
138
+ self,
139
+ signal_name: str,
140
+ payload: SignalValueT = None,
141
+ signal_description: str | None = None,
142
+ ) -> None:
143
+ """
144
+ Emit a signal.
145
+ """
146
+ signal = Signal[SignalValueT](
147
+ name=signal_name, payload=payload, description=signal_description
148
+ )
149
+ await self.signal_bus.signal(signal)
150
+
151
+ async def wait_for_signal(
152
+ self,
153
+ signal_name: str,
154
+ request_id: str | None = None,
155
+ workflow_id: str | None = None,
156
+ signal_description: str | None = None,
157
+ timeout_seconds: int | None = None,
158
+ signal_type: Type[SignalValueT] = str,
159
+ ) -> SignalValueT:
160
+ """
161
+ Wait until a signal with signal_name is emitted (or timeout).
162
+ Return the signal's payload when triggered, or raise on timeout.
163
+ """
164
+
165
+ # Notify any callbacks that the workflow is about to be paused waiting for a signal
166
+ if self.context.signal_notification:
167
+ self.context.signal_notification(
168
+ signal_name=signal_name,
169
+ request_id=request_id,
170
+ workflow_id=workflow_id,
171
+ metadata={
172
+ "description": signal_description,
173
+ "timeout_seconds": timeout_seconds,
174
+ "signal_type": signal_type,
175
+ },
176
+ )
177
+
178
+ signal = Signal[signal_type](
179
+ name=signal_name, description=signal_description, workflow_id=workflow_id
180
+ )
181
+ return await self.signal_bus.wait_for_signal(signal)
182
+
183
+
184
+ class AsyncioExecutor(Executor):
185
+ """Default executor using asyncio"""
186
+
187
+ def __init__(
188
+ self,
189
+ config: ExecutorConfig | None = None,
190
+ signal_bus: SignalHandler | None = None,
191
+ ):
192
+ signal_bus = signal_bus or AsyncioSignalHandler()
193
+ super().__init__(engine="asyncio", config=config, signal_bus=signal_bus)
194
+
195
+ self._activity_semaphore: asyncio.Semaphore | None = None
196
+ if self.config.max_concurrent_activities is not None:
197
+ self._activity_semaphore = asyncio.Semaphore(
198
+ self.config.max_concurrent_activities
199
+ )
200
+
201
+ async def _execute_task(
202
+ self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
203
+ ) -> R | BaseException:
204
+ async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R:
205
+ try:
206
+ if asyncio.iscoroutine(task):
207
+ return await task
208
+ elif asyncio.iscoroutinefunction(task):
209
+ return await task(**kwargs)
210
+ else:
211
+ # Execute the callable and await if it returns a coroutine
212
+ loop = asyncio.get_running_loop()
213
+
214
+ # If kwargs are provided, wrap the function with partial
215
+ if kwargs:
216
+ wrapped_task = functools.partial(task, **kwargs)
217
+ result = await loop.run_in_executor(None, wrapped_task)
218
+ else:
219
+ result = await loop.run_in_executor(None, task)
220
+
221
+ # Handle case where the sync function returns a coroutine
222
+ if asyncio.iscoroutine(result):
223
+ return await result
224
+
225
+ return result
226
+ except Exception as e:
227
+ # TODO: saqadri - adding logging or other error handling here
228
+ return e
229
+
230
+ if self._activity_semaphore:
231
+ async with self._activity_semaphore:
232
+ return await run_task(task)
233
+ else:
234
+ return await run_task(task)
235
+
236
+ async def execute(
237
+ self,
238
+ *tasks: Callable[..., R] | Coroutine[Any, Any, R],
239
+ **kwargs: Any,
240
+ ) -> List[R | BaseException]:
241
+ # TODO: saqadri - validate if async with self.execution_context() is needed here
242
+ async with self.execution_context():
243
+ return await asyncio.gather(
244
+ *(self._execute_task(task, **kwargs) for task in tasks),
245
+ return_exceptions=True,
246
+ )
247
+
248
+ async def execute_streaming(
249
+ self,
250
+ *tasks: List[Callable[..., R] | Coroutine[Any, Any, R]],
251
+ **kwargs: Any,
252
+ ) -> AsyncIterator[R | BaseException]:
253
+ # TODO: saqadri - validate if async with self.execution_context() is needed here
254
+ async with self.execution_context():
255
+ # Create futures for all tasks
256
+ futures = [
257
+ asyncio.create_task(self._execute_task(task, **kwargs))
258
+ for task in tasks
259
+ ]
260
+ pending = set(futures)
261
+
262
+ while pending:
263
+ done, pending = await asyncio.wait(
264
+ pending, return_when=asyncio.FIRST_COMPLETED
265
+ )
266
+ for future in done:
267
+ yield await future
268
+
269
+ async def signal(
270
+ self,
271
+ signal_name: str,
272
+ payload: SignalValueT = None,
273
+ signal_description: str | None = None,
274
+ ) -> None:
275
+ await super().signal(signal_name, payload, signal_description)
276
+
277
+ async def wait_for_signal(
278
+ self,
279
+ signal_name: str,
280
+ request_id: str | None = None,
281
+ workflow_id: str | None = None,
282
+ signal_description: str | None = None,
283
+ timeout_seconds: int | None = None,
284
+ signal_type: Type[SignalValueT] = str,
285
+ ) -> SignalValueT:
286
+ return await super().wait_for_signal(
287
+ signal_name,
288
+ request_id,
289
+ workflow_id,
290
+ signal_description,
291
+ timeout_seconds,
292
+ signal_type,
293
+ )
@@ -0,0 +1,34 @@
1
+ """
2
+ Keep track of all activities/tasks that the executor needs to run.
3
+ This is used by the workflow engine to dynamically orchestrate a workflow graph.
4
+ The user just writes standard functions annotated with @workflow_task, but behind the scenes a workflow graph is built.
5
+ """
6
+
7
+ from typing import Any, Callable, Dict, List
8
+
9
+
10
+ class ActivityRegistry:
11
+ """Centralized task/activity management with validation and metadata."""
12
+
13
+ def __init__(self):
14
+ self._activities: Dict[str, Callable] = {}
15
+ self._metadata: Dict[str, Dict[str, Any]] = {}
16
+
17
+ def register(
18
+ self, name: str, func: Callable, metadata: Dict[str, Any] | None = None
19
+ ):
20
+ if name in self._activities:
21
+ raise ValueError(f"Activity '{name}' is already registered.")
22
+ self._activities[name] = func
23
+ self._metadata[name] = metadata or {}
24
+
25
+ def get_activity(self, name: str) -> Callable:
26
+ if name not in self._activities:
27
+ raise KeyError(f"Activity '{name}' not found.")
28
+ return self._activities[name]
29
+
30
+ def get_metadata(self, name: str) -> Dict[str, Any]:
31
+ return self._metadata.get(name, {})
32
+
33
+ def list_activities(self) -> List[str]:
34
+ return list(self._activities.keys())