fast-agent-mcp 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
  2. fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
  3. mcp_agent/agents/agent.py +37 -79
  4. mcp_agent/app.py +16 -22
  5. mcp_agent/cli/commands/bootstrap.py +22 -52
  6. mcp_agent/cli/commands/config.py +4 -4
  7. mcp_agent/cli/commands/setup.py +11 -26
  8. mcp_agent/cli/main.py +6 -9
  9. mcp_agent/cli/terminal.py +2 -2
  10. mcp_agent/config.py +1 -5
  11. mcp_agent/context.py +13 -24
  12. mcp_agent/context_dependent.py +3 -7
  13. mcp_agent/core/agent_app.py +45 -121
  14. mcp_agent/core/agent_utils.py +3 -5
  15. mcp_agent/core/decorators.py +5 -12
  16. mcp_agent/core/enhanced_prompt.py +25 -52
  17. mcp_agent/core/exceptions.py +8 -8
  18. mcp_agent/core/factory.py +29 -70
  19. mcp_agent/core/fastagent.py +48 -88
  20. mcp_agent/core/mcp_content.py +8 -16
  21. mcp_agent/core/prompt.py +8 -15
  22. mcp_agent/core/proxies.py +34 -25
  23. mcp_agent/core/request_params.py +6 -3
  24. mcp_agent/core/types.py +4 -6
  25. mcp_agent/core/validation.py +4 -3
  26. mcp_agent/executor/decorator_registry.py +11 -23
  27. mcp_agent/executor/executor.py +8 -17
  28. mcp_agent/executor/task_registry.py +2 -4
  29. mcp_agent/executor/temporal.py +28 -74
  30. mcp_agent/executor/workflow.py +3 -5
  31. mcp_agent/executor/workflow_signal.py +17 -29
  32. mcp_agent/human_input/handler.py +4 -9
  33. mcp_agent/human_input/types.py +2 -3
  34. mcp_agent/logging/events.py +1 -5
  35. mcp_agent/logging/json_serializer.py +7 -6
  36. mcp_agent/logging/listeners.py +20 -23
  37. mcp_agent/logging/logger.py +15 -17
  38. mcp_agent/logging/rich_progress.py +10 -8
  39. mcp_agent/logging/tracing.py +4 -6
  40. mcp_agent/logging/transport.py +22 -22
  41. mcp_agent/mcp/gen_client.py +4 -12
  42. mcp_agent/mcp/interfaces.py +71 -86
  43. mcp_agent/mcp/mcp_agent_client_session.py +11 -19
  44. mcp_agent/mcp/mcp_agent_server.py +8 -10
  45. mcp_agent/mcp/mcp_aggregator.py +45 -117
  46. mcp_agent/mcp/mcp_connection_manager.py +16 -37
  47. mcp_agent/mcp/prompt_message_multipart.py +12 -18
  48. mcp_agent/mcp/prompt_serialization.py +13 -38
  49. mcp_agent/mcp/prompts/prompt_load.py +99 -0
  50. mcp_agent/mcp/prompts/prompt_server.py +21 -128
  51. mcp_agent/mcp/prompts/prompt_template.py +20 -42
  52. mcp_agent/mcp/resource_utils.py +8 -17
  53. mcp_agent/mcp/sampling.py +5 -14
  54. mcp_agent/mcp/stdio.py +11 -8
  55. mcp_agent/mcp_server/agent_server.py +10 -17
  56. mcp_agent/mcp_server_registry.py +13 -35
  57. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
  58. mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
  59. mcp_agent/resources/examples/data-analysis/slides.py +110 -0
  60. mcp_agent/resources/examples/internal/agent.py +2 -1
  61. mcp_agent/resources/examples/internal/job.py +2 -1
  62. mcp_agent/resources/examples/internal/prompt_category.py +1 -1
  63. mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
  64. mcp_agent/resources/examples/internal/sizer.py +2 -1
  65. mcp_agent/resources/examples/internal/social.py +2 -1
  66. mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
  67. mcp_agent/resources/examples/prompting/agent.py +2 -1
  68. mcp_agent/resources/examples/prompting/image_server.py +5 -11
  69. mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
  70. mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
  71. mcp_agent/resources/examples/researcher/researcher.py +2 -1
  72. mcp_agent/resources/examples/workflows/agent_build.py +2 -1
  73. mcp_agent/resources/examples/workflows/chaining.py +2 -1
  74. mcp_agent/resources/examples/workflows/evaluator.py +2 -1
  75. mcp_agent/resources/examples/workflows/human_input.py +2 -1
  76. mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
  77. mcp_agent/resources/examples/workflows/parallel.py +2 -1
  78. mcp_agent/resources/examples/workflows/router.py +2 -1
  79. mcp_agent/resources/examples/workflows/sse.py +1 -1
  80. mcp_agent/telemetry/usage_tracking.py +2 -1
  81. mcp_agent/ui/console_display.py +15 -39
  82. mcp_agent/workflows/embedding/embedding_base.py +1 -4
  83. mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
  84. mcp_agent/workflows/embedding/embedding_openai.py +4 -13
  85. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
  86. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
  87. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
  88. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
  89. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
  90. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
  91. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
  92. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
  93. mcp_agent/workflows/llm/anthropic_utils.py +8 -29
  94. mcp_agent/workflows/llm/augmented_llm.py +69 -247
  95. mcp_agent/workflows/llm/augmented_llm_anthropic.py +39 -73
  96. mcp_agent/workflows/llm/augmented_llm_openai.py +42 -97
  97. mcp_agent/workflows/llm/augmented_llm_passthrough.py +13 -20
  98. mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
  99. mcp_agent/workflows/llm/memory.py +103 -0
  100. mcp_agent/workflows/llm/model_factory.py +8 -20
  101. mcp_agent/workflows/llm/openai_utils.py +1 -1
  102. mcp_agent/workflows/llm/prompt_utils.py +1 -3
  103. mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +47 -89
  104. mcp_agent/workflows/llm/providers/multipart_converter_openai.py +20 -55
  105. mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
  106. mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +10 -12
  107. mcp_agent/workflows/llm/providers/sampling_converter_openai.py +7 -11
  108. mcp_agent/workflows/llm/sampling_converter.py +4 -11
  109. mcp_agent/workflows/llm/sampling_format_converter.py +12 -12
  110. mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
  111. mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
  112. mcp_agent/workflows/parallel/fan_in.py +17 -47
  113. mcp_agent/workflows/parallel/fan_out.py +6 -12
  114. mcp_agent/workflows/parallel/parallel_llm.py +9 -26
  115. mcp_agent/workflows/router/router_base.py +19 -49
  116. mcp_agent/workflows/router/router_embedding.py +11 -25
  117. mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
  118. mcp_agent/workflows/router/router_embedding_openai.py +2 -2
  119. mcp_agent/workflows/router/router_llm.py +12 -28
  120. mcp_agent/workflows/swarm/swarm.py +20 -48
  121. mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
  122. mcp_agent/workflows/swarm/swarm_openai.py +2 -2
  123. fast_agent_mcp-0.1.12.dist-info/RECORD +0 -161
  124. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
  125. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
  126. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
mcp_agent/core/proxies.py CHANGED
@@ -6,16 +6,18 @@ FOR COMPATIBILITY WITH LEGACY MCP-AGENT CODE
6
6
 
7
7
  """
8
8
 
9
- from typing import List, Optional, Dict, Union, TYPE_CHECKING
9
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
10
+
11
+ from mcp.types import EmbeddedResource
10
12
 
11
13
  from mcp_agent.agents.agent import Agent
12
14
  from mcp_agent.app import MCPApp
15
+ from mcp_agent.core.request_params import RequestParams
13
16
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
14
- from mcp.types import EmbeddedResource
15
17
 
16
18
  # Handle circular imports
17
19
  if TYPE_CHECKING:
18
- from mcp_agent.core.types import WorkflowType, ProxyDict
20
+ from mcp_agent.core.types import ProxyDict, WorkflowType
19
21
  else:
20
22
  # Define minimal versions for runtime
21
23
  from typing import Any
@@ -28,7 +30,7 @@ else:
28
30
  class BaseAgentProxy:
29
31
  """Base class for all proxy types"""
30
32
 
31
- def __init__(self, app: MCPApp, name: str):
33
+ def __init__(self, app: MCPApp, name: str) -> None:
32
34
  self._app = app
33
35
  self._name = name
34
36
 
@@ -39,9 +41,7 @@ class BaseAgentProxy:
39
41
  return await self.prompt()
40
42
  return await self.send(message)
41
43
 
42
- async def send(
43
- self, message: Optional[Union[str, PromptMessageMultipart]] = None
44
- ) -> str:
44
+ async def send(self, message: Optional[Union[str, PromptMessageMultipart]] = None) -> str:
45
45
  """
46
46
  Allow: agent.researcher.send('message') or agent.researcher.send(Prompt.user('message'))
47
47
 
@@ -87,9 +87,7 @@ class BaseAgentProxy:
87
87
  """Send a message to the agent and return the response"""
88
88
  raise NotImplementedError("Subclasses must implement send(prompt)")
89
89
 
90
- async def apply_prompt(
91
- self, prompt_name: str = None, arguments: dict[str, str] = None
92
- ) -> str:
90
+ async def apply_prompt(self, prompt_name: str = None, arguments: dict[str, str] = None) -> str:
93
91
  """
94
92
  Apply a Prompt from an MCP Server - implemented by subclasses.
95
93
  This is the preferred method for applying prompts.
@@ -105,7 +103,7 @@ class BaseAgentProxy:
105
103
  class LLMAgentProxy(BaseAgentProxy):
106
104
  """Proxy for regular agents that use _llm.generate_str()"""
107
105
 
108
- def __init__(self, app: MCPApp, name: str, agent: Agent):
106
+ def __init__(self, app: MCPApp, name: str, agent: Agent) -> None:
109
107
  super().__init__(app, name)
110
108
  self._agent = agent
111
109
 
@@ -117,9 +115,7 @@ class LLMAgentProxy(BaseAgentProxy):
117
115
  """Send a message to the agent and return the response"""
118
116
  return await self._agent._llm.generate_prompt(prompt, None)
119
117
 
120
- async def apply_prompt(
121
- self, prompt_name: str = None, arguments: dict[str, str] = None
122
- ) -> str:
118
+ async def apply_prompt(self, prompt_name: str = None, arguments: dict[str, str] = None) -> str:
123
119
  """
124
120
  Apply a prompt from an MCP server.
125
121
  This is the preferred method for applying prompts.
@@ -134,9 +130,7 @@ class LLMAgentProxy(BaseAgentProxy):
134
130
  return await self._agent.apply_prompt(prompt_name, arguments)
135
131
 
136
132
  # Add the new methods
137
- async def get_embedded_resources(
138
- self, server_name: str, resource_name: str
139
- ) -> List[EmbeddedResource]:
133
+ async def get_embedded_resources(self, server_name: str, resource_name: str) -> List[EmbeddedResource]:
140
134
  """
141
135
  Get a resource from an MCP server and return it as a list of embedded resources ready for use in prompts.
142
136
 
@@ -166,15 +160,32 @@ class LLMAgentProxy(BaseAgentProxy):
166
160
  Returns:
167
161
  The agent's response as a string
168
162
  """
169
- return await self._agent.with_resource(
170
- prompt_content, server_name, resource_name
171
- )
163
+ return await self._agent.with_resource(prompt_content, server_name, resource_name)
164
+
165
+ async def apply_prompt_messages(
166
+ self,
167
+ multipart_messages: List["PromptMessageMultipart"],
168
+ request_params: RequestParams | None = None,
169
+ ) -> str:
170
+ """
171
+ Apply a list of PromptMessageMultipart messages directly to the LLM.
172
+ This is a cleaner interface to _apply_prompt_template_provider_specific.
173
+
174
+ Args:
175
+ multipart_messages: List of PromptMessageMultipart objects
176
+ request_params: Optional parameters to configure the LLM request
177
+
178
+ Returns:
179
+ String representation of the assistant's response
180
+ """
181
+ # Delegate to the provider-specific implementation
182
+ return await self._agent._llm._apply_prompt_template_provider_specific(multipart_messages, request_params)
172
183
 
173
184
 
174
185
  class WorkflowProxy(BaseAgentProxy):
175
186
  """Proxy for workflow types that implement generate_str() directly"""
176
187
 
177
- def __init__(self, app: MCPApp, name: str, workflow: WorkflowType):
188
+ def __init__(self, app: MCPApp, name: str, workflow: WorkflowType) -> None:
178
189
  super().__init__(app, name)
179
190
  self._workflow = workflow
180
191
 
@@ -186,7 +197,7 @@ class WorkflowProxy(BaseAgentProxy):
186
197
  class RouterProxy(BaseAgentProxy):
187
198
  """Proxy for LLM Routers"""
188
199
 
189
- def __init__(self, app: MCPApp, name: str, workflow: WorkflowType):
200
+ def __init__(self, app: MCPApp, name: str, workflow: WorkflowType) -> None:
190
201
  super().__init__(app, name)
191
202
  self._workflow = workflow
192
203
 
@@ -215,9 +226,7 @@ class RouterProxy(BaseAgentProxy):
215
226
  class ChainProxy(BaseAgentProxy):
216
227
  """Proxy for chained agent operations"""
217
228
 
218
- def __init__(
219
- self, app: MCPApp, name: str, sequence: List[str], agent_proxies: ProxyDict
220
- ):
229
+ def __init__(self, app: MCPApp, name: str, sequence: List[str], agent_proxies: ProxyDict) -> None:
221
230
  super().__init__(app, name)
222
231
  self._sequence = sequence
223
232
  self._agent_proxies = agent_proxies
@@ -2,8 +2,11 @@
2
2
  Request parameters definitions for LLM interactions.
3
3
  """
4
4
 
5
- from pydantic import Field
5
+ from typing import List
6
+
7
+ from mcp import SamplingMessage
6
8
  from mcp.types import CreateMessageRequestParams
9
+ from pydantic import Field
7
10
 
8
11
 
9
12
  class RequestParams(CreateMessageRequestParams):
@@ -11,7 +14,7 @@ class RequestParams(CreateMessageRequestParams):
11
14
  Parameters to configure the AugmentedLLM 'generate' requests.
12
15
  """
13
16
 
14
- messages: None = Field(exclude=True, default=None)
17
+ messages: List[SamplingMessage] = Field(exclude=True, default=[])
15
18
  """
16
19
  Ignored. 'messages' are removed from CreateMessageRequestParams
17
20
  to avoid confusion with the 'message' parameter on 'generate' method.
@@ -40,4 +43,4 @@ class RequestParams(CreateMessageRequestParams):
40
43
  """
41
44
  Whether to allow multiple tool calls per iteration.
42
45
  Also known as multi-step tool use.
43
- """
46
+ """
mcp_agent/core/types.py CHANGED
@@ -2,14 +2,14 @@
2
2
  Type definitions for fast-agent core module.
3
3
  """
4
4
 
5
- from typing import Dict, Union, TypeAlias, TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Dict, TypeAlias, Union
6
6
 
7
7
  from mcp_agent.agents.agent import Agent
8
- from mcp_agent.workflows.orchestrator.orchestrator import Orchestrator
9
- from mcp_agent.workflows.parallel.parallel_llm import ParallelLLM
10
8
  from mcp_agent.workflows.evaluator_optimizer.evaluator_optimizer import (
11
9
  EvaluatorOptimizerLLM,
12
10
  )
11
+ from mcp_agent.workflows.orchestrator.orchestrator import Orchestrator
12
+ from mcp_agent.workflows.parallel.parallel_llm import ParallelLLM
13
13
  from mcp_agent.workflows.router.router_llm import LLMRouter
14
14
 
15
15
  # Avoid circular imports
@@ -17,8 +17,6 @@ if TYPE_CHECKING:
17
17
  from mcp_agent.core.proxies import BaseAgentProxy
18
18
 
19
19
  # Type aliases for better readability
20
- WorkflowType: TypeAlias = Union[
21
- Orchestrator, ParallelLLM, EvaluatorOptimizerLLM, LLMRouter
22
- ]
20
+ WorkflowType: TypeAlias = Union[Orchestrator, ParallelLLM, EvaluatorOptimizerLLM, LLMRouter]
23
21
  AgentOrWorkflow: TypeAlias = Union[Agent, WorkflowType]
24
22
  ProxyDict: TypeAlias = Dict[str, "BaseAgentProxy"] # Forward reference as string
@@ -2,14 +2,15 @@
2
2
  Validation utilities for FastAgent configuration and dependencies.
3
3
  """
4
4
 
5
- from typing import Dict, List, Any
5
+ from typing import Any, Dict, List
6
+
6
7
  from mcp_agent.core.agent_types import AgentType
7
- from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM
8
8
  from mcp_agent.core.exceptions import (
9
- ServerConfigError,
10
9
  AgentConfigError,
11
10
  CircularDependencyError,
11
+ ServerConfigError,
12
12
  )
13
+ from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM
13
14
 
14
15
 
15
16
  def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> None:
@@ -11,17 +11,15 @@ R = TypeVar("R")
11
11
  class DecoratorRegistry:
12
12
  """Centralized decorator management with validation and metadata."""
13
13
 
14
- def __init__(self):
14
+ def __init__(self) -> None:
15
15
  self._workflow_defn_decorators: Dict[str, Callable[[Type], Type]] = {}
16
- self._workflow_run_decorators: Dict[
17
- str, Callable[[Callable[..., R]], Callable[..., R]]
18
- ] = {}
16
+ self._workflow_run_decorators: Dict[str, Callable[[Callable[..., R]], Callable[..., R]]] = {}
19
17
 
20
18
  def register_workflow_defn_decorator(
21
19
  self,
22
20
  executor_name: str,
23
21
  decorator: Callable[[Type], Type],
24
- ):
22
+ ) -> None:
25
23
  """
26
24
  Registers a workflow definition decorator for a given executor.
27
25
 
@@ -48,7 +46,7 @@ class DecoratorRegistry:
48
46
  self,
49
47
  executor_name: str,
50
48
  decorator: Callable[[Callable[..., R]], Callable[..., R]],
51
- ):
49
+ ) -> None:
52
50
  """
53
51
  Registers a workflow run decorator for a given executor.
54
52
 
@@ -62,9 +60,7 @@ class DecoratorRegistry:
62
60
  )
63
61
  self._workflow_run_decorators[executor_name] = decorator
64
62
 
65
- def get_workflow_run_decorator(
66
- self, executor_name: str
67
- ) -> Callable[[Callable[..., R]], Callable[..., R]]:
63
+ def get_workflow_run_decorator(self, executor_name: str) -> Callable[[Callable[..., R]], Callable[..., R]]:
68
64
  """
69
65
  Retrieves a workflow run decorator for a given executor.
70
66
 
@@ -88,18 +84,14 @@ def default_workflow_run(fn: Callable[..., R]) -> Callable[..., R]:
88
84
  return wrapper
89
85
 
90
86
 
91
- def register_asyncio_decorators(decorator_registry: DecoratorRegistry):
87
+ def register_asyncio_decorators(decorator_registry: DecoratorRegistry) -> None:
92
88
  """Registers default asyncio decorators."""
93
89
  executor_name = "asyncio"
94
- decorator_registry.register_workflow_defn_decorator(
95
- executor_name, default_workflow_defn
96
- )
97
- decorator_registry.register_workflow_run_decorator(
98
- executor_name, default_workflow_run
99
- )
90
+ decorator_registry.register_workflow_defn_decorator(executor_name, default_workflow_defn)
91
+ decorator_registry.register_workflow_run_decorator(executor_name, default_workflow_run)
100
92
 
101
93
 
102
- def register_temporal_decorators(decorator_registry: DecoratorRegistry):
94
+ def register_temporal_decorators(decorator_registry: DecoratorRegistry) -> None:
103
95
  """Registers Temporal decorators if Temporal SDK is available."""
104
96
  try:
105
97
  import temporalio.workflow as temporal_workflow
@@ -112,9 +104,5 @@ def register_temporal_decorators(decorator_registry: DecoratorRegistry):
112
104
  return
113
105
 
114
106
  executor_name = "temporal"
115
- decorator_registry.register_workflow_defn_decorator(
116
- executor_name, temporal_workflow.defn
117
- )
118
- decorator_registry.register_workflow_run_decorator(
119
- executor_name, temporal_workflow.run
120
- )
107
+ decorator_registry.register_workflow_defn_decorator(executor_name, temporal_workflow.defn)
108
+ decorator_registry.register_workflow_run_decorator(executor_name, temporal_workflow.run)
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
4
  from contextlib import asynccontextmanager
5
5
  from datetime import timedelta
6
6
  from typing import (
7
+ TYPE_CHECKING,
7
8
  Any,
8
9
  AsyncIterator,
9
10
  Callable,
@@ -13,7 +14,6 @@ from typing import (
13
14
  Optional,
14
15
  Type,
15
16
  TypeVar,
16
- TYPE_CHECKING,
17
17
  )
18
18
 
19
19
  from pydantic import BaseModel, ConfigDict
@@ -56,7 +56,7 @@ class Executor(ABC, ContextDependent):
56
56
  signal_bus: SignalHandler = None,
57
57
  context: Optional["Context"] = None,
58
58
  **kwargs,
59
- ):
59
+ ) -> None:
60
60
  super().__init__(context=context, **kwargs)
61
61
  self.execution_engine = engine
62
62
 
@@ -127,9 +127,7 @@ class Executor(ABC, ContextDependent):
127
127
 
128
128
  return results
129
129
 
130
- async def validate_task(
131
- self, task: Callable[..., R] | Coroutine[Any, Any, R]
132
- ) -> None:
130
+ async def validate_task(self, task: Callable[..., R] | Coroutine[Any, Any, R]) -> None:
133
131
  """Validate a task before execution."""
134
132
  if not (asyncio.iscoroutine(task) or asyncio.iscoroutinefunction(task)):
135
133
  raise TypeError(f"Task must be async: {task}")
@@ -164,7 +162,7 @@ class Executor(ABC, ContextDependent):
164
162
 
165
163
  # Notify any callbacks that the workflow is about to be paused waiting for a signal
166
164
  if self.context.signal_notification:
167
- self.context.signal_notification(
165
+ await self.context.signal_notification(
168
166
  signal_name=signal_name,
169
167
  request_id=request_id,
170
168
  workflow_id=workflow_id,
@@ -188,15 +186,13 @@ class AsyncioExecutor(Executor):
188
186
  self,
189
187
  config: ExecutorConfig | None = None,
190
188
  signal_bus: SignalHandler | None = None,
191
- ):
189
+ ) -> None:
192
190
  signal_bus = signal_bus or AsyncioSignalHandler()
193
191
  super().__init__(engine="asyncio", config=config, signal_bus=signal_bus)
194
192
 
195
193
  self._activity_semaphore: asyncio.Semaphore | None = None
196
194
  if self.config.max_concurrent_activities is not None:
197
- self._activity_semaphore = asyncio.Semaphore(
198
- self.config.max_concurrent_activities
199
- )
195
+ self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
200
196
 
201
197
  async def _execute_task(
202
198
  self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
@@ -253,16 +249,11 @@ class AsyncioExecutor(Executor):
253
249
  # TODO: saqadri - validate if async with self.execution_context() is needed here
254
250
  async with self.execution_context():
255
251
  # Create futures for all tasks
256
- futures = [
257
- asyncio.create_task(self._execute_task(task, **kwargs))
258
- for task in tasks
259
- ]
252
+ futures = [asyncio.create_task(self._execute_task(task, **kwargs)) for task in tasks]
260
253
  pending = set(futures)
261
254
 
262
255
  while pending:
263
- done, pending = await asyncio.wait(
264
- pending, return_when=asyncio.FIRST_COMPLETED
265
- )
256
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
266
257
  for future in done:
267
258
  yield await future
268
259
 
@@ -10,13 +10,11 @@ from typing import Any, Callable, Dict, List
10
10
  class ActivityRegistry:
11
11
  """Centralized task/activity management with validation and metadata."""
12
12
 
13
- def __init__(self):
13
+ def __init__(self) -> None:
14
14
  self._activities: Dict[str, Callable] = {}
15
15
  self._metadata: Dict[str, Dict[str, Any]] = {}
16
16
 
17
- def register(
18
- self, name: str, func: Callable, metadata: Dict[str, Any] | None = None
19
- ):
17
+ def register(self, name: str, func: Callable, metadata: Dict[str, Any] | None = None) -> None:
20
18
  if name in self._activities:
21
19
  raise ValueError(f"Activity '{name}' is already registered.")
22
20
  self._activities[name] = func
@@ -9,6 +9,7 @@ import asyncio
9
9
  import functools
10
10
  import uuid
11
11
  from typing import (
12
+ TYPE_CHECKING,
12
13
  Any,
13
14
  AsyncIterator,
14
15
  Callable,
@@ -16,11 +17,10 @@ from typing import (
16
17
  Dict,
17
18
  List,
18
19
  Optional,
19
- TYPE_CHECKING,
20
20
  )
21
21
 
22
22
  from pydantic import ConfigDict
23
- from temporalio import activity, workflow, exceptions
23
+ from temporalio import activity, exceptions, workflow
24
24
  from temporalio.client import Client as TemporalClient
25
25
  from temporalio.worker import Worker
26
26
 
@@ -43,9 +43,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
43
43
 
44
44
  async def wait_for_signal(self, signal, timeout_seconds=None) -> SignalValueT:
45
45
  if not workflow._Runtime.current():
46
- raise RuntimeError(
47
- "TemporalSignalHandler.wait_for_signal must be called from within a workflow"
48
- )
46
+ raise RuntimeError("TemporalSignalHandler.wait_for_signal must be called from within a workflow")
49
47
 
50
48
  unique_signal_name = f"{signal.name}_{uuid.uuid4()}"
51
49
  registration = SignalRegistration(
@@ -59,22 +57,18 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
59
57
 
60
58
  # Define the signal handler for this specific registration
61
59
  @workflow.signal(name=unique_signal_name)
62
- def signal_handler(value: SignalValueT):
60
+ def signal_handler(value: SignalValueT) -> None:
63
61
  container["value"] = value
64
62
  container["completed"] = True
65
63
 
66
64
  async with self._lock:
67
65
  # Register both the signal registration and handler atomically
68
66
  self._pending_signals.setdefault(signal.name, []).append(registration)
69
- self._handlers.setdefault(signal.name, []).append(
70
- (unique_signal_name, signal_handler)
71
- )
67
+ self._handlers.setdefault(signal.name, []).append((unique_signal_name, signal_handler))
72
68
 
73
69
  try:
74
70
  # Wait for signal with optional timeout
75
- await workflow.wait_condition(
76
- lambda: container["completed"], timeout=timeout_seconds
77
- )
71
+ await workflow.wait_condition(lambda: container["completed"], timeout=timeout_seconds)
78
72
 
79
73
  return container["value"]
80
74
  except asyncio.TimeoutError as exc:
@@ -83,21 +77,13 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
83
77
  async with self._lock:
84
78
  # Remove ourselves from _pending_signals
85
79
  if signal.name in self._pending_signals:
86
- self._pending_signals[signal.name] = [
87
- sr
88
- for sr in self._pending_signals[signal.name]
89
- if sr.unique_name != unique_signal_name
90
- ]
80
+ self._pending_signals[signal.name] = [sr for sr in self._pending_signals[signal.name] if sr.unique_name != unique_signal_name]
91
81
  if not self._pending_signals[signal.name]:
92
82
  del self._pending_signals[signal.name]
93
83
 
94
84
  # Remove ourselves from _handlers
95
85
  if signal.name in self._handlers:
96
- self._handlers[signal.name] = [
97
- h
98
- for h in self._handlers[signal.name]
99
- if h[0] != unique_signal_name
100
- ]
86
+ self._handlers[signal.name] = [h for h in self._handlers[signal.name] if h[0] != unique_signal_name]
101
87
  if not self._handlers[signal.name]:
102
88
  del self._handlers[signal.name]
103
89
 
@@ -110,7 +96,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
110
96
 
111
97
  # Create the actual handler that will be registered with Temporal
112
98
  @workflow.signal(name=unique_signal_name)
113
- async def wrapped(signal_value: SignalValueT):
99
+ async def wrapped(signal_value: SignalValueT) -> None:
114
100
  # Create a signal object to pass to the handler
115
101
  signal = Signal(
116
102
  name=signal_name,
@@ -123,19 +109,15 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
123
109
  func(signal)
124
110
 
125
111
  # Register the handler under the original signal name
126
- self._handlers.setdefault(signal_name, []).append(
127
- (unique_signal_name, wrapped)
128
- )
112
+ self._handlers.setdefault(signal_name, []).append((unique_signal_name, wrapped))
129
113
  return func
130
114
 
131
115
  return decorator
132
116
 
133
- async def signal(self, signal):
117
+ async def signal(self, signal) -> None:
134
118
  self.validate_signal(signal)
135
119
 
136
- workflow_handle = workflow.get_external_workflow_handle(
137
- workflow_id=signal.workflow_id
138
- )
120
+ workflow_handle = workflow.get_external_workflow_handle(workflow_id=signal.workflow_id)
139
121
 
140
122
  # Send the signal to all registrations of this signal
141
123
  async with self._lock:
@@ -146,30 +128,22 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
146
128
  registration = pending_signal.registration
147
129
  if registration.workflow_id == signal.workflow_id:
148
130
  # Only signal for registrations of that workflow
149
- signal_tasks.append(
150
- workflow_handle.signal(
151
- registration.unique_name, signal.payload
152
- )
153
- )
131
+ signal_tasks.append(workflow_handle.signal(registration.unique_name, signal.payload))
154
132
  else:
155
133
  continue
156
134
 
157
135
  # Notify any registered handler functions
158
136
  if signal.name in self._handlers:
159
137
  for unique_name, _ in self._handlers[signal.name]:
160
- signal_tasks.append(
161
- workflow_handle.signal(unique_name, signal.payload)
162
- )
138
+ signal_tasks.append(workflow_handle.signal(unique_name, signal.payload))
163
139
 
164
140
  await asyncio.gather(*signal_tasks, return_exceptions=True)
165
141
 
166
- def validate_signal(self, signal):
142
+ def validate_signal(self, signal) -> None:
167
143
  super().validate_signal(signal)
168
144
  # Add TemporalSignalHandler-specific validation
169
145
  if signal.workflow_id is None:
170
- raise ValueError(
171
- "No workflow_id provided on Signal. That is required for Temporal signals"
172
- )
146
+ raise ValueError("No workflow_id provided on Signal. That is required for Temporal signals")
173
147
 
174
148
 
175
149
  class TemporalExecutorConfig(ExecutorConfig, TemporalSettings):
@@ -188,7 +162,7 @@ class TemporalExecutor(Executor):
188
162
  client: TemporalClient | None = None,
189
163
  context: Optional["Context"] = None,
190
164
  **kwargs,
191
- ):
165
+ ) -> None:
192
166
  signal_bus = signal_bus or TemporalSignalHandler()
193
167
  super().__init__(
194
168
  engine="temporal",
@@ -197,17 +171,13 @@ class TemporalExecutor(Executor):
197
171
  context=context,
198
172
  **kwargs,
199
173
  )
200
- self.config: TemporalExecutorConfig = (
201
- config or self.context.config.temporal or TemporalExecutorConfig()
202
- )
174
+ self.config: TemporalExecutorConfig = config or self.context.config.temporal or TemporalExecutorConfig()
203
175
  self.client = client
204
176
  self._worker = None
205
177
  self._activity_semaphore = None
206
178
 
207
179
  if config.max_concurrent_activities is not None:
208
- self._activity_semaphore = asyncio.Semaphore(
209
- self.config.max_concurrent_activities
210
- )
180
+ self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
211
181
 
212
182
  @staticmethod
213
183
  def wrap_as_activity(
@@ -234,9 +204,7 @@ class TemporalExecutor(Executor):
234
204
 
235
205
  return wrapped_activity
236
206
 
237
- async def _execute_task_as_async(
238
- self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
239
- ) -> R | BaseException:
207
+ async def _execute_task_as_async(self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any) -> R | BaseException:
240
208
  async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R:
241
209
  try:
242
210
  if asyncio.iscoroutine(task):
@@ -269,15 +237,11 @@ class TemporalExecutor(Executor):
269
237
  else:
270
238
  return await run_task(task)
271
239
 
272
- async def _execute_task(
273
- self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
274
- ) -> R | BaseException:
240
+ async def _execute_task(self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any) -> R | BaseException:
275
241
  func = task.func if isinstance(task, functools.partial) else task
276
242
  is_workflow_task = getattr(func, "is_workflow_task", False)
277
243
  if not is_workflow_task:
278
- return await asyncio.create_task(
279
- self._execute_task_as_async(task, **kwargs)
280
- )
244
+ return await asyncio.create_task(self._execute_task_as_async(task, **kwargs))
281
245
 
282
246
  execution_metadata: Dict[str, Any] = getattr(func, "execution_metadata", {})
283
247
 
@@ -286,9 +250,7 @@ class TemporalExecutor(Executor):
286
250
  if not activity_name:
287
251
  activity_name = f"{func.__module__}.{func.__qualname__}"
288
252
 
289
- schedule_to_close = execution_metadata.get(
290
- "schedule_to_close_timeout", self.config.timeout_seconds
291
- )
253
+ schedule_to_close = execution_metadata.get("schedule_to_close_timeout", self.config.timeout_seconds)
292
254
 
293
255
  retry_policy = execution_metadata.get("retry_policy", None)
294
256
 
@@ -319,9 +281,7 @@ class TemporalExecutor(Executor):
319
281
  ) -> List[R | BaseException]:
320
282
  # Must be called from within a workflow
321
283
  if not workflow._Runtime.current():
322
- raise RuntimeError(
323
- "TemporalExecutor.execute must be called from within a workflow"
324
- )
284
+ raise RuntimeError("TemporalExecutor.execute must be called from within a workflow")
325
285
 
326
286
  # TODO: saqadri - validate if async with self.execution_context() is needed here
327
287
  async with self.execution_context():
@@ -336,9 +296,7 @@ class TemporalExecutor(Executor):
336
296
  **kwargs: Any,
337
297
  ) -> AsyncIterator[R | BaseException]:
338
298
  if not workflow._Runtime.current():
339
- raise RuntimeError(
340
- "TemporalExecutor.execute_streaming must be called from within a workflow"
341
- )
299
+ raise RuntimeError("TemporalExecutor.execute_streaming must be called from within a workflow")
342
300
 
343
301
  # TODO: saqadri - validate if async with self.execution_context() is needed here
344
302
  async with self.execution_context():
@@ -347,9 +305,7 @@ class TemporalExecutor(Executor):
347
305
  pending = set(futures)
348
306
 
349
307
  while pending:
350
- done, pending = await workflow.wait(
351
- pending, return_when=asyncio.FIRST_COMPLETED
352
- )
308
+ done, pending = await workflow.wait(pending, return_when=asyncio.FIRST_COMPLETED)
353
309
  for future in done:
354
310
  try:
355
311
  result = await future
@@ -368,7 +324,7 @@ class TemporalExecutor(Executor):
368
324
 
369
325
  return self.client
370
326
 
371
- async def start_worker(self):
327
+ async def start_worker(self) -> None:
372
328
  """
373
329
  Start a worker in this process, auto-registering all tasks
374
330
  from the global registry. Also picks up any classes decorated
@@ -398,8 +354,6 @@ class TemporalExecutor(Executor):
398
354
  activities=activities,
399
355
  workflows=[], # We'll auto-load by Python scanning or let the user specify
400
356
  )
401
- print(
402
- f"Starting Temporal Worker on task queue '{self.config.task_queue}' with {len(activities)} activities."
403
- )
357
+ print(f"Starting Temporal Worker on task queue '{self.config.task_queue}' with {len(activities)} activities.")
404
358
 
405
359
  await self._worker.run()
@@ -62,7 +62,7 @@ class Workflow(ABC, Generic[T]):
62
62
  name: str | None = None,
63
63
  metadata: Dict[str, Any] | None = None,
64
64
  **kwargs: Any,
65
- ):
65
+ ) -> None:
66
66
  self.executor = executor
67
67
  self.name = name or self.__class__.__name__
68
68
  self.init_kwargs = kwargs
@@ -80,7 +80,7 @@ class Workflow(ABC, Generic[T]):
80
80
  Main workflow implementation. Must be overridden by subclasses.
81
81
  """
82
82
 
83
- async def update_state(self, **kwargs):
83
+ async def update_state(self, **kwargs) -> None:
84
84
  """Syntactic sugar to update workflow state."""
85
85
  for key, value in kwargs.items():
86
86
  self.state[key] = value
@@ -93,9 +93,7 @@ class Workflow(ABC, Generic[T]):
93
93
  Convenience method for human input. Uses `human_input` signal
94
94
  so we can unify local (console input) and Temporal signals.
95
95
  """
96
- return await self.executor.wait_for_signal(
97
- "human_input", description=description
98
- )
96
+ return await self.executor.wait_for_signal("human_input", description=description)
99
97
 
100
98
 
101
99
  # ############################