fast-agent-mcp 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/METADATA +3 -4
- fast_agent_mcp-0.2.0.dist-info/RECORD +123 -0
- mcp_agent/__init__.py +75 -0
- mcp_agent/agents/agent.py +61 -415
- mcp_agent/agents/base_agent.py +522 -0
- mcp_agent/agents/workflow/__init__.py +1 -0
- mcp_agent/agents/workflow/chain_agent.py +173 -0
- mcp_agent/agents/workflow/evaluator_optimizer.py +362 -0
- mcp_agent/agents/workflow/orchestrator_agent.py +591 -0
- mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_models.py +11 -21
- mcp_agent/agents/workflow/parallel_agent.py +182 -0
- mcp_agent/agents/workflow/router_agent.py +307 -0
- mcp_agent/app.py +15 -19
- mcp_agent/cli/commands/bootstrap.py +19 -38
- mcp_agent/cli/commands/config.py +4 -4
- mcp_agent/cli/commands/setup.py +7 -14
- mcp_agent/cli/main.py +7 -10
- mcp_agent/cli/terminal.py +3 -3
- mcp_agent/config.py +25 -40
- mcp_agent/context.py +12 -21
- mcp_agent/context_dependent.py +3 -5
- mcp_agent/core/agent_types.py +10 -7
- mcp_agent/core/direct_agent_app.py +179 -0
- mcp_agent/core/direct_decorators.py +443 -0
- mcp_agent/core/direct_factory.py +476 -0
- mcp_agent/core/enhanced_prompt.py +23 -55
- mcp_agent/core/exceptions.py +8 -8
- mcp_agent/core/fastagent.py +145 -371
- mcp_agent/core/interactive_prompt.py +424 -0
- mcp_agent/core/mcp_content.py +17 -17
- mcp_agent/core/prompt.py +6 -9
- mcp_agent/core/request_params.py +6 -3
- mcp_agent/core/validation.py +92 -18
- mcp_agent/executor/decorator_registry.py +9 -17
- mcp_agent/executor/executor.py +8 -17
- mcp_agent/executor/task_registry.py +2 -4
- mcp_agent/executor/temporal.py +19 -41
- mcp_agent/executor/workflow.py +3 -5
- mcp_agent/executor/workflow_signal.py +15 -21
- mcp_agent/human_input/handler.py +4 -7
- mcp_agent/human_input/types.py +2 -3
- mcp_agent/llm/__init__.py +2 -0
- mcp_agent/llm/augmented_llm.py +450 -0
- mcp_agent/llm/augmented_llm_passthrough.py +162 -0
- mcp_agent/llm/augmented_llm_playback.py +83 -0
- mcp_agent/llm/memory.py +103 -0
- mcp_agent/{workflows/llm → llm}/model_factory.py +22 -16
- mcp_agent/{workflows/llm → llm}/prompt_utils.py +1 -3
- mcp_agent/llm/providers/__init__.py +8 -0
- mcp_agent/{workflows/llm → llm/providers}/anthropic_utils.py +8 -25
- mcp_agent/{workflows/llm → llm/providers}/augmented_llm_anthropic.py +56 -194
- mcp_agent/llm/providers/augmented_llm_deepseek.py +53 -0
- mcp_agent/{workflows/llm → llm/providers}/augmented_llm_openai.py +99 -190
- mcp_agent/{workflows/llm → llm}/providers/multipart_converter_anthropic.py +72 -71
- mcp_agent/{workflows/llm → llm}/providers/multipart_converter_openai.py +65 -71
- mcp_agent/{workflows/llm → llm}/providers/openai_multipart.py +16 -44
- mcp_agent/{workflows/llm → llm/providers}/openai_utils.py +4 -4
- mcp_agent/{workflows/llm → llm}/providers/sampling_converter_anthropic.py +9 -11
- mcp_agent/{workflows/llm → llm}/providers/sampling_converter_openai.py +8 -12
- mcp_agent/{workflows/llm → llm}/sampling_converter.py +3 -31
- mcp_agent/llm/sampling_format_converter.py +37 -0
- mcp_agent/logging/events.py +1 -5
- mcp_agent/logging/json_serializer.py +7 -6
- mcp_agent/logging/listeners.py +20 -23
- mcp_agent/logging/logger.py +17 -19
- mcp_agent/logging/rich_progress.py +10 -8
- mcp_agent/logging/tracing.py +4 -6
- mcp_agent/logging/transport.py +22 -22
- mcp_agent/mcp/gen_client.py +1 -3
- mcp_agent/mcp/interfaces.py +117 -110
- mcp_agent/mcp/logger_textio.py +97 -0
- mcp_agent/mcp/mcp_agent_client_session.py +7 -7
- mcp_agent/mcp/mcp_agent_server.py +8 -8
- mcp_agent/mcp/mcp_aggregator.py +102 -143
- mcp_agent/mcp/mcp_connection_manager.py +20 -27
- mcp_agent/mcp/prompt_message_multipart.py +68 -16
- mcp_agent/mcp/prompt_render.py +77 -0
- mcp_agent/mcp/prompt_serialization.py +30 -48
- mcp_agent/mcp/prompts/prompt_constants.py +18 -0
- mcp_agent/mcp/prompts/prompt_helpers.py +327 -0
- mcp_agent/mcp/prompts/prompt_load.py +109 -0
- mcp_agent/mcp/prompts/prompt_server.py +155 -195
- mcp_agent/mcp/prompts/prompt_template.py +35 -66
- mcp_agent/mcp/resource_utils.py +7 -14
- mcp_agent/mcp/sampling.py +17 -17
- mcp_agent/mcp_server/agent_server.py +13 -17
- mcp_agent/mcp_server_registry.py +13 -22
- mcp_agent/resources/examples/{workflows → in_dev}/agent_build.py +3 -2
- mcp_agent/resources/examples/in_dev/slides.py +110 -0
- mcp_agent/resources/examples/internal/agent.py +6 -3
- mcp_agent/resources/examples/internal/fastagent.config.yaml +8 -2
- mcp_agent/resources/examples/internal/job.py +2 -1
- mcp_agent/resources/examples/internal/prompt_category.py +1 -1
- mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
- mcp_agent/resources/examples/internal/sizer.py +2 -1
- mcp_agent/resources/examples/internal/social.py +2 -1
- mcp_agent/resources/examples/prompting/agent.py +2 -1
- mcp_agent/resources/examples/prompting/image_server.py +4 -8
- mcp_agent/resources/examples/prompting/work_with_image.py +19 -0
- mcp_agent/ui/console_display.py +16 -20
- fast_agent_mcp-0.1.12.dist-info/RECORD +0 -161
- mcp_agent/core/agent_app.py +0 -646
- mcp_agent/core/agent_utils.py +0 -71
- mcp_agent/core/decorators.py +0 -455
- mcp_agent/core/factory.py +0 -463
- mcp_agent/core/proxies.py +0 -269
- mcp_agent/core/types.py +0 -24
- mcp_agent/eval/__init__.py +0 -0
- mcp_agent/mcp/stdio.py +0 -111
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +0 -188
- mcp_agent/resources/examples/data-analysis/analysis.py +0 -65
- mcp_agent/resources/examples/data-analysis/fastagent.config.yaml +0 -41
- mcp_agent/resources/examples/data-analysis/mount-point/WA_Fn-UseC_-HR-Employee-Attrition.csv +0 -1471
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +0 -53
- mcp_agent/resources/examples/researcher/fastagent.config.yaml +0 -66
- mcp_agent/resources/examples/researcher/researcher-eval.py +0 -53
- mcp_agent/resources/examples/researcher/researcher-imp.py +0 -190
- mcp_agent/resources/examples/researcher/researcher.py +0 -38
- mcp_agent/resources/examples/workflows/chaining.py +0 -44
- mcp_agent/resources/examples/workflows/evaluator.py +0 -78
- mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -24
- mcp_agent/resources/examples/workflows/human_input.py +0 -25
- mcp_agent/resources/examples/workflows/orchestrator.py +0 -73
- mcp_agent/resources/examples/workflows/parallel.py +0 -78
- mcp_agent/resources/examples/workflows/router.py +0 -53
- mcp_agent/resources/examples/workflows/sse.py +0 -23
- mcp_agent/telemetry/__init__.py +0 -0
- mcp_agent/telemetry/usage_tracking.py +0 -18
- mcp_agent/workflows/__init__.py +0 -0
- mcp_agent/workflows/embedding/__init__.py +0 -0
- mcp_agent/workflows/embedding/embedding_base.py +0 -61
- mcp_agent/workflows/embedding/embedding_cohere.py +0 -49
- mcp_agent/workflows/embedding/embedding_openai.py +0 -46
- mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +0 -481
- mcp_agent/workflows/intent_classifier/__init__.py +0 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +0 -120
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +0 -134
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +0 -45
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +0 -45
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +0 -161
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +0 -60
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +0 -60
- mcp_agent/workflows/llm/__init__.py +0 -0
- mcp_agent/workflows/llm/augmented_llm.py +0 -753
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +0 -241
- mcp_agent/workflows/llm/augmented_llm_playback.py +0 -109
- mcp_agent/workflows/llm/providers/__init__.py +0 -8
- mcp_agent/workflows/llm/sampling_format_converter.py +0 -22
- mcp_agent/workflows/orchestrator/__init__.py +0 -0
- mcp_agent/workflows/orchestrator/orchestrator.py +0 -578
- mcp_agent/workflows/parallel/__init__.py +0 -0
- mcp_agent/workflows/parallel/fan_in.py +0 -350
- mcp_agent/workflows/parallel/fan_out.py +0 -187
- mcp_agent/workflows/parallel/parallel_llm.py +0 -166
- mcp_agent/workflows/router/__init__.py +0 -0
- mcp_agent/workflows/router/router_base.py +0 -368
- mcp_agent/workflows/router/router_embedding.py +0 -240
- mcp_agent/workflows/router/router_embedding_cohere.py +0 -59
- mcp_agent/workflows/router/router_embedding_openai.py +0 -59
- mcp_agent/workflows/router/router_llm.py +0 -320
- mcp_agent/workflows/swarm/__init__.py +0 -0
- mcp_agent/workflows/swarm/swarm.py +0 -320
- mcp_agent/workflows/swarm/swarm_anthropic.py +0 -42
- mcp_agent/workflows/swarm/swarm_openai.py +0 -41
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/licenses/LICENSE +0 -0
- /mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_prompts.py +0 -0
mcp_agent/core/validation.py
CHANGED
@@ -2,14 +2,15 @@
|
|
2
2
|
Validation utilities for FastAgent configuration and dependencies.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Dict, List
|
5
|
+
from typing import Any, Dict, List
|
6
|
+
|
6
7
|
from mcp_agent.core.agent_types import AgentType
|
7
|
-
from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM
|
8
8
|
from mcp_agent.core.exceptions import (
|
9
|
-
ServerConfigError,
|
10
9
|
AgentConfigError,
|
11
10
|
CircularDependencyError,
|
11
|
+
ServerConfigError,
|
12
12
|
)
|
13
|
+
from mcp_agent.llm.augmented_llm import AugmentedLLM
|
13
14
|
|
14
15
|
|
15
16
|
def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> None:
|
@@ -55,7 +56,7 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
|
|
55
56
|
if agent_type == AgentType.PARALLEL.value:
|
56
57
|
# Check fan_in exists
|
57
58
|
fan_in = agent_data["fan_in"]
|
58
|
-
if fan_in not in available_components:
|
59
|
+
if fan_in and fan_in not in available_components:
|
59
60
|
raise AgentConfigError(
|
60
61
|
f"Parallel workflow '{name}' references non-existent fan_in component: {fan_in}"
|
61
62
|
)
|
@@ -105,7 +106,7 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
|
|
105
106
|
|
106
107
|
elif agent_type == AgentType.ROUTER.value:
|
107
108
|
# Check all referenced agents exist
|
108
|
-
router_agents = agent_data["
|
109
|
+
router_agents = agent_data["router_agents"]
|
109
110
|
missing = [a for a in router_agents if a not in available_components]
|
110
111
|
if missing:
|
111
112
|
raise AgentConfigError(
|
@@ -186,7 +187,7 @@ def get_dependencies(
|
|
186
187
|
deps.extend(get_dependencies(fan_out, agents, visited, path, agent_type))
|
187
188
|
elif config["type"] == AgentType.CHAIN.value:
|
188
189
|
# Get dependencies from sequence agents
|
189
|
-
sequence = config.get("sequence", config.get("
|
190
|
+
sequence = config.get("sequence", config.get("router_agents", []))
|
190
191
|
for agent_name in sequence:
|
191
192
|
deps.extend(get_dependencies(agent_name, agents, visited, path, agent_type))
|
192
193
|
|
@@ -198,23 +199,96 @@ def get_dependencies(
|
|
198
199
|
return deps
|
199
200
|
|
200
201
|
|
201
|
-
def
|
202
|
-
|
203
|
-
) -> List[str]:
|
202
|
+
def get_dependencies_groups(
|
203
|
+
agents_dict: Dict[str, Dict[str, Any]], allow_cycles: bool = False
|
204
|
+
) -> List[List[str]]:
|
204
205
|
"""
|
205
|
-
Get dependencies
|
206
|
-
|
206
|
+
Get dependencies between agents and group them into dependency layers.
|
207
|
+
Each layer can be initialized in parallel.
|
207
208
|
|
208
209
|
Args:
|
209
|
-
|
210
|
-
|
211
|
-
visited: Set of already visited agents
|
212
|
-
path: Current path for cycle detection
|
210
|
+
agents_dict: Dictionary of agent configurations
|
211
|
+
allow_cycles: Whether to allow cyclic dependencies
|
213
212
|
|
214
213
|
Returns:
|
215
|
-
List of
|
214
|
+
List of lists, where each inner list is a group of agents that can be initialized together
|
216
215
|
|
217
216
|
Raises:
|
218
|
-
CircularDependencyError: If circular dependency detected
|
217
|
+
CircularDependencyError: If circular dependency detected and allow_cycles is False
|
219
218
|
"""
|
220
|
-
|
219
|
+
# Get all agent names
|
220
|
+
agent_names = list(agents_dict.keys())
|
221
|
+
|
222
|
+
# Dictionary to store dependencies for each agent
|
223
|
+
dependencies = {name: set() for name in agent_names}
|
224
|
+
|
225
|
+
# Build the dependency graph
|
226
|
+
for name, agent_data in agents_dict.items():
|
227
|
+
agent_type = agent_data["type"]
|
228
|
+
|
229
|
+
if agent_type == AgentType.PARALLEL.value:
|
230
|
+
# Parallel agents depend on their fan-out and fan-in agents
|
231
|
+
dependencies[name].update(agent_data.get("parallel_agents", []))
|
232
|
+
elif agent_type == AgentType.CHAIN.value:
|
233
|
+
# Chain agents depend on the agents in their sequence
|
234
|
+
dependencies[name].update(agent_data.get("chain_agents", []))
|
235
|
+
elif agent_type == AgentType.ROUTER.value:
|
236
|
+
# Router agents depend on the agents they route to
|
237
|
+
dependencies[name].update(agent_data.get("router_agents", []))
|
238
|
+
elif agent_type == AgentType.ORCHESTRATOR.value:
|
239
|
+
# Orchestrator agents depend on their child agents
|
240
|
+
dependencies[name].update(agent_data.get("child_agents", []))
|
241
|
+
elif agent_type == AgentType.EVALUATOR_OPTIMIZER.value:
|
242
|
+
# Evaluator-Optimizer agents depend on their evaluation and optimization agents
|
243
|
+
dependencies[name].update(agent_data.get("eval_optimizer_agents", []))
|
244
|
+
|
245
|
+
# Check for cycles if not allowed
|
246
|
+
if not allow_cycles:
|
247
|
+
visited = set()
|
248
|
+
path = set()
|
249
|
+
|
250
|
+
def visit(node) -> None:
|
251
|
+
if node in path:
|
252
|
+
path_str = " -> ".join(path) + " -> " + node
|
253
|
+
raise CircularDependencyError(f"Circular dependency detected: {path_str}")
|
254
|
+
if node in visited:
|
255
|
+
return
|
256
|
+
|
257
|
+
path.add(node)
|
258
|
+
for dep in dependencies[node]:
|
259
|
+
if dep in agent_names: # Skip dependencies to non-existent agents
|
260
|
+
visit(dep)
|
261
|
+
path.remove(node)
|
262
|
+
visited.add(node)
|
263
|
+
|
264
|
+
# Check each node
|
265
|
+
for name in agent_names:
|
266
|
+
if name not in visited:
|
267
|
+
visit(name)
|
268
|
+
|
269
|
+
# Group agents by dependency level
|
270
|
+
result = []
|
271
|
+
remaining = set(agent_names)
|
272
|
+
|
273
|
+
while remaining:
|
274
|
+
# Find all agents that have no remaining dependencies
|
275
|
+
current_level = set()
|
276
|
+
for name in remaining:
|
277
|
+
if not dependencies[name] & remaining: # If no dependencies in remaining agents
|
278
|
+
current_level.add(name)
|
279
|
+
|
280
|
+
if not current_level:
|
281
|
+
if allow_cycles:
|
282
|
+
# If cycles are allowed, just add one remaining node to break the cycle
|
283
|
+
current_level.add(next(iter(remaining)))
|
284
|
+
else:
|
285
|
+
# This should not happen if we checked for cycles
|
286
|
+
raise CircularDependencyError("Unresolvable dependency cycle detected")
|
287
|
+
|
288
|
+
# Add the current level to the result
|
289
|
+
result.append(list(current_level))
|
290
|
+
|
291
|
+
# Remove current level from remaining
|
292
|
+
remaining -= current_level
|
293
|
+
|
294
|
+
return result
|
@@ -11,7 +11,7 @@ R = TypeVar("R")
|
|
11
11
|
class DecoratorRegistry:
|
12
12
|
"""Centralized decorator management with validation and metadata."""
|
13
13
|
|
14
|
-
def __init__(self):
|
14
|
+
def __init__(self) -> None:
|
15
15
|
self._workflow_defn_decorators: Dict[str, Callable[[Type], Type]] = {}
|
16
16
|
self._workflow_run_decorators: Dict[
|
17
17
|
str, Callable[[Callable[..., R]], Callable[..., R]]
|
@@ -21,7 +21,7 @@ class DecoratorRegistry:
|
|
21
21
|
self,
|
22
22
|
executor_name: str,
|
23
23
|
decorator: Callable[[Type], Type],
|
24
|
-
):
|
24
|
+
) -> None:
|
25
25
|
"""
|
26
26
|
Registers a workflow definition decorator for a given executor.
|
27
27
|
|
@@ -48,7 +48,7 @@ class DecoratorRegistry:
|
|
48
48
|
self,
|
49
49
|
executor_name: str,
|
50
50
|
decorator: Callable[[Callable[..., R]], Callable[..., R]],
|
51
|
-
):
|
51
|
+
) -> None:
|
52
52
|
"""
|
53
53
|
Registers a workflow run decorator for a given executor.
|
54
54
|
|
@@ -88,18 +88,14 @@ def default_workflow_run(fn: Callable[..., R]) -> Callable[..., R]:
|
|
88
88
|
return wrapper
|
89
89
|
|
90
90
|
|
91
|
-
def register_asyncio_decorators(decorator_registry: DecoratorRegistry):
|
91
|
+
def register_asyncio_decorators(decorator_registry: DecoratorRegistry) -> None:
|
92
92
|
"""Registers default asyncio decorators."""
|
93
93
|
executor_name = "asyncio"
|
94
|
-
decorator_registry.register_workflow_defn_decorator(
|
95
|
-
|
96
|
-
)
|
97
|
-
decorator_registry.register_workflow_run_decorator(
|
98
|
-
executor_name, default_workflow_run
|
99
|
-
)
|
94
|
+
decorator_registry.register_workflow_defn_decorator(executor_name, default_workflow_defn)
|
95
|
+
decorator_registry.register_workflow_run_decorator(executor_name, default_workflow_run)
|
100
96
|
|
101
97
|
|
102
|
-
def register_temporal_decorators(decorator_registry: DecoratorRegistry):
|
98
|
+
def register_temporal_decorators(decorator_registry: DecoratorRegistry) -> None:
|
103
99
|
"""Registers Temporal decorators if Temporal SDK is available."""
|
104
100
|
try:
|
105
101
|
import temporalio.workflow as temporal_workflow
|
@@ -112,9 +108,5 @@ def register_temporal_decorators(decorator_registry: DecoratorRegistry):
|
|
112
108
|
return
|
113
109
|
|
114
110
|
executor_name = "temporal"
|
115
|
-
decorator_registry.register_workflow_defn_decorator(
|
116
|
-
|
117
|
-
)
|
118
|
-
decorator_registry.register_workflow_run_decorator(
|
119
|
-
executor_name, temporal_workflow.run
|
120
|
-
)
|
111
|
+
decorator_registry.register_workflow_defn_decorator(executor_name, temporal_workflow.defn)
|
112
|
+
decorator_registry.register_workflow_run_decorator(executor_name, temporal_workflow.run)
|
mcp_agent/executor/executor.py
CHANGED
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|
4
4
|
from contextlib import asynccontextmanager
|
5
5
|
from datetime import timedelta
|
6
6
|
from typing import (
|
7
|
+
TYPE_CHECKING,
|
7
8
|
Any,
|
8
9
|
AsyncIterator,
|
9
10
|
Callable,
|
@@ -13,7 +14,6 @@ from typing import (
|
|
13
14
|
Optional,
|
14
15
|
Type,
|
15
16
|
TypeVar,
|
16
|
-
TYPE_CHECKING,
|
17
17
|
)
|
18
18
|
|
19
19
|
from pydantic import BaseModel, ConfigDict
|
@@ -56,7 +56,7 @@ class Executor(ABC, ContextDependent):
|
|
56
56
|
signal_bus: SignalHandler = None,
|
57
57
|
context: Optional["Context"] = None,
|
58
58
|
**kwargs,
|
59
|
-
):
|
59
|
+
) -> None:
|
60
60
|
super().__init__(context=context, **kwargs)
|
61
61
|
self.execution_engine = engine
|
62
62
|
|
@@ -127,9 +127,7 @@ class Executor(ABC, ContextDependent):
|
|
127
127
|
|
128
128
|
return results
|
129
129
|
|
130
|
-
async def validate_task(
|
131
|
-
self, task: Callable[..., R] | Coroutine[Any, Any, R]
|
132
|
-
) -> None:
|
130
|
+
async def validate_task(self, task: Callable[..., R] | Coroutine[Any, Any, R]) -> None:
|
133
131
|
"""Validate a task before execution."""
|
134
132
|
if not (asyncio.iscoroutine(task) or asyncio.iscoroutinefunction(task)):
|
135
133
|
raise TypeError(f"Task must be async: {task}")
|
@@ -164,7 +162,7 @@ class Executor(ABC, ContextDependent):
|
|
164
162
|
|
165
163
|
# Notify any callbacks that the workflow is about to be paused waiting for a signal
|
166
164
|
if self.context.signal_notification:
|
167
|
-
self.context.signal_notification(
|
165
|
+
await self.context.signal_notification(
|
168
166
|
signal_name=signal_name,
|
169
167
|
request_id=request_id,
|
170
168
|
workflow_id=workflow_id,
|
@@ -188,15 +186,13 @@ class AsyncioExecutor(Executor):
|
|
188
186
|
self,
|
189
187
|
config: ExecutorConfig | None = None,
|
190
188
|
signal_bus: SignalHandler | None = None,
|
191
|
-
):
|
189
|
+
) -> None:
|
192
190
|
signal_bus = signal_bus or AsyncioSignalHandler()
|
193
191
|
super().__init__(engine="asyncio", config=config, signal_bus=signal_bus)
|
194
192
|
|
195
193
|
self._activity_semaphore: asyncio.Semaphore | None = None
|
196
194
|
if self.config.max_concurrent_activities is not None:
|
197
|
-
self._activity_semaphore = asyncio.Semaphore(
|
198
|
-
self.config.max_concurrent_activities
|
199
|
-
)
|
195
|
+
self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
|
200
196
|
|
201
197
|
async def _execute_task(
|
202
198
|
self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
|
@@ -253,16 +249,11 @@ class AsyncioExecutor(Executor):
|
|
253
249
|
# TODO: saqadri - validate if async with self.execution_context() is needed here
|
254
250
|
async with self.execution_context():
|
255
251
|
# Create futures for all tasks
|
256
|
-
futures = [
|
257
|
-
asyncio.create_task(self._execute_task(task, **kwargs))
|
258
|
-
for task in tasks
|
259
|
-
]
|
252
|
+
futures = [asyncio.create_task(self._execute_task(task, **kwargs)) for task in tasks]
|
260
253
|
pending = set(futures)
|
261
254
|
|
262
255
|
while pending:
|
263
|
-
done, pending = await asyncio.wait(
|
264
|
-
pending, return_when=asyncio.FIRST_COMPLETED
|
265
|
-
)
|
256
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
266
257
|
for future in done:
|
267
258
|
yield await future
|
268
259
|
|
@@ -10,13 +10,11 @@ from typing import Any, Callable, Dict, List
|
|
10
10
|
class ActivityRegistry:
|
11
11
|
"""Centralized task/activity management with validation and metadata."""
|
12
12
|
|
13
|
-
def __init__(self):
|
13
|
+
def __init__(self) -> None:
|
14
14
|
self._activities: Dict[str, Callable] = {}
|
15
15
|
self._metadata: Dict[str, Dict[str, Any]] = {}
|
16
16
|
|
17
|
-
def register(
|
18
|
-
self, name: str, func: Callable, metadata: Dict[str, Any] | None = None
|
19
|
-
):
|
17
|
+
def register(self, name: str, func: Callable, metadata: Dict[str, Any] | None = None) -> None:
|
20
18
|
if name in self._activities:
|
21
19
|
raise ValueError(f"Activity '{name}' is already registered.")
|
22
20
|
self._activities[name] = func
|
mcp_agent/executor/temporal.py
CHANGED
@@ -9,6 +9,7 @@ import asyncio
|
|
9
9
|
import functools
|
10
10
|
import uuid
|
11
11
|
from typing import (
|
12
|
+
TYPE_CHECKING,
|
12
13
|
Any,
|
13
14
|
AsyncIterator,
|
14
15
|
Callable,
|
@@ -16,11 +17,10 @@ from typing import (
|
|
16
17
|
Dict,
|
17
18
|
List,
|
18
19
|
Optional,
|
19
|
-
TYPE_CHECKING,
|
20
20
|
)
|
21
21
|
|
22
22
|
from pydantic import ConfigDict
|
23
|
-
from temporalio import activity,
|
23
|
+
from temporalio import activity, exceptions, workflow
|
24
24
|
from temporalio.client import Client as TemporalClient
|
25
25
|
from temporalio.worker import Worker
|
26
26
|
|
@@ -59,22 +59,18 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
59
59
|
|
60
60
|
# Define the signal handler for this specific registration
|
61
61
|
@workflow.signal(name=unique_signal_name)
|
62
|
-
def signal_handler(value: SignalValueT):
|
62
|
+
def signal_handler(value: SignalValueT) -> None:
|
63
63
|
container["value"] = value
|
64
64
|
container["completed"] = True
|
65
65
|
|
66
66
|
async with self._lock:
|
67
67
|
# Register both the signal registration and handler atomically
|
68
68
|
self._pending_signals.setdefault(signal.name, []).append(registration)
|
69
|
-
self._handlers.setdefault(signal.name, []).append(
|
70
|
-
(unique_signal_name, signal_handler)
|
71
|
-
)
|
69
|
+
self._handlers.setdefault(signal.name, []).append((unique_signal_name, signal_handler))
|
72
70
|
|
73
71
|
try:
|
74
72
|
# Wait for signal with optional timeout
|
75
|
-
await workflow.wait_condition(
|
76
|
-
lambda: container["completed"], timeout=timeout_seconds
|
77
|
-
)
|
73
|
+
await workflow.wait_condition(lambda: container["completed"], timeout=timeout_seconds)
|
78
74
|
|
79
75
|
return container["value"]
|
80
76
|
except asyncio.TimeoutError as exc:
|
@@ -94,9 +90,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
94
90
|
# Remove ourselves from _handlers
|
95
91
|
if signal.name in self._handlers:
|
96
92
|
self._handlers[signal.name] = [
|
97
|
-
h
|
98
|
-
for h in self._handlers[signal.name]
|
99
|
-
if h[0] != unique_signal_name
|
93
|
+
h for h in self._handlers[signal.name] if h[0] != unique_signal_name
|
100
94
|
]
|
101
95
|
if not self._handlers[signal.name]:
|
102
96
|
del self._handlers[signal.name]
|
@@ -110,7 +104,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
110
104
|
|
111
105
|
# Create the actual handler that will be registered with Temporal
|
112
106
|
@workflow.signal(name=unique_signal_name)
|
113
|
-
async def wrapped(signal_value: SignalValueT):
|
107
|
+
async def wrapped(signal_value: SignalValueT) -> None:
|
114
108
|
# Create a signal object to pass to the handler
|
115
109
|
signal = Signal(
|
116
110
|
name=signal_name,
|
@@ -123,19 +117,15 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
123
117
|
func(signal)
|
124
118
|
|
125
119
|
# Register the handler under the original signal name
|
126
|
-
self._handlers.setdefault(signal_name, []).append(
|
127
|
-
(unique_signal_name, wrapped)
|
128
|
-
)
|
120
|
+
self._handlers.setdefault(signal_name, []).append((unique_signal_name, wrapped))
|
129
121
|
return func
|
130
122
|
|
131
123
|
return decorator
|
132
124
|
|
133
|
-
async def signal(self, signal):
|
125
|
+
async def signal(self, signal) -> None:
|
134
126
|
self.validate_signal(signal)
|
135
127
|
|
136
|
-
workflow_handle = workflow.get_external_workflow_handle(
|
137
|
-
workflow_id=signal.workflow_id
|
138
|
-
)
|
128
|
+
workflow_handle = workflow.get_external_workflow_handle(workflow_id=signal.workflow_id)
|
139
129
|
|
140
130
|
# Send the signal to all registrations of this signal
|
141
131
|
async with self._lock:
|
@@ -147,9 +137,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
147
137
|
if registration.workflow_id == signal.workflow_id:
|
148
138
|
# Only signal for registrations of that workflow
|
149
139
|
signal_tasks.append(
|
150
|
-
workflow_handle.signal(
|
151
|
-
registration.unique_name, signal.payload
|
152
|
-
)
|
140
|
+
workflow_handle.signal(registration.unique_name, signal.payload)
|
153
141
|
)
|
154
142
|
else:
|
155
143
|
continue
|
@@ -157,13 +145,11 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
157
145
|
# Notify any registered handler functions
|
158
146
|
if signal.name in self._handlers:
|
159
147
|
for unique_name, _ in self._handlers[signal.name]:
|
160
|
-
signal_tasks.append(
|
161
|
-
workflow_handle.signal(unique_name, signal.payload)
|
162
|
-
)
|
148
|
+
signal_tasks.append(workflow_handle.signal(unique_name, signal.payload))
|
163
149
|
|
164
150
|
await asyncio.gather(*signal_tasks, return_exceptions=True)
|
165
151
|
|
166
|
-
def validate_signal(self, signal):
|
152
|
+
def validate_signal(self, signal) -> None:
|
167
153
|
super().validate_signal(signal)
|
168
154
|
# Add TemporalSignalHandler-specific validation
|
169
155
|
if signal.workflow_id is None:
|
@@ -188,7 +174,7 @@ class TemporalExecutor(Executor):
|
|
188
174
|
client: TemporalClient | None = None,
|
189
175
|
context: Optional["Context"] = None,
|
190
176
|
**kwargs,
|
191
|
-
):
|
177
|
+
) -> None:
|
192
178
|
signal_bus = signal_bus or TemporalSignalHandler()
|
193
179
|
super().__init__(
|
194
180
|
engine="temporal",
|
@@ -205,9 +191,7 @@ class TemporalExecutor(Executor):
|
|
205
191
|
self._activity_semaphore = None
|
206
192
|
|
207
193
|
if config.max_concurrent_activities is not None:
|
208
|
-
self._activity_semaphore = asyncio.Semaphore(
|
209
|
-
self.config.max_concurrent_activities
|
210
|
-
)
|
194
|
+
self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
|
211
195
|
|
212
196
|
@staticmethod
|
213
197
|
def wrap_as_activity(
|
@@ -275,9 +259,7 @@ class TemporalExecutor(Executor):
|
|
275
259
|
func = task.func if isinstance(task, functools.partial) else task
|
276
260
|
is_workflow_task = getattr(func, "is_workflow_task", False)
|
277
261
|
if not is_workflow_task:
|
278
|
-
return await asyncio.create_task(
|
279
|
-
self._execute_task_as_async(task, **kwargs)
|
280
|
-
)
|
262
|
+
return await asyncio.create_task(self._execute_task_as_async(task, **kwargs))
|
281
263
|
|
282
264
|
execution_metadata: Dict[str, Any] = getattr(func, "execution_metadata", {})
|
283
265
|
|
@@ -319,9 +301,7 @@ class TemporalExecutor(Executor):
|
|
319
301
|
) -> List[R | BaseException]:
|
320
302
|
# Must be called from within a workflow
|
321
303
|
if not workflow._Runtime.current():
|
322
|
-
raise RuntimeError(
|
323
|
-
"TemporalExecutor.execute must be called from within a workflow"
|
324
|
-
)
|
304
|
+
raise RuntimeError("TemporalExecutor.execute must be called from within a workflow")
|
325
305
|
|
326
306
|
# TODO: saqadri - validate if async with self.execution_context() is needed here
|
327
307
|
async with self.execution_context():
|
@@ -347,9 +327,7 @@ class TemporalExecutor(Executor):
|
|
347
327
|
pending = set(futures)
|
348
328
|
|
349
329
|
while pending:
|
350
|
-
done, pending = await workflow.wait(
|
351
|
-
pending, return_when=asyncio.FIRST_COMPLETED
|
352
|
-
)
|
330
|
+
done, pending = await workflow.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
353
331
|
for future in done:
|
354
332
|
try:
|
355
333
|
result = await future
|
@@ -368,7 +346,7 @@ class TemporalExecutor(Executor):
|
|
368
346
|
|
369
347
|
return self.client
|
370
348
|
|
371
|
-
async def start_worker(self):
|
349
|
+
async def start_worker(self) -> None:
|
372
350
|
"""
|
373
351
|
Start a worker in this process, auto-registering all tasks
|
374
352
|
from the global registry. Also picks up any classes decorated
|
mcp_agent/executor/workflow.py
CHANGED
@@ -62,7 +62,7 @@ class Workflow(ABC, Generic[T]):
|
|
62
62
|
name: str | None = None,
|
63
63
|
metadata: Dict[str, Any] | None = None,
|
64
64
|
**kwargs: Any,
|
65
|
-
):
|
65
|
+
) -> None:
|
66
66
|
self.executor = executor
|
67
67
|
self.name = name or self.__class__.__name__
|
68
68
|
self.init_kwargs = kwargs
|
@@ -80,7 +80,7 @@ class Workflow(ABC, Generic[T]):
|
|
80
80
|
Main workflow implementation. Must be overridden by subclasses.
|
81
81
|
"""
|
82
82
|
|
83
|
-
async def update_state(self, **kwargs):
|
83
|
+
async def update_state(self, **kwargs) -> None:
|
84
84
|
"""Syntactic sugar to update workflow state."""
|
85
85
|
for key, value in kwargs.items():
|
86
86
|
self.state[key] = value
|
@@ -93,9 +93,7 @@ class Workflow(ABC, Generic[T]):
|
|
93
93
|
Convenience method for human input. Uses `human_input` signal
|
94
94
|
so we can unify local (console input) and Temporal signals.
|
95
95
|
"""
|
96
|
-
return await self.executor.wait_for_signal(
|
97
|
-
"human_input", description=description
|
98
|
-
)
|
96
|
+
return await self.executor.wait_for_signal("human_input", description=description)
|
99
97
|
|
100
98
|
|
101
99
|
# ############################
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import asyncio
|
2
2
|
import uuid
|
3
|
-
from abc import
|
3
|
+
from abc import ABC, abstractmethod
|
4
4
|
from typing import Any, Callable, Dict, Generic, List, Protocol, TypeVar
|
5
5
|
|
6
6
|
from pydantic import BaseModel, ConfigDict
|
@@ -71,14 +71,14 @@ class PendingSignal(BaseModel):
|
|
71
71
|
class BaseSignalHandler(ABC, Generic[SignalValueT]):
|
72
72
|
"""Base class implementing common signal handling functionality."""
|
73
73
|
|
74
|
-
def __init__(self):
|
74
|
+
def __init__(self) -> None:
|
75
75
|
# Map signal_name -> list of PendingSignal objects
|
76
76
|
self._pending_signals: Dict[str, List[PendingSignal]] = {}
|
77
77
|
# Map signal_name -> list of (unique_name, handler) tuples
|
78
78
|
self._handlers: Dict[str, List[tuple[str, Callable]]] = {}
|
79
79
|
self._lock = asyncio.Lock()
|
80
80
|
|
81
|
-
async def cleanup(self, signal_name: str | None = None):
|
81
|
+
async def cleanup(self, signal_name: str | None = None) -> None:
|
82
82
|
"""Clean up handlers and registrations for a signal or all signals."""
|
83
83
|
async with self._lock:
|
84
84
|
if signal_name:
|
@@ -90,7 +90,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
|
|
90
90
|
self._handlers.clear()
|
91
91
|
self._pending_signals.clear()
|
92
92
|
|
93
|
-
def validate_signal(self, signal: Signal[SignalValueT]):
|
93
|
+
def validate_signal(self, signal: Signal[SignalValueT]) -> None:
|
94
94
|
"""Validate signal properties."""
|
95
95
|
if not signal.name:
|
96
96
|
raise ValueError("Signal name is required")
|
@@ -102,7 +102,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
|
|
102
102
|
def decorator(func: Callable) -> Callable:
|
103
103
|
unique_name = f"{signal_name}_{uuid.uuid4()}"
|
104
104
|
|
105
|
-
async def wrapped(value: SignalValueT):
|
105
|
+
async def wrapped(value: SignalValueT) -> None:
|
106
106
|
try:
|
107
107
|
if asyncio.iscoroutinefunction(func):
|
108
108
|
await func(value)
|
@@ -133,7 +133,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
|
|
133
133
|
class ConsoleSignalHandler(SignalHandler[str]):
|
134
134
|
"""Simple console-based signal handling (blocks on input)."""
|
135
135
|
|
136
|
-
def __init__(self):
|
136
|
+
def __init__(self) -> None:
|
137
137
|
self._pending_signals: Dict[str, List[PendingSignal]] = {}
|
138
138
|
self._handlers: Dict[str, List[Callable]] = {}
|
139
139
|
|
@@ -163,7 +163,7 @@ class ConsoleSignalHandler(SignalHandler[str]):
|
|
163
163
|
|
164
164
|
def on_signal(self, signal_name):
|
165
165
|
def decorator(func):
|
166
|
-
async def wrapped(value: SignalValueT):
|
166
|
+
async def wrapped(value: SignalValueT) -> None:
|
167
167
|
if asyncio.iscoroutinefunction(func):
|
168
168
|
await func(value)
|
169
169
|
else:
|
@@ -174,13 +174,11 @@ class ConsoleSignalHandler(SignalHandler[str]):
|
|
174
174
|
|
175
175
|
return decorator
|
176
176
|
|
177
|
-
async def signal(self, signal):
|
177
|
+
async def signal(self, signal) -> None:
|
178
178
|
print(f"[SIGNAL SENT: {signal.name}] Value: {signal.payload}")
|
179
179
|
|
180
180
|
handlers = self._handlers.get(signal.name, [])
|
181
|
-
await asyncio.gather(
|
182
|
-
*(handler(signal) for handler in handlers), return_exceptions=True
|
183
|
-
)
|
181
|
+
await asyncio.gather(*(handler(signal) for handler in handlers), return_exceptions=True)
|
184
182
|
|
185
183
|
# Notify any waiting coroutines
|
186
184
|
if signal.name in self._pending_signals:
|
@@ -194,9 +192,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
194
192
|
Asyncio-based signal handling using an internal dictionary of asyncio Events.
|
195
193
|
"""
|
196
194
|
|
197
|
-
async def wait_for_signal(
|
198
|
-
self, signal, timeout_seconds: int | None = None
|
199
|
-
) -> SignalValueT:
|
195
|
+
async def wait_for_signal(self, signal, timeout_seconds: int | None = None) -> SignalValueT:
|
200
196
|
event = asyncio.Event()
|
201
197
|
unique_name = str(uuid.uuid4())
|
202
198
|
|
@@ -236,7 +232,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
236
232
|
|
237
233
|
def on_signal(self, signal_name):
|
238
234
|
def decorator(func):
|
239
|
-
async def wrapped(value: SignalValueT):
|
235
|
+
async def wrapped(value: SignalValueT) -> None:
|
240
236
|
if asyncio.iscoroutinefunction(func):
|
241
237
|
await func(value)
|
242
238
|
else:
|
@@ -247,7 +243,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
|
|
247
243
|
|
248
244
|
return decorator
|
249
245
|
|
250
|
-
async def signal(self, signal):
|
246
|
+
async def signal(self, signal) -> None:
|
251
247
|
async with self._lock:
|
252
248
|
# Notify any waiting coroutines
|
253
249
|
if signal.name in self._pending_signals:
|
@@ -272,11 +268,11 @@ class LocalSignalStore:
|
|
272
268
|
and triggers them when a signal is emitted.
|
273
269
|
"""
|
274
270
|
|
275
|
-
def __init__(self):
|
271
|
+
def __init__(self) -> None:
|
276
272
|
# For each signal_name, store a list of futures that are waiting for it
|
277
273
|
self._waiters: Dict[str, List[asyncio.Future]] = {}
|
278
274
|
|
279
|
-
async def emit(self, signal_name: str, payload: Any):
|
275
|
+
async def emit(self, signal_name: str, payload: Any) -> None:
|
280
276
|
# If we have waiting futures, set their result
|
281
277
|
if signal_name in self._waiters:
|
282
278
|
for future in self._waiters[signal_name]:
|
@@ -284,9 +280,7 @@ class LocalSignalStore:
|
|
284
280
|
future.set_result(payload)
|
285
281
|
self._waiters[signal_name].clear()
|
286
282
|
|
287
|
-
async def wait_for(
|
288
|
-
self, signal_name: str, timeout_seconds: int | None = None
|
289
|
-
) -> Any:
|
283
|
+
async def wait_for(self, signal_name: str, timeout_seconds: int | None = None) -> Any:
|
290
284
|
loop = asyncio.get_running_loop()
|
291
285
|
future = loop.create_future()
|
292
286
|
|