fast-agent-mcp 0.1.12__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/METADATA +3 -4
  2. fast_agent_mcp-0.2.0.dist-info/RECORD +123 -0
  3. mcp_agent/__init__.py +75 -0
  4. mcp_agent/agents/agent.py +61 -415
  5. mcp_agent/agents/base_agent.py +522 -0
  6. mcp_agent/agents/workflow/__init__.py +1 -0
  7. mcp_agent/agents/workflow/chain_agent.py +173 -0
  8. mcp_agent/agents/workflow/evaluator_optimizer.py +362 -0
  9. mcp_agent/agents/workflow/orchestrator_agent.py +591 -0
  10. mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_models.py +11 -21
  11. mcp_agent/agents/workflow/parallel_agent.py +182 -0
  12. mcp_agent/agents/workflow/router_agent.py +307 -0
  13. mcp_agent/app.py +15 -19
  14. mcp_agent/cli/commands/bootstrap.py +19 -38
  15. mcp_agent/cli/commands/config.py +4 -4
  16. mcp_agent/cli/commands/setup.py +7 -14
  17. mcp_agent/cli/main.py +7 -10
  18. mcp_agent/cli/terminal.py +3 -3
  19. mcp_agent/config.py +25 -40
  20. mcp_agent/context.py +12 -21
  21. mcp_agent/context_dependent.py +3 -5
  22. mcp_agent/core/agent_types.py +10 -7
  23. mcp_agent/core/direct_agent_app.py +179 -0
  24. mcp_agent/core/direct_decorators.py +443 -0
  25. mcp_agent/core/direct_factory.py +476 -0
  26. mcp_agent/core/enhanced_prompt.py +23 -55
  27. mcp_agent/core/exceptions.py +8 -8
  28. mcp_agent/core/fastagent.py +145 -371
  29. mcp_agent/core/interactive_prompt.py +424 -0
  30. mcp_agent/core/mcp_content.py +17 -17
  31. mcp_agent/core/prompt.py +6 -9
  32. mcp_agent/core/request_params.py +6 -3
  33. mcp_agent/core/validation.py +92 -18
  34. mcp_agent/executor/decorator_registry.py +9 -17
  35. mcp_agent/executor/executor.py +8 -17
  36. mcp_agent/executor/task_registry.py +2 -4
  37. mcp_agent/executor/temporal.py +19 -41
  38. mcp_agent/executor/workflow.py +3 -5
  39. mcp_agent/executor/workflow_signal.py +15 -21
  40. mcp_agent/human_input/handler.py +4 -7
  41. mcp_agent/human_input/types.py +2 -3
  42. mcp_agent/llm/__init__.py +2 -0
  43. mcp_agent/llm/augmented_llm.py +450 -0
  44. mcp_agent/llm/augmented_llm_passthrough.py +162 -0
  45. mcp_agent/llm/augmented_llm_playback.py +83 -0
  46. mcp_agent/llm/memory.py +103 -0
  47. mcp_agent/{workflows/llm → llm}/model_factory.py +22 -16
  48. mcp_agent/{workflows/llm → llm}/prompt_utils.py +1 -3
  49. mcp_agent/llm/providers/__init__.py +8 -0
  50. mcp_agent/{workflows/llm → llm/providers}/anthropic_utils.py +8 -25
  51. mcp_agent/{workflows/llm → llm/providers}/augmented_llm_anthropic.py +56 -194
  52. mcp_agent/llm/providers/augmented_llm_deepseek.py +53 -0
  53. mcp_agent/{workflows/llm → llm/providers}/augmented_llm_openai.py +99 -190
  54. mcp_agent/{workflows/llm → llm}/providers/multipart_converter_anthropic.py +72 -71
  55. mcp_agent/{workflows/llm → llm}/providers/multipart_converter_openai.py +65 -71
  56. mcp_agent/{workflows/llm → llm}/providers/openai_multipart.py +16 -44
  57. mcp_agent/{workflows/llm → llm/providers}/openai_utils.py +4 -4
  58. mcp_agent/{workflows/llm → llm}/providers/sampling_converter_anthropic.py +9 -11
  59. mcp_agent/{workflows/llm → llm}/providers/sampling_converter_openai.py +8 -12
  60. mcp_agent/{workflows/llm → llm}/sampling_converter.py +3 -31
  61. mcp_agent/llm/sampling_format_converter.py +37 -0
  62. mcp_agent/logging/events.py +1 -5
  63. mcp_agent/logging/json_serializer.py +7 -6
  64. mcp_agent/logging/listeners.py +20 -23
  65. mcp_agent/logging/logger.py +17 -19
  66. mcp_agent/logging/rich_progress.py +10 -8
  67. mcp_agent/logging/tracing.py +4 -6
  68. mcp_agent/logging/transport.py +22 -22
  69. mcp_agent/mcp/gen_client.py +1 -3
  70. mcp_agent/mcp/interfaces.py +117 -110
  71. mcp_agent/mcp/logger_textio.py +97 -0
  72. mcp_agent/mcp/mcp_agent_client_session.py +7 -7
  73. mcp_agent/mcp/mcp_agent_server.py +8 -8
  74. mcp_agent/mcp/mcp_aggregator.py +102 -143
  75. mcp_agent/mcp/mcp_connection_manager.py +20 -27
  76. mcp_agent/mcp/prompt_message_multipart.py +68 -16
  77. mcp_agent/mcp/prompt_render.py +77 -0
  78. mcp_agent/mcp/prompt_serialization.py +30 -48
  79. mcp_agent/mcp/prompts/prompt_constants.py +18 -0
  80. mcp_agent/mcp/prompts/prompt_helpers.py +327 -0
  81. mcp_agent/mcp/prompts/prompt_load.py +109 -0
  82. mcp_agent/mcp/prompts/prompt_server.py +155 -195
  83. mcp_agent/mcp/prompts/prompt_template.py +35 -66
  84. mcp_agent/mcp/resource_utils.py +7 -14
  85. mcp_agent/mcp/sampling.py +17 -17
  86. mcp_agent/mcp_server/agent_server.py +13 -17
  87. mcp_agent/mcp_server_registry.py +13 -22
  88. mcp_agent/resources/examples/{workflows → in_dev}/agent_build.py +3 -2
  89. mcp_agent/resources/examples/in_dev/slides.py +110 -0
  90. mcp_agent/resources/examples/internal/agent.py +6 -3
  91. mcp_agent/resources/examples/internal/fastagent.config.yaml +8 -2
  92. mcp_agent/resources/examples/internal/job.py +2 -1
  93. mcp_agent/resources/examples/internal/prompt_category.py +1 -1
  94. mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
  95. mcp_agent/resources/examples/internal/sizer.py +2 -1
  96. mcp_agent/resources/examples/internal/social.py +2 -1
  97. mcp_agent/resources/examples/prompting/agent.py +2 -1
  98. mcp_agent/resources/examples/prompting/image_server.py +4 -8
  99. mcp_agent/resources/examples/prompting/work_with_image.py +19 -0
  100. mcp_agent/ui/console_display.py +16 -20
  101. fast_agent_mcp-0.1.12.dist-info/RECORD +0 -161
  102. mcp_agent/core/agent_app.py +0 -646
  103. mcp_agent/core/agent_utils.py +0 -71
  104. mcp_agent/core/decorators.py +0 -455
  105. mcp_agent/core/factory.py +0 -463
  106. mcp_agent/core/proxies.py +0 -269
  107. mcp_agent/core/types.py +0 -24
  108. mcp_agent/eval/__init__.py +0 -0
  109. mcp_agent/mcp/stdio.py +0 -111
  110. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +0 -188
  111. mcp_agent/resources/examples/data-analysis/analysis.py +0 -65
  112. mcp_agent/resources/examples/data-analysis/fastagent.config.yaml +0 -41
  113. mcp_agent/resources/examples/data-analysis/mount-point/WA_Fn-UseC_-HR-Employee-Attrition.csv +0 -1471
  114. mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +0 -53
  115. mcp_agent/resources/examples/researcher/fastagent.config.yaml +0 -66
  116. mcp_agent/resources/examples/researcher/researcher-eval.py +0 -53
  117. mcp_agent/resources/examples/researcher/researcher-imp.py +0 -190
  118. mcp_agent/resources/examples/researcher/researcher.py +0 -38
  119. mcp_agent/resources/examples/workflows/chaining.py +0 -44
  120. mcp_agent/resources/examples/workflows/evaluator.py +0 -78
  121. mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -24
  122. mcp_agent/resources/examples/workflows/human_input.py +0 -25
  123. mcp_agent/resources/examples/workflows/orchestrator.py +0 -73
  124. mcp_agent/resources/examples/workflows/parallel.py +0 -78
  125. mcp_agent/resources/examples/workflows/router.py +0 -53
  126. mcp_agent/resources/examples/workflows/sse.py +0 -23
  127. mcp_agent/telemetry/__init__.py +0 -0
  128. mcp_agent/telemetry/usage_tracking.py +0 -18
  129. mcp_agent/workflows/__init__.py +0 -0
  130. mcp_agent/workflows/embedding/__init__.py +0 -0
  131. mcp_agent/workflows/embedding/embedding_base.py +0 -61
  132. mcp_agent/workflows/embedding/embedding_cohere.py +0 -49
  133. mcp_agent/workflows/embedding/embedding_openai.py +0 -46
  134. mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
  135. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +0 -481
  136. mcp_agent/workflows/intent_classifier/__init__.py +0 -0
  137. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +0 -120
  138. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +0 -134
  139. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +0 -45
  140. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +0 -45
  141. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +0 -161
  142. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +0 -60
  143. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +0 -60
  144. mcp_agent/workflows/llm/__init__.py +0 -0
  145. mcp_agent/workflows/llm/augmented_llm.py +0 -753
  146. mcp_agent/workflows/llm/augmented_llm_passthrough.py +0 -241
  147. mcp_agent/workflows/llm/augmented_llm_playback.py +0 -109
  148. mcp_agent/workflows/llm/providers/__init__.py +0 -8
  149. mcp_agent/workflows/llm/sampling_format_converter.py +0 -22
  150. mcp_agent/workflows/orchestrator/__init__.py +0 -0
  151. mcp_agent/workflows/orchestrator/orchestrator.py +0 -578
  152. mcp_agent/workflows/parallel/__init__.py +0 -0
  153. mcp_agent/workflows/parallel/fan_in.py +0 -350
  154. mcp_agent/workflows/parallel/fan_out.py +0 -187
  155. mcp_agent/workflows/parallel/parallel_llm.py +0 -166
  156. mcp_agent/workflows/router/__init__.py +0 -0
  157. mcp_agent/workflows/router/router_base.py +0 -368
  158. mcp_agent/workflows/router/router_embedding.py +0 -240
  159. mcp_agent/workflows/router/router_embedding_cohere.py +0 -59
  160. mcp_agent/workflows/router/router_embedding_openai.py +0 -59
  161. mcp_agent/workflows/router/router_llm.py +0 -320
  162. mcp_agent/workflows/swarm/__init__.py +0 -0
  163. mcp_agent/workflows/swarm/swarm.py +0 -320
  164. mcp_agent/workflows/swarm/swarm_anthropic.py +0 -42
  165. mcp_agent/workflows/swarm/swarm_openai.py +0 -41
  166. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/WHEEL +0 -0
  167. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/entry_points.txt +0 -0
  168. {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.2.0.dist-info}/licenses/LICENSE +0 -0
  169. /mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_prompts.py +0 -0
@@ -2,14 +2,15 @@
2
2
  Validation utilities for FastAgent configuration and dependencies.
3
3
  """
4
4
 
5
- from typing import Dict, List, Any
5
+ from typing import Any, Dict, List
6
+
6
7
  from mcp_agent.core.agent_types import AgentType
7
- from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM
8
8
  from mcp_agent.core.exceptions import (
9
- ServerConfigError,
10
9
  AgentConfigError,
11
10
  CircularDependencyError,
11
+ ServerConfigError,
12
12
  )
13
+ from mcp_agent.llm.augmented_llm import AugmentedLLM
13
14
 
14
15
 
15
16
  def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> None:
@@ -55,7 +56,7 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
55
56
  if agent_type == AgentType.PARALLEL.value:
56
57
  # Check fan_in exists
57
58
  fan_in = agent_data["fan_in"]
58
- if fan_in not in available_components:
59
+ if fan_in and fan_in not in available_components:
59
60
  raise AgentConfigError(
60
61
  f"Parallel workflow '{name}' references non-existent fan_in component: {fan_in}"
61
62
  )
@@ -105,7 +106,7 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
105
106
 
106
107
  elif agent_type == AgentType.ROUTER.value:
107
108
  # Check all referenced agents exist
108
- router_agents = agent_data["agents"]
109
+ router_agents = agent_data["router_agents"]
109
110
  missing = [a for a in router_agents if a not in available_components]
110
111
  if missing:
111
112
  raise AgentConfigError(
@@ -186,7 +187,7 @@ def get_dependencies(
186
187
  deps.extend(get_dependencies(fan_out, agents, visited, path, agent_type))
187
188
  elif config["type"] == AgentType.CHAIN.value:
188
189
  # Get dependencies from sequence agents
189
- sequence = config.get("sequence", config.get("agents", []))
190
+ sequence = config.get("sequence", config.get("router_agents", []))
190
191
  for agent_name in sequence:
191
192
  deps.extend(get_dependencies(agent_name, agents, visited, path, agent_type))
192
193
 
@@ -198,23 +199,96 @@ def get_dependencies(
198
199
  return deps
199
200
 
200
201
 
201
- def get_parallel_dependencies(
202
- name: str, agents: Dict[str, Dict[str, Any]], visited: set, path: set
203
- ) -> List[str]:
202
+ def get_dependencies_groups(
203
+ agents_dict: Dict[str, Dict[str, Any]], allow_cycles: bool = False
204
+ ) -> List[List[str]]:
204
205
  """
205
- Get dependencies for a parallel agent in topological order.
206
- Legacy function that calls the more general get_dependencies.
206
+ Get dependencies between agents and group them into dependency layers.
207
+ Each layer can be initialized in parallel.
207
208
 
208
209
  Args:
209
- name: Name of the parallel agent
210
- agents: Dictionary of agent configurations
211
- visited: Set of already visited agents
212
- path: Current path for cycle detection
210
+ agents_dict: Dictionary of agent configurations
211
+ allow_cycles: Whether to allow cyclic dependencies
213
212
 
214
213
  Returns:
215
- List of agent names in dependency order
214
+ List of lists, where each inner list is a group of agents that can be initialized together
216
215
 
217
216
  Raises:
218
- CircularDependencyError: If circular dependency detected
217
+ CircularDependencyError: If circular dependency detected and allow_cycles is False
219
218
  """
220
- return get_dependencies(name, agents, visited, path, AgentType.PARALLEL)
219
+ # Get all agent names
220
+ agent_names = list(agents_dict.keys())
221
+
222
+ # Dictionary to store dependencies for each agent
223
+ dependencies = {name: set() for name in agent_names}
224
+
225
+ # Build the dependency graph
226
+ for name, agent_data in agents_dict.items():
227
+ agent_type = agent_data["type"]
228
+
229
+ if agent_type == AgentType.PARALLEL.value:
230
+ # Parallel agents depend on their fan-out and fan-in agents
231
+ dependencies[name].update(agent_data.get("parallel_agents", []))
232
+ elif agent_type == AgentType.CHAIN.value:
233
+ # Chain agents depend on the agents in their sequence
234
+ dependencies[name].update(agent_data.get("chain_agents", []))
235
+ elif agent_type == AgentType.ROUTER.value:
236
+ # Router agents depend on the agents they route to
237
+ dependencies[name].update(agent_data.get("router_agents", []))
238
+ elif agent_type == AgentType.ORCHESTRATOR.value:
239
+ # Orchestrator agents depend on their child agents
240
+ dependencies[name].update(agent_data.get("child_agents", []))
241
+ elif agent_type == AgentType.EVALUATOR_OPTIMIZER.value:
242
+ # Evaluator-Optimizer agents depend on their evaluation and optimization agents
243
+ dependencies[name].update(agent_data.get("eval_optimizer_agents", []))
244
+
245
+ # Check for cycles if not allowed
246
+ if not allow_cycles:
247
+ visited = set()
248
+ path = set()
249
+
250
+ def visit(node) -> None:
251
+ if node in path:
252
+ path_str = " -> ".join(path) + " -> " + node
253
+ raise CircularDependencyError(f"Circular dependency detected: {path_str}")
254
+ if node in visited:
255
+ return
256
+
257
+ path.add(node)
258
+ for dep in dependencies[node]:
259
+ if dep in agent_names: # Skip dependencies to non-existent agents
260
+ visit(dep)
261
+ path.remove(node)
262
+ visited.add(node)
263
+
264
+ # Check each node
265
+ for name in agent_names:
266
+ if name not in visited:
267
+ visit(name)
268
+
269
+ # Group agents by dependency level
270
+ result = []
271
+ remaining = set(agent_names)
272
+
273
+ while remaining:
274
+ # Find all agents that have no remaining dependencies
275
+ current_level = set()
276
+ for name in remaining:
277
+ if not dependencies[name] & remaining: # If no dependencies in remaining agents
278
+ current_level.add(name)
279
+
280
+ if not current_level:
281
+ if allow_cycles:
282
+ # If cycles are allowed, just add one remaining node to break the cycle
283
+ current_level.add(next(iter(remaining)))
284
+ else:
285
+ # This should not happen if we checked for cycles
286
+ raise CircularDependencyError("Unresolvable dependency cycle detected")
287
+
288
+ # Add the current level to the result
289
+ result.append(list(current_level))
290
+
291
+ # Remove current level from remaining
292
+ remaining -= current_level
293
+
294
+ return result
@@ -11,7 +11,7 @@ R = TypeVar("R")
11
11
  class DecoratorRegistry:
12
12
  """Centralized decorator management with validation and metadata."""
13
13
 
14
- def __init__(self):
14
+ def __init__(self) -> None:
15
15
  self._workflow_defn_decorators: Dict[str, Callable[[Type], Type]] = {}
16
16
  self._workflow_run_decorators: Dict[
17
17
  str, Callable[[Callable[..., R]], Callable[..., R]]
@@ -21,7 +21,7 @@ class DecoratorRegistry:
21
21
  self,
22
22
  executor_name: str,
23
23
  decorator: Callable[[Type], Type],
24
- ):
24
+ ) -> None:
25
25
  """
26
26
  Registers a workflow definition decorator for a given executor.
27
27
 
@@ -48,7 +48,7 @@ class DecoratorRegistry:
48
48
  self,
49
49
  executor_name: str,
50
50
  decorator: Callable[[Callable[..., R]], Callable[..., R]],
51
- ):
51
+ ) -> None:
52
52
  """
53
53
  Registers a workflow run decorator for a given executor.
54
54
 
@@ -88,18 +88,14 @@ def default_workflow_run(fn: Callable[..., R]) -> Callable[..., R]:
88
88
  return wrapper
89
89
 
90
90
 
91
- def register_asyncio_decorators(decorator_registry: DecoratorRegistry):
91
+ def register_asyncio_decorators(decorator_registry: DecoratorRegistry) -> None:
92
92
  """Registers default asyncio decorators."""
93
93
  executor_name = "asyncio"
94
- decorator_registry.register_workflow_defn_decorator(
95
- executor_name, default_workflow_defn
96
- )
97
- decorator_registry.register_workflow_run_decorator(
98
- executor_name, default_workflow_run
99
- )
94
+ decorator_registry.register_workflow_defn_decorator(executor_name, default_workflow_defn)
95
+ decorator_registry.register_workflow_run_decorator(executor_name, default_workflow_run)
100
96
 
101
97
 
102
- def register_temporal_decorators(decorator_registry: DecoratorRegistry):
98
+ def register_temporal_decorators(decorator_registry: DecoratorRegistry) -> None:
103
99
  """Registers Temporal decorators if Temporal SDK is available."""
104
100
  try:
105
101
  import temporalio.workflow as temporal_workflow
@@ -112,9 +108,5 @@ def register_temporal_decorators(decorator_registry: DecoratorRegistry):
112
108
  return
113
109
 
114
110
  executor_name = "temporal"
115
- decorator_registry.register_workflow_defn_decorator(
116
- executor_name, temporal_workflow.defn
117
- )
118
- decorator_registry.register_workflow_run_decorator(
119
- executor_name, temporal_workflow.run
120
- )
111
+ decorator_registry.register_workflow_defn_decorator(executor_name, temporal_workflow.defn)
112
+ decorator_registry.register_workflow_run_decorator(executor_name, temporal_workflow.run)
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
4
4
  from contextlib import asynccontextmanager
5
5
  from datetime import timedelta
6
6
  from typing import (
7
+ TYPE_CHECKING,
7
8
  Any,
8
9
  AsyncIterator,
9
10
  Callable,
@@ -13,7 +14,6 @@ from typing import (
13
14
  Optional,
14
15
  Type,
15
16
  TypeVar,
16
- TYPE_CHECKING,
17
17
  )
18
18
 
19
19
  from pydantic import BaseModel, ConfigDict
@@ -56,7 +56,7 @@ class Executor(ABC, ContextDependent):
56
56
  signal_bus: SignalHandler = None,
57
57
  context: Optional["Context"] = None,
58
58
  **kwargs,
59
- ):
59
+ ) -> None:
60
60
  super().__init__(context=context, **kwargs)
61
61
  self.execution_engine = engine
62
62
 
@@ -127,9 +127,7 @@ class Executor(ABC, ContextDependent):
127
127
 
128
128
  return results
129
129
 
130
- async def validate_task(
131
- self, task: Callable[..., R] | Coroutine[Any, Any, R]
132
- ) -> None:
130
+ async def validate_task(self, task: Callable[..., R] | Coroutine[Any, Any, R]) -> None:
133
131
  """Validate a task before execution."""
134
132
  if not (asyncio.iscoroutine(task) or asyncio.iscoroutinefunction(task)):
135
133
  raise TypeError(f"Task must be async: {task}")
@@ -164,7 +162,7 @@ class Executor(ABC, ContextDependent):
164
162
 
165
163
  # Notify any callbacks that the workflow is about to be paused waiting for a signal
166
164
  if self.context.signal_notification:
167
- self.context.signal_notification(
165
+ await self.context.signal_notification(
168
166
  signal_name=signal_name,
169
167
  request_id=request_id,
170
168
  workflow_id=workflow_id,
@@ -188,15 +186,13 @@ class AsyncioExecutor(Executor):
188
186
  self,
189
187
  config: ExecutorConfig | None = None,
190
188
  signal_bus: SignalHandler | None = None,
191
- ):
189
+ ) -> None:
192
190
  signal_bus = signal_bus or AsyncioSignalHandler()
193
191
  super().__init__(engine="asyncio", config=config, signal_bus=signal_bus)
194
192
 
195
193
  self._activity_semaphore: asyncio.Semaphore | None = None
196
194
  if self.config.max_concurrent_activities is not None:
197
- self._activity_semaphore = asyncio.Semaphore(
198
- self.config.max_concurrent_activities
199
- )
195
+ self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
200
196
 
201
197
  async def _execute_task(
202
198
  self, task: Callable[..., R] | Coroutine[Any, Any, R], **kwargs: Any
@@ -253,16 +249,11 @@ class AsyncioExecutor(Executor):
253
249
  # TODO: saqadri - validate if async with self.execution_context() is needed here
254
250
  async with self.execution_context():
255
251
  # Create futures for all tasks
256
- futures = [
257
- asyncio.create_task(self._execute_task(task, **kwargs))
258
- for task in tasks
259
- ]
252
+ futures = [asyncio.create_task(self._execute_task(task, **kwargs)) for task in tasks]
260
253
  pending = set(futures)
261
254
 
262
255
  while pending:
263
- done, pending = await asyncio.wait(
264
- pending, return_when=asyncio.FIRST_COMPLETED
265
- )
256
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
266
257
  for future in done:
267
258
  yield await future
268
259
 
@@ -10,13 +10,11 @@ from typing import Any, Callable, Dict, List
10
10
  class ActivityRegistry:
11
11
  """Centralized task/activity management with validation and metadata."""
12
12
 
13
- def __init__(self):
13
+ def __init__(self) -> None:
14
14
  self._activities: Dict[str, Callable] = {}
15
15
  self._metadata: Dict[str, Dict[str, Any]] = {}
16
16
 
17
- def register(
18
- self, name: str, func: Callable, metadata: Dict[str, Any] | None = None
19
- ):
17
+ def register(self, name: str, func: Callable, metadata: Dict[str, Any] | None = None) -> None:
20
18
  if name in self._activities:
21
19
  raise ValueError(f"Activity '{name}' is already registered.")
22
20
  self._activities[name] = func
@@ -9,6 +9,7 @@ import asyncio
9
9
  import functools
10
10
  import uuid
11
11
  from typing import (
12
+ TYPE_CHECKING,
12
13
  Any,
13
14
  AsyncIterator,
14
15
  Callable,
@@ -16,11 +17,10 @@ from typing import (
16
17
  Dict,
17
18
  List,
18
19
  Optional,
19
- TYPE_CHECKING,
20
20
  )
21
21
 
22
22
  from pydantic import ConfigDict
23
- from temporalio import activity, workflow, exceptions
23
+ from temporalio import activity, exceptions, workflow
24
24
  from temporalio.client import Client as TemporalClient
25
25
  from temporalio.worker import Worker
26
26
 
@@ -59,22 +59,18 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
59
59
 
60
60
  # Define the signal handler for this specific registration
61
61
  @workflow.signal(name=unique_signal_name)
62
- def signal_handler(value: SignalValueT):
62
+ def signal_handler(value: SignalValueT) -> None:
63
63
  container["value"] = value
64
64
  container["completed"] = True
65
65
 
66
66
  async with self._lock:
67
67
  # Register both the signal registration and handler atomically
68
68
  self._pending_signals.setdefault(signal.name, []).append(registration)
69
- self._handlers.setdefault(signal.name, []).append(
70
- (unique_signal_name, signal_handler)
71
- )
69
+ self._handlers.setdefault(signal.name, []).append((unique_signal_name, signal_handler))
72
70
 
73
71
  try:
74
72
  # Wait for signal with optional timeout
75
- await workflow.wait_condition(
76
- lambda: container["completed"], timeout=timeout_seconds
77
- )
73
+ await workflow.wait_condition(lambda: container["completed"], timeout=timeout_seconds)
78
74
 
79
75
  return container["value"]
80
76
  except asyncio.TimeoutError as exc:
@@ -94,9 +90,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
94
90
  # Remove ourselves from _handlers
95
91
  if signal.name in self._handlers:
96
92
  self._handlers[signal.name] = [
97
- h
98
- for h in self._handlers[signal.name]
99
- if h[0] != unique_signal_name
93
+ h for h in self._handlers[signal.name] if h[0] != unique_signal_name
100
94
  ]
101
95
  if not self._handlers[signal.name]:
102
96
  del self._handlers[signal.name]
@@ -110,7 +104,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
110
104
 
111
105
  # Create the actual handler that will be registered with Temporal
112
106
  @workflow.signal(name=unique_signal_name)
113
- async def wrapped(signal_value: SignalValueT):
107
+ async def wrapped(signal_value: SignalValueT) -> None:
114
108
  # Create a signal object to pass to the handler
115
109
  signal = Signal(
116
110
  name=signal_name,
@@ -123,19 +117,15 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
123
117
  func(signal)
124
118
 
125
119
  # Register the handler under the original signal name
126
- self._handlers.setdefault(signal_name, []).append(
127
- (unique_signal_name, wrapped)
128
- )
120
+ self._handlers.setdefault(signal_name, []).append((unique_signal_name, wrapped))
129
121
  return func
130
122
 
131
123
  return decorator
132
124
 
133
- async def signal(self, signal):
125
+ async def signal(self, signal) -> None:
134
126
  self.validate_signal(signal)
135
127
 
136
- workflow_handle = workflow.get_external_workflow_handle(
137
- workflow_id=signal.workflow_id
138
- )
128
+ workflow_handle = workflow.get_external_workflow_handle(workflow_id=signal.workflow_id)
139
129
 
140
130
  # Send the signal to all registrations of this signal
141
131
  async with self._lock:
@@ -147,9 +137,7 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
147
137
  if registration.workflow_id == signal.workflow_id:
148
138
  # Only signal for registrations of that workflow
149
139
  signal_tasks.append(
150
- workflow_handle.signal(
151
- registration.unique_name, signal.payload
152
- )
140
+ workflow_handle.signal(registration.unique_name, signal.payload)
153
141
  )
154
142
  else:
155
143
  continue
@@ -157,13 +145,11 @@ class TemporalSignalHandler(BaseSignalHandler[SignalValueT]):
157
145
  # Notify any registered handler functions
158
146
  if signal.name in self._handlers:
159
147
  for unique_name, _ in self._handlers[signal.name]:
160
- signal_tasks.append(
161
- workflow_handle.signal(unique_name, signal.payload)
162
- )
148
+ signal_tasks.append(workflow_handle.signal(unique_name, signal.payload))
163
149
 
164
150
  await asyncio.gather(*signal_tasks, return_exceptions=True)
165
151
 
166
- def validate_signal(self, signal):
152
+ def validate_signal(self, signal) -> None:
167
153
  super().validate_signal(signal)
168
154
  # Add TemporalSignalHandler-specific validation
169
155
  if signal.workflow_id is None:
@@ -188,7 +174,7 @@ class TemporalExecutor(Executor):
188
174
  client: TemporalClient | None = None,
189
175
  context: Optional["Context"] = None,
190
176
  **kwargs,
191
- ):
177
+ ) -> None:
192
178
  signal_bus = signal_bus or TemporalSignalHandler()
193
179
  super().__init__(
194
180
  engine="temporal",
@@ -205,9 +191,7 @@ class TemporalExecutor(Executor):
205
191
  self._activity_semaphore = None
206
192
 
207
193
  if config.max_concurrent_activities is not None:
208
- self._activity_semaphore = asyncio.Semaphore(
209
- self.config.max_concurrent_activities
210
- )
194
+ self._activity_semaphore = asyncio.Semaphore(self.config.max_concurrent_activities)
211
195
 
212
196
  @staticmethod
213
197
  def wrap_as_activity(
@@ -275,9 +259,7 @@ class TemporalExecutor(Executor):
275
259
  func = task.func if isinstance(task, functools.partial) else task
276
260
  is_workflow_task = getattr(func, "is_workflow_task", False)
277
261
  if not is_workflow_task:
278
- return await asyncio.create_task(
279
- self._execute_task_as_async(task, **kwargs)
280
- )
262
+ return await asyncio.create_task(self._execute_task_as_async(task, **kwargs))
281
263
 
282
264
  execution_metadata: Dict[str, Any] = getattr(func, "execution_metadata", {})
283
265
 
@@ -319,9 +301,7 @@ class TemporalExecutor(Executor):
319
301
  ) -> List[R | BaseException]:
320
302
  # Must be called from within a workflow
321
303
  if not workflow._Runtime.current():
322
- raise RuntimeError(
323
- "TemporalExecutor.execute must be called from within a workflow"
324
- )
304
+ raise RuntimeError("TemporalExecutor.execute must be called from within a workflow")
325
305
 
326
306
  # TODO: saqadri - validate if async with self.execution_context() is needed here
327
307
  async with self.execution_context():
@@ -347,9 +327,7 @@ class TemporalExecutor(Executor):
347
327
  pending = set(futures)
348
328
 
349
329
  while pending:
350
- done, pending = await workflow.wait(
351
- pending, return_when=asyncio.FIRST_COMPLETED
352
- )
330
+ done, pending = await workflow.wait(pending, return_when=asyncio.FIRST_COMPLETED)
353
331
  for future in done:
354
332
  try:
355
333
  result = await future
@@ -368,7 +346,7 @@ class TemporalExecutor(Executor):
368
346
 
369
347
  return self.client
370
348
 
371
- async def start_worker(self):
349
+ async def start_worker(self) -> None:
372
350
  """
373
351
  Start a worker in this process, auto-registering all tasks
374
352
  from the global registry. Also picks up any classes decorated
@@ -62,7 +62,7 @@ class Workflow(ABC, Generic[T]):
62
62
  name: str | None = None,
63
63
  metadata: Dict[str, Any] | None = None,
64
64
  **kwargs: Any,
65
- ):
65
+ ) -> None:
66
66
  self.executor = executor
67
67
  self.name = name or self.__class__.__name__
68
68
  self.init_kwargs = kwargs
@@ -80,7 +80,7 @@ class Workflow(ABC, Generic[T]):
80
80
  Main workflow implementation. Must be overridden by subclasses.
81
81
  """
82
82
 
83
- async def update_state(self, **kwargs):
83
+ async def update_state(self, **kwargs) -> None:
84
84
  """Syntactic sugar to update workflow state."""
85
85
  for key, value in kwargs.items():
86
86
  self.state[key] = value
@@ -93,9 +93,7 @@ class Workflow(ABC, Generic[T]):
93
93
  Convenience method for human input. Uses `human_input` signal
94
94
  so we can unify local (console input) and Temporal signals.
95
95
  """
96
- return await self.executor.wait_for_signal(
97
- "human_input", description=description
98
- )
96
+ return await self.executor.wait_for_signal("human_input", description=description)
99
97
 
100
98
 
101
99
  # ############################
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import uuid
3
- from abc import abstractmethod, ABC
3
+ from abc import ABC, abstractmethod
4
4
  from typing import Any, Callable, Dict, Generic, List, Protocol, TypeVar
5
5
 
6
6
  from pydantic import BaseModel, ConfigDict
@@ -71,14 +71,14 @@ class PendingSignal(BaseModel):
71
71
  class BaseSignalHandler(ABC, Generic[SignalValueT]):
72
72
  """Base class implementing common signal handling functionality."""
73
73
 
74
- def __init__(self):
74
+ def __init__(self) -> None:
75
75
  # Map signal_name -> list of PendingSignal objects
76
76
  self._pending_signals: Dict[str, List[PendingSignal]] = {}
77
77
  # Map signal_name -> list of (unique_name, handler) tuples
78
78
  self._handlers: Dict[str, List[tuple[str, Callable]]] = {}
79
79
  self._lock = asyncio.Lock()
80
80
 
81
- async def cleanup(self, signal_name: str | None = None):
81
+ async def cleanup(self, signal_name: str | None = None) -> None:
82
82
  """Clean up handlers and registrations for a signal or all signals."""
83
83
  async with self._lock:
84
84
  if signal_name:
@@ -90,7 +90,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
90
90
  self._handlers.clear()
91
91
  self._pending_signals.clear()
92
92
 
93
- def validate_signal(self, signal: Signal[SignalValueT]):
93
+ def validate_signal(self, signal: Signal[SignalValueT]) -> None:
94
94
  """Validate signal properties."""
95
95
  if not signal.name:
96
96
  raise ValueError("Signal name is required")
@@ -102,7 +102,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
102
102
  def decorator(func: Callable) -> Callable:
103
103
  unique_name = f"{signal_name}_{uuid.uuid4()}"
104
104
 
105
- async def wrapped(value: SignalValueT):
105
+ async def wrapped(value: SignalValueT) -> None:
106
106
  try:
107
107
  if asyncio.iscoroutinefunction(func):
108
108
  await func(value)
@@ -133,7 +133,7 @@ class BaseSignalHandler(ABC, Generic[SignalValueT]):
133
133
  class ConsoleSignalHandler(SignalHandler[str]):
134
134
  """Simple console-based signal handling (blocks on input)."""
135
135
 
136
- def __init__(self):
136
+ def __init__(self) -> None:
137
137
  self._pending_signals: Dict[str, List[PendingSignal]] = {}
138
138
  self._handlers: Dict[str, List[Callable]] = {}
139
139
 
@@ -163,7 +163,7 @@ class ConsoleSignalHandler(SignalHandler[str]):
163
163
 
164
164
  def on_signal(self, signal_name):
165
165
  def decorator(func):
166
- async def wrapped(value: SignalValueT):
166
+ async def wrapped(value: SignalValueT) -> None:
167
167
  if asyncio.iscoroutinefunction(func):
168
168
  await func(value)
169
169
  else:
@@ -174,13 +174,11 @@ class ConsoleSignalHandler(SignalHandler[str]):
174
174
 
175
175
  return decorator
176
176
 
177
- async def signal(self, signal):
177
+ async def signal(self, signal) -> None:
178
178
  print(f"[SIGNAL SENT: {signal.name}] Value: {signal.payload}")
179
179
 
180
180
  handlers = self._handlers.get(signal.name, [])
181
- await asyncio.gather(
182
- *(handler(signal) for handler in handlers), return_exceptions=True
183
- )
181
+ await asyncio.gather(*(handler(signal) for handler in handlers), return_exceptions=True)
184
182
 
185
183
  # Notify any waiting coroutines
186
184
  if signal.name in self._pending_signals:
@@ -194,9 +192,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
194
192
  Asyncio-based signal handling using an internal dictionary of asyncio Events.
195
193
  """
196
194
 
197
- async def wait_for_signal(
198
- self, signal, timeout_seconds: int | None = None
199
- ) -> SignalValueT:
195
+ async def wait_for_signal(self, signal, timeout_seconds: int | None = None) -> SignalValueT:
200
196
  event = asyncio.Event()
201
197
  unique_name = str(uuid.uuid4())
202
198
 
@@ -236,7 +232,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
236
232
 
237
233
  def on_signal(self, signal_name):
238
234
  def decorator(func):
239
- async def wrapped(value: SignalValueT):
235
+ async def wrapped(value: SignalValueT) -> None:
240
236
  if asyncio.iscoroutinefunction(func):
241
237
  await func(value)
242
238
  else:
@@ -247,7 +243,7 @@ class AsyncioSignalHandler(BaseSignalHandler[SignalValueT]):
247
243
 
248
244
  return decorator
249
245
 
250
- async def signal(self, signal):
246
+ async def signal(self, signal) -> None:
251
247
  async with self._lock:
252
248
  # Notify any waiting coroutines
253
249
  if signal.name in self._pending_signals:
@@ -272,11 +268,11 @@ class LocalSignalStore:
272
268
  and triggers them when a signal is emitted.
273
269
  """
274
270
 
275
- def __init__(self):
271
+ def __init__(self) -> None:
276
272
  # For each signal_name, store a list of futures that are waiting for it
277
273
  self._waiters: Dict[str, List[asyncio.Future]] = {}
278
274
 
279
- async def emit(self, signal_name: str, payload: Any):
275
+ async def emit(self, signal_name: str, payload: Any) -> None:
280
276
  # If we have waiting futures, set their result
281
277
  if signal_name in self._waiters:
282
278
  for future in self._waiters[signal_name]:
@@ -284,9 +280,7 @@ class LocalSignalStore:
284
280
  future.set_result(payload)
285
281
  self._waiters[signal_name].clear()
286
282
 
287
- async def wait_for(
288
- self, signal_name: str, timeout_seconds: int | None = None
289
- ) -> Any:
283
+ async def wait_for(self, signal_name: str, timeout_seconds: int | None = None) -> Any:
290
284
  loop = asyncio.get_running_loop()
291
285
  future = loop.create_future()
292
286