nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import json
17
18
  import logging
18
19
  import os
19
- import time
20
20
  import typing
21
21
  from abc import ABC
22
22
  from abc import abstractmethod
@@ -25,19 +25,21 @@ from collections.abc import Callable
25
25
  from contextlib import asynccontextmanager
26
26
  from pathlib import Path
27
27
 
28
- from fastapi import BackgroundTasks
28
+ import httpx
29
+ from authlib.common.errors import AuthlibBaseError as OAuthError
29
30
  from fastapi import Body
30
31
  from fastapi import FastAPI
32
+ from fastapi import HTTPException
31
33
  from fastapi import Request
32
34
  from fastapi import Response
33
35
  from fastapi import UploadFile
34
- from fastapi.exceptions import HTTPException
35
36
  from fastapi.middleware.cors import CORSMiddleware
36
37
  from fastapi.responses import StreamingResponse
37
38
  from pydantic import BaseModel
38
39
  from pydantic import Field
39
40
  from starlette.websockets import WebSocket
40
41
 
42
+ from nat.builder.function import Function
41
43
  from nat.builder.workflow_builder import WorkflowBuilder
42
44
  from nat.data_models.api_server import ChatRequest
43
45
  from nat.data_models.api_server import ChatResponse
@@ -58,18 +60,30 @@ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
58
60
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
59
61
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
60
62
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
61
- from nat.front_ends.fastapi.job_store import JobInfo
62
- from nat.front_ends.fastapi.job_store import JobStore
63
63
  from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
64
64
  from nat.front_ends.fastapi.response_helpers import generate_single_response
65
65
  from nat.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
66
66
  from nat.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
67
67
  from nat.front_ends.fastapi.step_adaptor import StepAdaptor
68
+ from nat.front_ends.fastapi.utils import get_config_file_path
68
69
  from nat.object_store.models import ObjectStoreItem
70
+ from nat.runtime.loader import load_workflow
69
71
  from nat.runtime.session import SessionManager
70
72
 
71
73
  logger = logging.getLogger(__name__)
72
74
 
75
+ _DASK_AVAILABLE = False
76
+
77
+ try:
78
+ from nat.front_ends.fastapi.job_store import JobInfo
79
+ from nat.front_ends.fastapi.job_store import JobStatus
80
+ from nat.front_ends.fastapi.job_store import JobStore
81
+ _DASK_AVAILABLE = True
82
+ except ImportError:
83
+ JobInfo = None
84
+ JobStatus = None
85
+ JobStore = None
86
+
73
87
 
74
88
  class FastApiFrontEndPluginWorkerBase(ABC):
75
89
 
@@ -80,10 +94,29 @@ class FastApiFrontEndPluginWorkerBase(ABC):
80
94
  FastApiFrontEndConfig), ("Front end config is not FastApiFrontEndConfig")
81
95
 
82
96
  self._front_end_config = config.general.front_end
83
-
84
- self._cleanup_tasks: list[str] = []
85
- self._cleanup_tasks_lock = asyncio.Lock()
97
+ self._dask_available = False
98
+ self._job_store = None
86
99
  self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
100
+ self._scheduler_address = os.environ.get("NAT_DASK_SCHEDULER_ADDRESS")
101
+ self._db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
102
+ self._config_file_path = get_config_file_path()
103
+
104
+ if self._scheduler_address is not None:
105
+ if not _DASK_AVAILABLE:
106
+ raise RuntimeError("Dask is not available, please install it to use the FastAPI front end with Dask.")
107
+
108
+ if self._db_url is None:
109
+ raise RuntimeError(
110
+ "NAT_JOB_STORE_DB_URL must be set when using Dask (configure a persistent JobStore database).")
111
+
112
+ try:
113
+ self._job_store = JobStore(scheduler_address=self._scheduler_address, db_url=self._db_url)
114
+ self._dask_available = True
115
+ logger.debug("Connected to Dask scheduler at %s", self._scheduler_address)
116
+ except Exception as e:
117
+ raise RuntimeError(f"Failed to connect to Dask scheduler at {self._scheduler_address}: {e}") from e
118
+ else:
119
+ logger.debug("No Dask scheduler address provided, running without Dask support.")
87
120
 
88
121
  @property
89
122
  def config(self) -> Config:
@@ -107,20 +140,6 @@ class FastApiFrontEndPluginWorkerBase(ABC):
107
140
 
108
141
  yield
109
142
 
110
- # If a cleanup task is running, cancel it
111
- async with self._cleanup_tasks_lock:
112
-
113
- # Cancel all cleanup tasks
114
- for task_name in self._cleanup_tasks:
115
- cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
116
- if cleanup_task is not None:
117
- logger.info("Cancelling %s cleanup task", task_name)
118
- cleanup_task.cancel()
119
- else:
120
- logger.warning("No cleanup task found for %s", task_name)
121
-
122
- self._cleanup_tasks.clear()
123
-
124
143
  logger.debug("Closing NAT server from process %s", os.getpid())
125
144
 
126
145
  nat_app = FastAPI(lifespan=lifespan)
@@ -208,32 +227,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
208
227
  self._outstanding_flows: dict[str, FlowState] = {}
209
228
  self._outstanding_flows_lock = asyncio.Lock()
210
229
 
211
- @staticmethod
212
- async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
213
- while True:
214
- try:
215
- job_store.cleanup_expired_jobs()
216
- logger.debug("Expired %s jobs cleaned up", name)
217
- except Exception as e:
218
- logger.error("Error during %s job cleanup: %s", name, e)
219
- await asyncio.sleep(sleep_time_sec)
220
-
221
- async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
222
- # Schedule periodic cleanup of expired jobs on first job creation
223
- attr_name = f"{name}_cleanup_task"
224
-
225
- # Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
226
- if not hasattr(app.state, attr_name):
227
- async with self._cleanup_tasks_lock:
228
- if not hasattr(app.state, attr_name):
229
- logger.info("Starting %s periodic cleanup task", name)
230
- setattr(
231
- app.state,
232
- attr_name,
233
- asyncio.create_task(
234
- self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
235
- self._cleanup_tasks.append(attr_name)
236
-
237
230
  def get_step_adaptor(self) -> StepAdaptor:
238
231
 
239
232
  return StepAdaptor(self.front_end_config.step_adaptor)
@@ -247,14 +240,15 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
247
240
 
248
241
  async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
249
242
 
250
- await self.add_default_route(app, SessionManager(builder.build()))
251
- await self.add_evaluate_route(app, SessionManager(builder.build()))
243
+ await self.add_default_route(app, SessionManager(await builder.build()))
244
+ await self.add_evaluate_route(app, SessionManager(await builder.build()))
252
245
  await self.add_static_files_route(app, builder)
253
246
  await self.add_authorization_route(app)
247
+ await self.add_mcp_client_tool_list_route(app, builder)
254
248
 
255
249
  for ep in self.front_end_config.endpoints:
256
250
 
257
- entry_workflow = builder.build(entry_function=ep.function_name)
251
+ entry_workflow = await builder.build(entry_function=ep.function_name)
258
252
 
259
253
  await self.add_route(app, endpoint=ep, session_manager=SessionManager(entry_workflow))
260
254
 
@@ -276,52 +270,72 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
276
270
  },
277
271
  }
278
272
 
279
- # Create job store for tracking evaluation jobs
280
- job_store = JobStore()
281
- # Don't run multiple evaluations at the same time
282
- evaluation_lock = asyncio.Lock()
283
-
284
- async def run_evaluation(job_id: str, config_file: str, reps: int, session_manager: SessionManager):
273
+ # TODO: Find another way to limit the number of concurrent evaluations
274
+ async def run_evaluation(scheduler_address: str,
275
+ db_url: str,
276
+ workflow_config_file_path: str,
277
+ job_id: str,
278
+ eval_config_file: str,
279
+ reps: int):
285
280
  """Background task to run the evaluation."""
286
- async with evaluation_lock:
287
- try:
288
- # Create EvaluationRunConfig using the CLI defaults
289
- eval_config = EvaluationRunConfig(config_file=Path(config_file), dataset=None, reps=reps)
290
-
291
- # Create a new EvaluationRun with the evaluation-specific config
292
- job_store.update_status(job_id, "running")
293
- eval_runner = EvaluationRun(eval_config)
294
- output: EvaluationRunOutput = await eval_runner.run_and_evaluate(session_manager=session_manager,
295
- job_id=job_id)
296
- if output.workflow_interrupted:
297
- job_store.update_status(job_id, "interrupted")
298
- else:
299
- parent_dir = os.path.dirname(
300
- output.workflow_output_file) if output.workflow_output_file else None
301
-
302
- job_store.update_status(job_id, "success", output_path=str(parent_dir))
303
- except Exception as e:
304
- logger.error("Error in evaluation job %s: %s", job_id, str(e))
305
- job_store.update_status(job_id, "failure", error=str(e))
306
-
307
- async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
308
- """Handle evaluation requests."""
281
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
309
282
 
310
- async with session_manager.session(request=http_request):
283
+ try:
284
+ # We have two config files, one for the workflow and one for the evaluation
285
+ # Create EvaluationRunConfig using the CLI defaults
286
+ eval_config = EvaluationRunConfig(config_file=Path(eval_config_file), dataset=None, reps=reps)
311
287
 
312
- # if job_id is present and already exists return the job info
313
- if request.job_id:
314
- job = job_store.get_job(request.job_id)
315
- if job:
316
- return EvaluateResponse(job_id=job.job_id, status=job.status)
288
+ # Create a new EvaluationRun with the evaluation-specific config
289
+ await job_store.update_status(job_id, JobStatus.RUNNING)
290
+ eval_runner = EvaluationRun(eval_config)
291
+
292
+ async with load_workflow(workflow_config_file_path) as local_session_manager:
293
+ output: EvaluationRunOutput = await eval_runner.run_and_evaluate(
294
+ session_manager=local_session_manager, job_id=job_id)
295
+
296
+ if output.workflow_interrupted:
297
+ await job_store.update_status(job_id, JobStatus.INTERRUPTED)
298
+ else:
299
+ parent_dir = os.path.dirname(output.workflow_output_file) if output.workflow_output_file else None
317
300
 
318
- job_id = job_store.create_job(request.config_file, request.job_id, request.expiry_seconds)
319
- await self.create_cleanup_task(app=app, name="async_evaluation", job_store=job_store)
320
- background_tasks.add_task(run_evaluation, job_id, request.config_file, request.reps, session_manager)
301
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output_path=str(parent_dir))
302
+ except Exception as e:
303
+ logger.exception("Error in evaluation job %s", job_id)
304
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
305
+
306
+ async def start_evaluation(request: EvaluateRequest, http_request: Request):
307
+ """Handle evaluation requests."""
321
308
 
322
- return EvaluateResponse(job_id=job_id, status="submitted")
309
+ async with session_manager.session(http_connection=http_request):
323
310
 
324
- def translate_job_to_response(job: JobInfo) -> EvaluateStatusResponse:
311
+ # if job_id is present and already exists return the job info
312
+ # There is a race condition between this check and the actual job submission, however if the client is
313
+ # supplying their own job_ids, then it is their responsibility to ensure that the job_id is unique.
314
+ if request.job_id:
315
+ job_status = await self._job_store.get_status(request.job_id)
316
+ if job_status != JobStatus.NOT_FOUND:
317
+ return EvaluateResponse(job_id=request.job_id, status=job_status)
318
+
319
+ job_id = self._job_store.ensure_job_id(request.job_id)
320
+
321
+ await self._job_store.submit_job(job_id=job_id,
322
+ config_file=request.config_file,
323
+ expiry_seconds=request.expiry_seconds,
324
+ job_fn=run_evaluation,
325
+ job_args=[
326
+ self._scheduler_address,
327
+ self._db_url,
328
+ self._config_file_path,
329
+ job_id,
330
+ request.config_file,
331
+ request.reps
332
+ ])
333
+
334
+ logger.info("Submitted evaluation job %s with config %s", job_id, request.config_file)
335
+
336
+ return EvaluateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
337
+
338
+ def translate_job_to_response(job: "JobInfo") -> EvaluateStatusResponse:
325
339
  """Translate a JobInfo object to an EvaluateStatusResponse."""
326
340
  return EvaluateStatusResponse(job_id=job.job_id,
327
341
  status=job.status,
@@ -330,15 +344,15 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
330
344
  output_path=str(job.output_path),
331
345
  created_at=job.created_at,
332
346
  updated_at=job.updated_at,
333
- expires_at=job_store.get_expires_at(job))
347
+ expires_at=self._job_store.get_expires_at(job))
334
348
 
335
349
  async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusResponse:
336
350
  """Get the status of an evaluation job."""
337
351
  logger.info("Getting status for job %s", job_id)
338
352
 
339
- async with session_manager.session(request=http_request):
353
+ async with session_manager.session(http_connection=http_request):
340
354
 
341
- job = job_store.get_job(job_id)
355
+ job = await self._job_store.get_job(job_id)
342
356
  if not job:
343
357
  logger.warning("Job %s not found", job_id)
344
358
  raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
@@ -349,9 +363,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
349
363
  """Get the status of the last created evaluation job."""
350
364
  logger.info("Getting last job status")
351
365
 
352
- async with session_manager.session(request=http_request):
366
+ async with session_manager.session(http_connection=http_request):
353
367
 
354
- job = job_store.get_last_job()
368
+ job = await self._job_store.get_last_job()
355
369
  if not job:
356
370
  logger.warning("No jobs found when requesting last job status")
357
371
  raise HTTPException(status_code=404, detail="No jobs found")
@@ -361,65 +375,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
361
375
  async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
362
376
  """Get all jobs, optionally filtered by status."""
363
377
 
364
- async with session_manager.session(request=http_request):
378
+ async with session_manager.session(http_connection=http_request):
365
379
 
366
380
  if status is None:
367
381
  logger.info("Getting all jobs")
368
- jobs = job_store.get_all_jobs()
382
+ jobs = await self._job_store.get_all_jobs()
369
383
  else:
370
384
  logger.info("Getting jobs with status %s", status)
371
- jobs = job_store.get_jobs_by_status(status)
385
+ jobs = await self._job_store.get_jobs_by_status(JobStatus(status))
386
+
372
387
  logger.info("Found %d jobs", len(jobs))
373
388
  return [translate_job_to_response(job) for job in jobs]
374
389
 
375
390
  if self.front_end_config.evaluate.path:
376
- # Add last job endpoint first (most specific)
377
- app.add_api_route(
378
- path=f"{self.front_end_config.evaluate.path}/job/last",
379
- endpoint=get_last_job_status,
380
- methods=["GET"],
381
- response_model=EvaluateStatusResponse,
382
- description="Get the status of the last created evaluation job",
383
- responses={
384
- 404: {
385
- "description": "No jobs found"
386
- }, 500: response_500
387
- },
388
- )
391
+ if self._dask_available:
392
+ # Add last job endpoint first (most specific)
393
+ app.add_api_route(
394
+ path=f"{self.front_end_config.evaluate.path}/job/last",
395
+ endpoint=get_last_job_status,
396
+ methods=["GET"],
397
+ response_model=EvaluateStatusResponse,
398
+ description="Get the status of the last created evaluation job",
399
+ responses={
400
+ 404: {
401
+ "description": "No jobs found"
402
+ }, 500: response_500
403
+ },
404
+ )
389
405
 
390
- # Add specific job endpoint (least specific)
391
- app.add_api_route(
392
- path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
393
- endpoint=get_job_status,
394
- methods=["GET"],
395
- response_model=EvaluateStatusResponse,
396
- description="Get the status of an evaluation job",
397
- responses={
398
- 404: {
399
- "description": "Job not found"
400
- }, 500: response_500
401
- },
402
- )
406
+ # Add specific job endpoint (least specific)
407
+ app.add_api_route(
408
+ path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
409
+ endpoint=get_job_status,
410
+ methods=["GET"],
411
+ response_model=EvaluateStatusResponse,
412
+ description="Get the status of an evaluation job",
413
+ responses={
414
+ 404: {
415
+ "description": "Job not found"
416
+ }, 500: response_500
417
+ },
418
+ )
403
419
 
404
- # Add jobs endpoint with optional status query parameter
405
- app.add_api_route(
406
- path=f"{self.front_end_config.evaluate.path}/jobs",
407
- endpoint=get_jobs,
408
- methods=["GET"],
409
- response_model=list[EvaluateStatusResponse],
410
- description="Get all jobs, optionally filtered by status",
411
- responses={500: response_500},
412
- )
420
+ # Add jobs endpoint with optional status query parameter
421
+ app.add_api_route(
422
+ path=f"{self.front_end_config.evaluate.path}/jobs",
423
+ endpoint=get_jobs,
424
+ methods=["GET"],
425
+ response_model=list[EvaluateStatusResponse],
426
+ description="Get all jobs, optionally filtered by status",
427
+ responses={500: response_500},
428
+ )
413
429
 
414
- # Add HTTP endpoint for evaluation
415
- app.add_api_route(
416
- path=self.front_end_config.evaluate.path,
417
- endpoint=start_evaluation,
418
- methods=[self.front_end_config.evaluate.method],
419
- response_model=EvaluateResponse,
420
- description=self.front_end_config.evaluate.description,
421
- responses={500: response_500},
422
- )
430
+ # Add HTTP endpoint for evaluation
431
+ app.add_api_route(
432
+ path=self.front_end_config.evaluate.path,
433
+ endpoint=start_evaluation,
434
+ methods=[self.front_end_config.evaluate.method],
435
+ response_model=EvaluateResponse,
436
+ description=self.front_end_config.evaluate.description,
437
+ responses={500: response_500},
438
+ )
439
+ else:
440
+ logger.warning("Dask is not available, evaluation endpoints will not be added.")
423
441
 
424
442
  async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
425
443
 
@@ -522,25 +540,27 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
522
540
 
523
541
  workflow = session_manager.workflow
524
542
 
525
- GenerateBodyType = workflow.input_schema # pylint: disable=invalid-name
526
- GenerateStreamResponseType = workflow.streaming_output_schema # pylint: disable=invalid-name
527
- GenerateSingleResponseType = workflow.single_output_schema # pylint: disable=invalid-name
528
-
529
- # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
530
- # Consider prefixing these with "nat_" to avoid conflicts
531
- class AsyncGenerateRequest(GenerateBodyType):
532
- job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
533
- sync_timeout: int = Field(
534
- default=0,
535
- ge=0,
536
- le=300,
537
- description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
538
- "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
539
- expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
540
- ge=JobStore.MIN_EXPIRY,
541
- le=JobStore.MAX_EXPIRY,
542
- description="Optional time (in seconds) before the job expires. "
543
- "Clamped between 600 (10 min) and 86400 (24h).")
543
+ GenerateBodyType = workflow.input_schema
544
+ GenerateStreamResponseType = workflow.streaming_output_schema
545
+ GenerateSingleResponseType = workflow.single_output_schema
546
+
547
+ if self._dask_available:
548
+ # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
549
+ # Consider prefixing these with "nat_" to avoid conflicts
550
+
551
+ class AsyncGenerateRequest(GenerateBodyType):
552
+ job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
553
+ sync_timeout: int = Field(
554
+ default=0,
555
+ ge=0,
556
+ le=300,
557
+ description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
558
+ "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
559
+ expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
560
+ ge=JobStore.MIN_EXPIRY,
561
+ le=JobStore.MAX_EXPIRY,
562
+ description="Optional time (in seconds) before the job expires. "
563
+ "Clamped between 600 (10 min) and 86400 (24h).")
544
564
 
545
565
  # Ensure that the input is in the body. POD types are treated as query parameters
546
566
  if (not issubclass(GenerateBodyType, BaseModel)):
@@ -560,19 +580,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
560
580
  },
561
581
  }
562
582
 
563
- # Create job store for tracking async generation jobs
564
- job_store = JobStore()
565
-
566
- # Run up to max_running_async_jobs jobs at the same time
567
- async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
568
-
569
583
  def get_single_endpoint(result_type: type | None):
570
584
 
571
585
  async def get_single(response: Response, request: Request):
572
586
 
573
587
  response.headers["Content-Type"] = "application/json"
574
588
 
575
- async with session_manager.session(request=request,
589
+ async with session_manager.session(http_connection=request,
576
590
  user_authentication_callback=self._http_flow_handler.authenticate):
577
591
 
578
592
  return await generate_single_response(None, session_manager, result_type=result_type)
@@ -583,7 +597,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
583
597
 
584
598
  async def get_stream(request: Request):
585
599
 
586
- async with session_manager.session(request=request,
600
+ async with session_manager.session(http_connection=request,
587
601
  user_authentication_callback=self._http_flow_handler.authenticate):
588
602
 
589
603
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -618,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
618
632
 
619
633
  response.headers["Content-Type"] = "application/json"
620
634
 
621
- async with session_manager.session(request=request,
635
+ async with session_manager.session(http_connection=request,
622
636
  user_authentication_callback=self._http_flow_handler.authenticate):
623
637
 
624
638
  return await generate_single_response(payload, session_manager, result_type=result_type)
@@ -632,7 +646,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
632
646
 
633
647
  async def post_stream(request: Request, payload: request_type):
634
648
 
635
- async with session_manager.session(request=request,
649
+ async with session_manager.session(http_connection=request,
636
650
  user_authentication_callback=self._http_flow_handler.authenticate):
637
651
 
638
652
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -677,7 +691,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
677
691
  # Check if streaming is requested
678
692
  stream_requested = getattr(payload, 'stream', False)
679
693
 
680
- async with session_manager.session(request=request):
694
+ async with session_manager.session(http_connection=request):
681
695
  if stream_requested:
682
696
  # Return streaming response
683
697
  return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -688,115 +702,112 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
688
702
  step_adaptor=self.get_step_adaptor(),
689
703
  result_type=ChatResponseChunk,
690
704
  output_type=ChatResponseChunk))
691
- else:
692
- # Return single response - check if workflow supports non-streaming
693
- try:
705
+
706
+ # Return single response - check if workflow supports non-streaming
707
+ try:
708
+ response.headers["Content-Type"] = "application/json"
709
+ return await generate_single_response(payload, session_manager, result_type=ChatResponse)
710
+ except ValueError as e:
711
+ if "Cannot get a single output value for streaming workflows" in str(e):
712
+ # Workflow only supports streaming, but client requested non-streaming
713
+ # Fall back to streaming and collect the result
714
+ chunks = []
715
+ async for chunk_str in generate_streaming_response_as_str(
716
+ payload,
717
+ session_manager=session_manager,
718
+ streaming=True,
719
+ step_adaptor=self.get_step_adaptor(),
720
+ result_type=ChatResponseChunk,
721
+ output_type=ChatResponseChunk):
722
+ if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
723
+ chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
724
+ if chunk_data:
725
+ try:
726
+ chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
727
+ if (chunk_json.choices and len(chunk_json.choices) > 0
728
+ and chunk_json.choices[0].delta
729
+ and chunk_json.choices[0].delta.content is not None):
730
+ chunks.append(chunk_json.choices[0].delta.content)
731
+ except Exception:
732
+ continue
733
+
734
+ # Create a single response from collected chunks
735
+ content = "".join(chunks)
736
+ single_response = ChatResponse.from_string(content)
694
737
  response.headers["Content-Type"] = "application/json"
695
- return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696
- except ValueError as e:
697
- if "Cannot get a single output value for streaming workflows" in str(e):
698
- # Workflow only supports streaming, but client requested non-streaming
699
- # Fall back to streaming and collect the result
700
- chunks = []
701
- async for chunk_str in generate_streaming_response_as_str(
702
- payload,
703
- session_manager=session_manager,
704
- streaming=True,
705
- step_adaptor=self.get_step_adaptor(),
706
- result_type=ChatResponseChunk,
707
- output_type=ChatResponseChunk):
708
- if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709
- chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710
- if chunk_data:
711
- try:
712
- chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713
- if (chunk_json.choices and len(chunk_json.choices) > 0
714
- and chunk_json.choices[0].delta
715
- and chunk_json.choices[0].delta.content is not None):
716
- chunks.append(chunk_json.choices[0].delta.content)
717
- except Exception:
718
- continue
719
-
720
- # Create a single response from collected chunks
721
- content = "".join(chunks)
722
- single_response = ChatResponse.from_string(content)
723
- response.headers["Content-Type"] = "application/json"
724
- return single_response
725
- else:
726
- raise
738
+ return single_response
739
+ raise
727
740
 
728
741
  return post_openai_api_compatible
729
742
 
730
- async def run_generation(job_id: str, payload: typing.Any, session_manager: SessionManager, result_type: type):
731
- """Background task to run the evaluation."""
732
- async with async_job_concurrency:
733
- try:
734
- result = await generate_single_response(payload=payload,
735
- session_manager=session_manager,
736
- result_type=result_type)
737
- job_store.update_status(job_id, "success", output=result)
738
- except Exception as e:
739
- logger.error("Error in evaluation job %s: %s", job_id, e)
740
- job_store.update_status(job_id, "failure", error=str(e))
741
-
742
- def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
743
+ def _job_status_to_response(job: "JobInfo") -> AsyncGenerationStatusResponse:
743
744
  job_output = job.output
744
745
  if job_output is not None:
745
- job_output = job_output.model_dump()
746
+ try:
747
+ job_output = json.loads(job_output)
748
+ except json.JSONDecodeError:
749
+ logger.error("Failed to parse job output as JSON: %s", job_output)
750
+ job_output = {"error": "Output parsing failed"}
751
+
746
752
  return AsyncGenerationStatusResponse(job_id=job.job_id,
747
753
  status=job.status,
748
754
  error=job.error,
749
755
  output=job_output,
750
756
  created_at=job.created_at,
751
757
  updated_at=job.updated_at,
752
- expires_at=job_store.get_expires_at(job))
758
+ expires_at=self._job_store.get_expires_at(job))
759
+
760
+ async def run_generation(scheduler_address: str,
761
+ db_url: str,
762
+ config_file_path: str,
763
+ job_id: str,
764
+ payload: typing.Any):
765
+ """Background task to run the workflow."""
766
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
767
+ try:
768
+ async with load_workflow(config_file_path) as local_session_manager:
769
+ result = await generate_single_response(
770
+ payload, local_session_manager, result_type=local_session_manager.workflow.single_output_schema)
753
771
 
754
- def post_async_generation(request_type: type, final_result_type: type):
772
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output=result)
773
+ except Exception as e:
774
+ logger.exception("Error in async job %s", job_id)
775
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
776
+
777
+ def post_async_generation(request_type: type):
755
778
 
756
779
  async def start_async_generation(
757
- request: request_type, background_tasks: BackgroundTasks, response: Response,
780
+ request: request_type, response: Response,
758
781
  http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
759
782
  """Handle async generation requests."""
760
783
 
761
- async with session_manager.session(request=http_request):
784
+ async with session_manager.session(http_connection=http_request):
762
785
 
763
786
  # if job_id is present and already exists return the job info
764
787
  if request.job_id:
765
- job = job_store.get_job(request.job_id)
788
+ job = await self._job_store.get_job(request.job_id)
766
789
  if job:
767
790
  return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
768
791
 
769
- job_id = job_store.create_job(job_id=request.job_id, expiry_seconds=request.expiry_seconds)
770
- await self.create_cleanup_task(app=app, name="async_generation", job_store=job_store)
771
-
772
- # The fastapi/starlette background tasks won't begin executing until after the response is sent
773
- # to the client, so we need to wrap the task in a function, alowing us to start the task now,
774
- # and allowing the background task function to await the results.
775
- task = asyncio.create_task(
776
- run_generation(job_id=job_id,
777
- payload=request,
778
- session_manager=session_manager,
779
- result_type=final_result_type))
780
-
781
- async def wrapped_task(t: asyncio.Task):
782
- return await t
783
-
784
- background_tasks.add_task(wrapped_task, task)
785
-
786
- now = time.time()
787
- sync_timeout = now + request.sync_timeout
788
- while time.time() < sync_timeout:
789
- job = job_store.get_job(job_id)
790
- if job is not None and job.status not in job_store.ACTIVE_STATUS:
791
- # If the job is done, return the result
792
- response.status_code = 200
793
- return _job_status_to_response(job)
794
-
795
- # Sleep for a short time before checking again
796
- await asyncio.sleep(0.1)
792
+ job_id = self._job_store.ensure_job_id(request.job_id)
793
+ (_, job) = await self._job_store.submit_job(job_id=job_id,
794
+ expiry_seconds=request.expiry_seconds,
795
+ job_fn=run_generation,
796
+ sync_timeout=request.sync_timeout,
797
+ job_args=[
798
+ self._scheduler_address,
799
+ self._db_url,
800
+ self._config_file_path,
801
+ job_id,
802
+ request.model_dump(mode="json")
803
+ ])
804
+
805
+ if job is not None:
806
+ response.status_code = 200
807
+ return _job_status_to_response(job)
797
808
 
798
809
  response.status_code = 202
799
- return AsyncGenerateResponse(job_id=job_id, status="submitted")
810
+ return AsyncGenerateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
800
811
 
801
812
  return start_async_generation
802
813
 
@@ -804,10 +815,10 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
804
815
  """Get the status of an async job."""
805
816
  logger.info("Getting status for job %s", job_id)
806
817
 
807
- async with session_manager.session(request=http_request):
818
+ async with session_manager.session(http_connection=http_request):
808
819
 
809
- job = job_store.get_job(job_id)
810
- if not job:
820
+ job = await self._job_store.get_job(job_id)
821
+ if job is None:
811
822
  logger.warning("Job %s not found", job_id)
812
823
  raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
813
824
 
@@ -935,30 +946,33 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
935
946
  responses={500: response_500},
936
947
  )
937
948
 
938
- app.add_api_route(
939
- path=f"{endpoint.path}/async",
940
- endpoint=post_async_generation(request_type=AsyncGenerateRequest,
941
- final_result_type=GenerateSingleResponseType),
942
- methods=[endpoint.method],
943
- response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
944
- description="Start an async generate job",
945
- responses={500: response_500},
946
- )
949
+ if self._dask_available:
950
+ app.add_api_route(
951
+ path=f"{endpoint.path}/async",
952
+ endpoint=post_async_generation(request_type=AsyncGenerateRequest),
953
+ methods=[endpoint.method],
954
+ response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
955
+ description="Start an async generate job",
956
+ responses={500: response_500},
957
+ )
958
+ else:
959
+ logger.warning("Dask is not available, async generation endpoints will not be added.")
947
960
  else:
948
961
  raise ValueError(f"Unsupported method {endpoint.method}")
949
962
 
950
- app.add_api_route(
951
- path=f"{endpoint.path}/async/job/{{job_id}}",
952
- endpoint=get_async_job_status,
953
- methods=["GET"],
954
- response_model=AsyncGenerationStatusResponse,
955
- description="Get the status of an async job",
956
- responses={
957
- 404: {
958
- "description": "Job not found"
959
- }, 500: response_500
960
- },
961
- )
963
+ if self._dask_available:
964
+ app.add_api_route(
965
+ path=f"{endpoint.path}/async/job/{{job_id}}",
966
+ endpoint=get_async_job_status,
967
+ methods=["GET"],
968
+ response_model=AsyncGenerationStatusResponse,
969
+ description="Get the status of an async job",
970
+ responses={
971
+ 404: {
972
+ "description": "Job not found"
973
+ }, 500: response_500
974
+ },
975
+ )
962
976
 
963
977
  if (endpoint.openai_api_path):
964
978
  if (endpoint.method == "GET"):
@@ -1061,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1061
1075
  code_verifier=verifier,
1062
1076
  state=state)
1063
1077
  flow_state.future.set_result(res)
1078
+ except OAuthError as e:
1079
+ flow_state.future.set_exception(
1080
+ RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
1081
+ except httpx.HTTPError as e:
1082
+ flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
1064
1083
  except Exception as e:
1065
- flow_state.future.set_exception(e)
1084
+ flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
1066
1085
 
1067
1086
  return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
1068
1087
  status_code=200,
@@ -1078,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1078
1097
  methods=["GET"],
1079
1098
  description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
1080
1099
 
1100
+ async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
1101
+ """Add the MCP client tool list endpoint to the FastAPI app."""
1102
+ from typing import Any
1103
+
1104
+ from pydantic import BaseModel
1105
+
1106
+ class MCPToolInfo(BaseModel):
1107
+ name: str
1108
+ description: str
1109
+ server: str
1110
+ available: bool
1111
+
1112
+ class MCPClientToolListResponse(BaseModel):
1113
+ mcp_clients: list[dict[str, Any]]
1114
+
1115
+ async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
1116
+ """
1117
+ Get the list of MCP tools from all MCP clients in the workflow configuration.
1118
+ Checks session health and compares with workflow function group configuration.
1119
+ """
1120
+ mcp_clients_info = []
1121
+
1122
+ try:
1123
+ # Get all function groups from the builder
1124
+ function_groups = builder._function_groups
1125
+
1126
+ # Find MCP client function groups
1127
+ for group_name, configured_group in function_groups.items():
1128
+ if configured_group.config.type != "mcp_client":
1129
+ continue
1130
+
1131
+ from nat.plugins.mcp.client_impl import MCPClientConfig
1132
+
1133
+ config = configured_group.config
1134
+ assert isinstance(config, MCPClientConfig)
1135
+
1136
+ # Reuse the existing MCP client session stored on the function group instance
1137
+ group_instance = configured_group.instance
1138
+
1139
+ client = group_instance.mcp_client
1140
+ if client is None:
1141
+ raise RuntimeError(f"MCP client not found for group {group_name}")
1142
+
1143
+ try:
1144
+ session_healthy = False
1145
+ server_tools: dict[str, Any] = {}
1146
+
1147
+ try:
1148
+ server_tools = await client.get_tools()
1149
+ session_healthy = True
1150
+ except Exception as e:
1151
+ logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
1152
+ session_healthy = False
1153
+
1154
+ # Get workflow function group configuration (configured client-side tools)
1155
+ configured_short_names: set[str] = set()
1156
+ configured_full_to_fn: dict[str, Function] = {}
1157
+ try:
1158
+ # Pass a no-op filter function to bypass any default filtering that might check
1159
+ # health status, preventing potential infinite recursion during health status checks.
1160
+ async def pass_through_filter(fn):
1161
+ return fn
1162
+
1163
+ accessible_functions = await group_instance.get_accessible_functions(
1164
+ filter_fn=pass_through_filter)
1165
+ configured_full_to_fn = accessible_functions
1166
+ configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
1167
+ except Exception as e:
1168
+ logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
1169
+
1170
+ # Build alias->original mapping and override configs from overrides
1171
+ alias_to_original: dict[str, str] = {}
1172
+ override_configs: dict[str, Any] = {}
1173
+ try:
1174
+ if config.tool_overrides is not None:
1175
+ for orig_name, override in config.tool_overrides.items():
1176
+ if override.alias is not None:
1177
+ alias_to_original[override.alias] = orig_name
1178
+ override_configs[override.alias] = override
1179
+ else:
1180
+ override_configs[orig_name] = override
1181
+ except Exception:
1182
+ pass
1183
+
1184
+ # Create tool info list (always return configured tools; mark availability)
1185
+ tools_info: list[dict[str, Any]] = []
1186
+ available_count = 0
1187
+ for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
1188
+ orig_name = alias_to_original.get(fn_short, fn_short)
1189
+ available = session_healthy and (orig_name in server_tools)
1190
+ if available:
1191
+ available_count += 1
1192
+
1193
+ # Prefer tool override description, then workflow function description,
1194
+ # then server description
1195
+ description = ""
1196
+ if fn_short in override_configs and override_configs[fn_short].description:
1197
+ description = override_configs[fn_short].description
1198
+ elif wf_fn.description:
1199
+ description = wf_fn.description
1200
+ elif available and orig_name in server_tools:
1201
+ description = server_tools[orig_name].description or ""
1202
+
1203
+ tools_info.append(
1204
+ MCPToolInfo(name=fn_short,
1205
+ description=description or "",
1206
+ server=client.server_name,
1207
+ available=available).model_dump())
1208
+
1209
+ # Sort tools_info by name to maintain consistent ordering
1210
+ tools_info.sort(key=lambda x: x['name'])
1211
+
1212
+ mcp_clients_info.append({
1213
+ "function_group": group_name,
1214
+ "server": client.server_name,
1215
+ "transport": config.server.transport,
1216
+ "session_healthy": session_healthy,
1217
+ "tools": tools_info,
1218
+ "total_tools": len(configured_short_names),
1219
+ "available_tools": available_count
1220
+ })
1221
+
1222
+ except Exception as e:
1223
+ logger.error(f"Error processing MCP client {group_name}: {e}")
1224
+ mcp_clients_info.append({
1225
+ "function_group": group_name,
1226
+ "server": "unknown",
1227
+ "transport": config.server.transport if config.server else "unknown",
1228
+ "session_healthy": False,
1229
+ "error": str(e),
1230
+ "tools": [],
1231
+ "total_tools": 0,
1232
+ "workflow_tools": 0
1233
+ })
1234
+
1235
+ return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
1236
+
1237
+ except Exception as e:
1238
+ logger.error(f"Error in MCP client tool list endpoint: {e}")
1239
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
1240
+
1241
+ # Add the route to the FastAPI app
1242
+ app.add_api_route(
1243
+ path="/mcp/client/tool/list",
1244
+ endpoint=get_mcp_client_tool_list,
1245
+ methods=["GET"],
1246
+ response_model=MCPClientToolListResponse,
1247
+ description="Get list of MCP client tools with session health and workflow configuration comparison",
1248
+ responses={
1249
+ 200: {
1250
+ "description": "Successfully retrieved MCP client tool information",
1251
+ "content": {
1252
+ "application/json": {
1253
+ "example": {
1254
+ "mcp_clients": [{
1255
+ "function_group": "mcp_tools",
1256
+ "server": "streamable-http:http://localhost:9901/mcp",
1257
+ "transport": "streamable-http",
1258
+ "session_healthy": True,
1259
+ "tools": [{
1260
+ "name": "tool_a",
1261
+ "description": "Tool A description",
1262
+ "server": "streamable-http:http://localhost:9901/mcp",
1263
+ "available": True
1264
+ }],
1265
+ "total_tools": 1,
1266
+ "available_tools": 1
1267
+ }]
1268
+ }
1269
+ }
1270
+ }
1271
+ },
1272
+ 500: {
1273
+ "description": "Internal Server Error"
1274
+ }
1275
+ })
1276
+
1081
1277
  async def _add_flow(self, state: str, flow_state: FlowState):
1082
1278
  async with self._outstanding_flows_lock:
1083
1279
  self._outstanding_flows[state] = flow_state
@@ -1085,3 +1281,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1085
1281
  async def _remove_flow(self, state: str):
1086
1282
  async with self._outstanding_flows_lock:
1087
1283
  del self._outstanding_flows[state]
1284
+
1285
+
1286
+ # Prevent Sphinx from documenting items not a part of the public API
1287
+ __all__ = ["FastApiFrontEndPluginWorkerBase", "FastApiFrontEndPluginWorker", "RouteInfo"]