nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
import os
|
|
19
|
-
import time
|
|
20
20
|
import typing
|
|
21
21
|
from abc import ABC
|
|
22
22
|
from abc import abstractmethod
|
|
@@ -25,19 +25,21 @@ from collections.abc import Callable
|
|
|
25
25
|
from contextlib import asynccontextmanager
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
import httpx
|
|
29
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
29
30
|
from fastapi import Body
|
|
30
31
|
from fastapi import FastAPI
|
|
32
|
+
from fastapi import HTTPException
|
|
31
33
|
from fastapi import Request
|
|
32
34
|
from fastapi import Response
|
|
33
35
|
from fastapi import UploadFile
|
|
34
|
-
from fastapi.exceptions import HTTPException
|
|
35
36
|
from fastapi.middleware.cors import CORSMiddleware
|
|
36
37
|
from fastapi.responses import StreamingResponse
|
|
37
38
|
from pydantic import BaseModel
|
|
38
39
|
from pydantic import Field
|
|
39
40
|
from starlette.websockets import WebSocket
|
|
40
41
|
|
|
42
|
+
from nat.builder.function import Function
|
|
41
43
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
42
44
|
from nat.data_models.api_server import ChatRequest
|
|
43
45
|
from nat.data_models.api_server import ChatResponse
|
|
@@ -58,18 +60,30 @@ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
|
|
|
58
60
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
|
|
59
61
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
|
|
60
62
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
61
|
-
from nat.front_ends.fastapi.job_store import JobInfo
|
|
62
|
-
from nat.front_ends.fastapi.job_store import JobStore
|
|
63
63
|
from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
|
|
64
64
|
from nat.front_ends.fastapi.response_helpers import generate_single_response
|
|
65
65
|
from nat.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
|
|
66
66
|
from nat.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
|
|
67
67
|
from nat.front_ends.fastapi.step_adaptor import StepAdaptor
|
|
68
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
68
69
|
from nat.object_store.models import ObjectStoreItem
|
|
70
|
+
from nat.runtime.loader import load_workflow
|
|
69
71
|
from nat.runtime.session import SessionManager
|
|
70
72
|
|
|
71
73
|
logger = logging.getLogger(__name__)
|
|
72
74
|
|
|
75
|
+
_DASK_AVAILABLE = False
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
from nat.front_ends.fastapi.job_store import JobInfo
|
|
79
|
+
from nat.front_ends.fastapi.job_store import JobStatus
|
|
80
|
+
from nat.front_ends.fastapi.job_store import JobStore
|
|
81
|
+
_DASK_AVAILABLE = True
|
|
82
|
+
except ImportError:
|
|
83
|
+
JobInfo = None
|
|
84
|
+
JobStatus = None
|
|
85
|
+
JobStore = None
|
|
86
|
+
|
|
73
87
|
|
|
74
88
|
class FastApiFrontEndPluginWorkerBase(ABC):
|
|
75
89
|
|
|
@@ -80,10 +94,29 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
80
94
|
FastApiFrontEndConfig), ("Front end config is not FastApiFrontEndConfig")
|
|
81
95
|
|
|
82
96
|
self._front_end_config = config.general.front_end
|
|
83
|
-
|
|
84
|
-
self.
|
|
85
|
-
self._cleanup_tasks_lock = asyncio.Lock()
|
|
97
|
+
self._dask_available = False
|
|
98
|
+
self._job_store = None
|
|
86
99
|
self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
|
|
100
|
+
self._scheduler_address = os.environ.get("NAT_DASK_SCHEDULER_ADDRESS")
|
|
101
|
+
self._db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
|
|
102
|
+
self._config_file_path = get_config_file_path()
|
|
103
|
+
|
|
104
|
+
if self._scheduler_address is not None:
|
|
105
|
+
if not _DASK_AVAILABLE:
|
|
106
|
+
raise RuntimeError("Dask is not available, please install it to use the FastAPI front end with Dask.")
|
|
107
|
+
|
|
108
|
+
if self._db_url is None:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"NAT_JOB_STORE_DB_URL must be set when using Dask (configure a persistent JobStore database).")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
self._job_store = JobStore(scheduler_address=self._scheduler_address, db_url=self._db_url)
|
|
114
|
+
self._dask_available = True
|
|
115
|
+
logger.debug("Connected to Dask scheduler at %s", self._scheduler_address)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise RuntimeError(f"Failed to connect to Dask scheduler at {self._scheduler_address}: {e}") from e
|
|
118
|
+
else:
|
|
119
|
+
logger.debug("No Dask scheduler address provided, running without Dask support.")
|
|
87
120
|
|
|
88
121
|
@property
|
|
89
122
|
def config(self) -> Config:
|
|
@@ -107,20 +140,6 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
107
140
|
|
|
108
141
|
yield
|
|
109
142
|
|
|
110
|
-
# If a cleanup task is running, cancel it
|
|
111
|
-
async with self._cleanup_tasks_lock:
|
|
112
|
-
|
|
113
|
-
# Cancel all cleanup tasks
|
|
114
|
-
for task_name in self._cleanup_tasks:
|
|
115
|
-
cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
|
|
116
|
-
if cleanup_task is not None:
|
|
117
|
-
logger.info("Cancelling %s cleanup task", task_name)
|
|
118
|
-
cleanup_task.cancel()
|
|
119
|
-
else:
|
|
120
|
-
logger.warning("No cleanup task found for %s", task_name)
|
|
121
|
-
|
|
122
|
-
self._cleanup_tasks.clear()
|
|
123
|
-
|
|
124
143
|
logger.debug("Closing NAT server from process %s", os.getpid())
|
|
125
144
|
|
|
126
145
|
nat_app = FastAPI(lifespan=lifespan)
|
|
@@ -208,32 +227,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
208
227
|
self._outstanding_flows: dict[str, FlowState] = {}
|
|
209
228
|
self._outstanding_flows_lock = asyncio.Lock()
|
|
210
229
|
|
|
211
|
-
@staticmethod
|
|
212
|
-
async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
213
|
-
while True:
|
|
214
|
-
try:
|
|
215
|
-
job_store.cleanup_expired_jobs()
|
|
216
|
-
logger.debug("Expired %s jobs cleaned up", name)
|
|
217
|
-
except Exception as e:
|
|
218
|
-
logger.error("Error during %s job cleanup: %s", name, e)
|
|
219
|
-
await asyncio.sleep(sleep_time_sec)
|
|
220
|
-
|
|
221
|
-
async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
222
|
-
# Schedule periodic cleanup of expired jobs on first job creation
|
|
223
|
-
attr_name = f"{name}_cleanup_task"
|
|
224
|
-
|
|
225
|
-
# Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
|
|
226
|
-
if not hasattr(app.state, attr_name):
|
|
227
|
-
async with self._cleanup_tasks_lock:
|
|
228
|
-
if not hasattr(app.state, attr_name):
|
|
229
|
-
logger.info("Starting %s periodic cleanup task", name)
|
|
230
|
-
setattr(
|
|
231
|
-
app.state,
|
|
232
|
-
attr_name,
|
|
233
|
-
asyncio.create_task(
|
|
234
|
-
self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
|
|
235
|
-
self._cleanup_tasks.append(attr_name)
|
|
236
|
-
|
|
237
230
|
def get_step_adaptor(self) -> StepAdaptor:
|
|
238
231
|
|
|
239
232
|
return StepAdaptor(self.front_end_config.step_adaptor)
|
|
@@ -247,14 +240,15 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
247
240
|
|
|
248
241
|
async def add_routes(self, app: FastAPI, builder: WorkflowBuilder):
|
|
249
242
|
|
|
250
|
-
await self.add_default_route(app, SessionManager(builder.build()))
|
|
251
|
-
await self.add_evaluate_route(app, SessionManager(builder.build()))
|
|
243
|
+
await self.add_default_route(app, SessionManager(await builder.build()))
|
|
244
|
+
await self.add_evaluate_route(app, SessionManager(await builder.build()))
|
|
252
245
|
await self.add_static_files_route(app, builder)
|
|
253
246
|
await self.add_authorization_route(app)
|
|
247
|
+
await self.add_mcp_client_tool_list_route(app, builder)
|
|
254
248
|
|
|
255
249
|
for ep in self.front_end_config.endpoints:
|
|
256
250
|
|
|
257
|
-
entry_workflow = builder.build(entry_function=ep.function_name)
|
|
251
|
+
entry_workflow = await builder.build(entry_function=ep.function_name)
|
|
258
252
|
|
|
259
253
|
await self.add_route(app, endpoint=ep, session_manager=SessionManager(entry_workflow))
|
|
260
254
|
|
|
@@ -276,52 +270,72 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
276
270
|
},
|
|
277
271
|
}
|
|
278
272
|
|
|
279
|
-
#
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
273
|
+
# TODO: Find another way to limit the number of concurrent evaluations
|
|
274
|
+
async def run_evaluation(scheduler_address: str,
|
|
275
|
+
db_url: str,
|
|
276
|
+
workflow_config_file_path: str,
|
|
277
|
+
job_id: str,
|
|
278
|
+
eval_config_file: str,
|
|
279
|
+
reps: int):
|
|
285
280
|
"""Background task to run the evaluation."""
|
|
286
|
-
|
|
287
|
-
try:
|
|
288
|
-
# Create EvaluationRunConfig using the CLI defaults
|
|
289
|
-
eval_config = EvaluationRunConfig(config_file=Path(config_file), dataset=None, reps=reps)
|
|
290
|
-
|
|
291
|
-
# Create a new EvaluationRun with the evaluation-specific config
|
|
292
|
-
job_store.update_status(job_id, "running")
|
|
293
|
-
eval_runner = EvaluationRun(eval_config)
|
|
294
|
-
output: EvaluationRunOutput = await eval_runner.run_and_evaluate(session_manager=session_manager,
|
|
295
|
-
job_id=job_id)
|
|
296
|
-
if output.workflow_interrupted:
|
|
297
|
-
job_store.update_status(job_id, "interrupted")
|
|
298
|
-
else:
|
|
299
|
-
parent_dir = os.path.dirname(
|
|
300
|
-
output.workflow_output_file) if output.workflow_output_file else None
|
|
301
|
-
|
|
302
|
-
job_store.update_status(job_id, "success", output_path=str(parent_dir))
|
|
303
|
-
except Exception as e:
|
|
304
|
-
logger.error("Error in evaluation job %s: %s", job_id, str(e))
|
|
305
|
-
job_store.update_status(job_id, "failure", error=str(e))
|
|
306
|
-
|
|
307
|
-
async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
|
|
308
|
-
"""Handle evaluation requests."""
|
|
281
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
309
282
|
|
|
310
|
-
|
|
283
|
+
try:
|
|
284
|
+
# We have two config files, one for the workflow and one for the evaluation
|
|
285
|
+
# Create EvaluationRunConfig using the CLI defaults
|
|
286
|
+
eval_config = EvaluationRunConfig(config_file=Path(eval_config_file), dataset=None, reps=reps)
|
|
311
287
|
|
|
312
|
-
#
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
288
|
+
# Create a new EvaluationRun with the evaluation-specific config
|
|
289
|
+
await job_store.update_status(job_id, JobStatus.RUNNING)
|
|
290
|
+
eval_runner = EvaluationRun(eval_config)
|
|
291
|
+
|
|
292
|
+
async with load_workflow(workflow_config_file_path) as local_session_manager:
|
|
293
|
+
output: EvaluationRunOutput = await eval_runner.run_and_evaluate(
|
|
294
|
+
session_manager=local_session_manager, job_id=job_id)
|
|
295
|
+
|
|
296
|
+
if output.workflow_interrupted:
|
|
297
|
+
await job_store.update_status(job_id, JobStatus.INTERRUPTED)
|
|
298
|
+
else:
|
|
299
|
+
parent_dir = os.path.dirname(output.workflow_output_file) if output.workflow_output_file else None
|
|
317
300
|
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
301
|
+
await job_store.update_status(job_id, JobStatus.SUCCESS, output_path=str(parent_dir))
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.exception("Error in evaluation job %s", job_id)
|
|
304
|
+
await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
|
|
305
|
+
|
|
306
|
+
async def start_evaluation(request: EvaluateRequest, http_request: Request):
|
|
307
|
+
"""Handle evaluation requests."""
|
|
321
308
|
|
|
322
|
-
|
|
309
|
+
async with session_manager.session(http_connection=http_request):
|
|
323
310
|
|
|
324
|
-
|
|
311
|
+
# if job_id is present and already exists return the job info
|
|
312
|
+
# There is a race condition between this check and the actual job submission, however if the client is
|
|
313
|
+
# supplying their own job_ids, then it is their responsibility to ensure that the job_id is unique.
|
|
314
|
+
if request.job_id:
|
|
315
|
+
job_status = await self._job_store.get_status(request.job_id)
|
|
316
|
+
if job_status != JobStatus.NOT_FOUND:
|
|
317
|
+
return EvaluateResponse(job_id=request.job_id, status=job_status)
|
|
318
|
+
|
|
319
|
+
job_id = self._job_store.ensure_job_id(request.job_id)
|
|
320
|
+
|
|
321
|
+
await self._job_store.submit_job(job_id=job_id,
|
|
322
|
+
config_file=request.config_file,
|
|
323
|
+
expiry_seconds=request.expiry_seconds,
|
|
324
|
+
job_fn=run_evaluation,
|
|
325
|
+
job_args=[
|
|
326
|
+
self._scheduler_address,
|
|
327
|
+
self._db_url,
|
|
328
|
+
self._config_file_path,
|
|
329
|
+
job_id,
|
|
330
|
+
request.config_file,
|
|
331
|
+
request.reps
|
|
332
|
+
])
|
|
333
|
+
|
|
334
|
+
logger.info("Submitted evaluation job %s with config %s", job_id, request.config_file)
|
|
335
|
+
|
|
336
|
+
return EvaluateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
|
|
337
|
+
|
|
338
|
+
def translate_job_to_response(job: "JobInfo") -> EvaluateStatusResponse:
|
|
325
339
|
"""Translate a JobInfo object to an EvaluateStatusResponse."""
|
|
326
340
|
return EvaluateStatusResponse(job_id=job.job_id,
|
|
327
341
|
status=job.status,
|
|
@@ -330,15 +344,15 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
330
344
|
output_path=str(job.output_path),
|
|
331
345
|
created_at=job.created_at,
|
|
332
346
|
updated_at=job.updated_at,
|
|
333
|
-
expires_at=
|
|
347
|
+
expires_at=self._job_store.get_expires_at(job))
|
|
334
348
|
|
|
335
349
|
async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusResponse:
|
|
336
350
|
"""Get the status of an evaluation job."""
|
|
337
351
|
logger.info("Getting status for job %s", job_id)
|
|
338
352
|
|
|
339
|
-
async with session_manager.session(
|
|
353
|
+
async with session_manager.session(http_connection=http_request):
|
|
340
354
|
|
|
341
|
-
job =
|
|
355
|
+
job = await self._job_store.get_job(job_id)
|
|
342
356
|
if not job:
|
|
343
357
|
logger.warning("Job %s not found", job_id)
|
|
344
358
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
@@ -349,9 +363,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
349
363
|
"""Get the status of the last created evaluation job."""
|
|
350
364
|
logger.info("Getting last job status")
|
|
351
365
|
|
|
352
|
-
async with session_manager.session(
|
|
366
|
+
async with session_manager.session(http_connection=http_request):
|
|
353
367
|
|
|
354
|
-
job =
|
|
368
|
+
job = await self._job_store.get_last_job()
|
|
355
369
|
if not job:
|
|
356
370
|
logger.warning("No jobs found when requesting last job status")
|
|
357
371
|
raise HTTPException(status_code=404, detail="No jobs found")
|
|
@@ -361,65 +375,69 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
361
375
|
async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
|
|
362
376
|
"""Get all jobs, optionally filtered by status."""
|
|
363
377
|
|
|
364
|
-
async with session_manager.session(
|
|
378
|
+
async with session_manager.session(http_connection=http_request):
|
|
365
379
|
|
|
366
380
|
if status is None:
|
|
367
381
|
logger.info("Getting all jobs")
|
|
368
|
-
jobs =
|
|
382
|
+
jobs = await self._job_store.get_all_jobs()
|
|
369
383
|
else:
|
|
370
384
|
logger.info("Getting jobs with status %s", status)
|
|
371
|
-
jobs =
|
|
385
|
+
jobs = await self._job_store.get_jobs_by_status(JobStatus(status))
|
|
386
|
+
|
|
372
387
|
logger.info("Found %d jobs", len(jobs))
|
|
373
388
|
return [translate_job_to_response(job) for job in jobs]
|
|
374
389
|
|
|
375
390
|
if self.front_end_config.evaluate.path:
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
391
|
+
if self._dask_available:
|
|
392
|
+
# Add last job endpoint first (most specific)
|
|
393
|
+
app.add_api_route(
|
|
394
|
+
path=f"{self.front_end_config.evaluate.path}/job/last",
|
|
395
|
+
endpoint=get_last_job_status,
|
|
396
|
+
methods=["GET"],
|
|
397
|
+
response_model=EvaluateStatusResponse,
|
|
398
|
+
description="Get the status of the last created evaluation job",
|
|
399
|
+
responses={
|
|
400
|
+
404: {
|
|
401
|
+
"description": "No jobs found"
|
|
402
|
+
}, 500: response_500
|
|
403
|
+
},
|
|
404
|
+
)
|
|
389
405
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
406
|
+
# Add specific job endpoint (least specific)
|
|
407
|
+
app.add_api_route(
|
|
408
|
+
path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
|
|
409
|
+
endpoint=get_job_status,
|
|
410
|
+
methods=["GET"],
|
|
411
|
+
response_model=EvaluateStatusResponse,
|
|
412
|
+
description="Get the status of an evaluation job",
|
|
413
|
+
responses={
|
|
414
|
+
404: {
|
|
415
|
+
"description": "Job not found"
|
|
416
|
+
}, 500: response_500
|
|
417
|
+
},
|
|
418
|
+
)
|
|
403
419
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
420
|
+
# Add jobs endpoint with optional status query parameter
|
|
421
|
+
app.add_api_route(
|
|
422
|
+
path=f"{self.front_end_config.evaluate.path}/jobs",
|
|
423
|
+
endpoint=get_jobs,
|
|
424
|
+
methods=["GET"],
|
|
425
|
+
response_model=list[EvaluateStatusResponse],
|
|
426
|
+
description="Get all jobs, optionally filtered by status",
|
|
427
|
+
responses={500: response_500},
|
|
428
|
+
)
|
|
413
429
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
430
|
+
# Add HTTP endpoint for evaluation
|
|
431
|
+
app.add_api_route(
|
|
432
|
+
path=self.front_end_config.evaluate.path,
|
|
433
|
+
endpoint=start_evaluation,
|
|
434
|
+
methods=[self.front_end_config.evaluate.method],
|
|
435
|
+
response_model=EvaluateResponse,
|
|
436
|
+
description=self.front_end_config.evaluate.description,
|
|
437
|
+
responses={500: response_500},
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
logger.warning("Dask is not available, evaluation endpoints will not be added.")
|
|
423
441
|
|
|
424
442
|
async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
425
443
|
|
|
@@ -522,25 +540,27 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
522
540
|
|
|
523
541
|
workflow = session_manager.workflow
|
|
524
542
|
|
|
525
|
-
GenerateBodyType = workflow.input_schema
|
|
526
|
-
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
|
-
GenerateSingleResponseType = workflow.single_output_schema
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
default=
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
543
|
+
GenerateBodyType = workflow.input_schema
|
|
544
|
+
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
545
|
+
GenerateSingleResponseType = workflow.single_output_schema
|
|
546
|
+
|
|
547
|
+
if self._dask_available:
|
|
548
|
+
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
549
|
+
# Consider prefixing these with "nat_" to avoid conflicts
|
|
550
|
+
|
|
551
|
+
class AsyncGenerateRequest(GenerateBodyType):
|
|
552
|
+
job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
|
|
553
|
+
sync_timeout: int = Field(
|
|
554
|
+
default=0,
|
|
555
|
+
ge=0,
|
|
556
|
+
le=300,
|
|
557
|
+
description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
|
|
558
|
+
"if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
|
|
559
|
+
expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
|
|
560
|
+
ge=JobStore.MIN_EXPIRY,
|
|
561
|
+
le=JobStore.MAX_EXPIRY,
|
|
562
|
+
description="Optional time (in seconds) before the job expires. "
|
|
563
|
+
"Clamped between 600 (10 min) and 86400 (24h).")
|
|
544
564
|
|
|
545
565
|
# Ensure that the input is in the body. POD types are treated as query parameters
|
|
546
566
|
if (not issubclass(GenerateBodyType, BaseModel)):
|
|
@@ -560,19 +580,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
560
580
|
},
|
|
561
581
|
}
|
|
562
582
|
|
|
563
|
-
# Create job store for tracking async generation jobs
|
|
564
|
-
job_store = JobStore()
|
|
565
|
-
|
|
566
|
-
# Run up to max_running_async_jobs jobs at the same time
|
|
567
|
-
async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
|
|
568
|
-
|
|
569
583
|
def get_single_endpoint(result_type: type | None):
|
|
570
584
|
|
|
571
585
|
async def get_single(response: Response, request: Request):
|
|
572
586
|
|
|
573
587
|
response.headers["Content-Type"] = "application/json"
|
|
574
588
|
|
|
575
|
-
async with session_manager.session(
|
|
589
|
+
async with session_manager.session(http_connection=request,
|
|
576
590
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
577
591
|
|
|
578
592
|
return await generate_single_response(None, session_manager, result_type=result_type)
|
|
@@ -583,7 +597,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
583
597
|
|
|
584
598
|
async def get_stream(request: Request):
|
|
585
599
|
|
|
586
|
-
async with session_manager.session(
|
|
600
|
+
async with session_manager.session(http_connection=request,
|
|
587
601
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
588
602
|
|
|
589
603
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -618,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
618
632
|
|
|
619
633
|
response.headers["Content-Type"] = "application/json"
|
|
620
634
|
|
|
621
|
-
async with session_manager.session(
|
|
635
|
+
async with session_manager.session(http_connection=request,
|
|
622
636
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
623
637
|
|
|
624
638
|
return await generate_single_response(payload, session_manager, result_type=result_type)
|
|
@@ -632,7 +646,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
632
646
|
|
|
633
647
|
async def post_stream(request: Request, payload: request_type):
|
|
634
648
|
|
|
635
|
-
async with session_manager.session(
|
|
649
|
+
async with session_manager.session(http_connection=request,
|
|
636
650
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
637
651
|
|
|
638
652
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -677,7 +691,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
677
691
|
# Check if streaming is requested
|
|
678
692
|
stream_requested = getattr(payload, 'stream', False)
|
|
679
693
|
|
|
680
|
-
async with session_manager.session(
|
|
694
|
+
async with session_manager.session(http_connection=request):
|
|
681
695
|
if stream_requested:
|
|
682
696
|
# Return streaming response
|
|
683
697
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -688,115 +702,112 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
688
702
|
step_adaptor=self.get_step_adaptor(),
|
|
689
703
|
result_type=ChatResponseChunk,
|
|
690
704
|
output_type=ChatResponseChunk))
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
705
|
+
|
|
706
|
+
# Return single response - check if workflow supports non-streaming
|
|
707
|
+
try:
|
|
708
|
+
response.headers["Content-Type"] = "application/json"
|
|
709
|
+
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
|
|
710
|
+
except ValueError as e:
|
|
711
|
+
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
712
|
+
# Workflow only supports streaming, but client requested non-streaming
|
|
713
|
+
# Fall back to streaming and collect the result
|
|
714
|
+
chunks = []
|
|
715
|
+
async for chunk_str in generate_streaming_response_as_str(
|
|
716
|
+
payload,
|
|
717
|
+
session_manager=session_manager,
|
|
718
|
+
streaming=True,
|
|
719
|
+
step_adaptor=self.get_step_adaptor(),
|
|
720
|
+
result_type=ChatResponseChunk,
|
|
721
|
+
output_type=ChatResponseChunk):
|
|
722
|
+
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
723
|
+
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
724
|
+
if chunk_data:
|
|
725
|
+
try:
|
|
726
|
+
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
727
|
+
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
728
|
+
and chunk_json.choices[0].delta
|
|
729
|
+
and chunk_json.choices[0].delta.content is not None):
|
|
730
|
+
chunks.append(chunk_json.choices[0].delta.content)
|
|
731
|
+
except Exception:
|
|
732
|
+
continue
|
|
733
|
+
|
|
734
|
+
# Create a single response from collected chunks
|
|
735
|
+
content = "".join(chunks)
|
|
736
|
+
single_response = ChatResponse.from_string(content)
|
|
694
737
|
response.headers["Content-Type"] = "application/json"
|
|
695
|
-
return
|
|
696
|
-
|
|
697
|
-
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
698
|
-
# Workflow only supports streaming, but client requested non-streaming
|
|
699
|
-
# Fall back to streaming and collect the result
|
|
700
|
-
chunks = []
|
|
701
|
-
async for chunk_str in generate_streaming_response_as_str(
|
|
702
|
-
payload,
|
|
703
|
-
session_manager=session_manager,
|
|
704
|
-
streaming=True,
|
|
705
|
-
step_adaptor=self.get_step_adaptor(),
|
|
706
|
-
result_type=ChatResponseChunk,
|
|
707
|
-
output_type=ChatResponseChunk):
|
|
708
|
-
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
709
|
-
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
710
|
-
if chunk_data:
|
|
711
|
-
try:
|
|
712
|
-
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
713
|
-
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
714
|
-
and chunk_json.choices[0].delta
|
|
715
|
-
and chunk_json.choices[0].delta.content is not None):
|
|
716
|
-
chunks.append(chunk_json.choices[0].delta.content)
|
|
717
|
-
except Exception:
|
|
718
|
-
continue
|
|
719
|
-
|
|
720
|
-
# Create a single response from collected chunks
|
|
721
|
-
content = "".join(chunks)
|
|
722
|
-
single_response = ChatResponse.from_string(content)
|
|
723
|
-
response.headers["Content-Type"] = "application/json"
|
|
724
|
-
return single_response
|
|
725
|
-
else:
|
|
726
|
-
raise
|
|
738
|
+
return single_response
|
|
739
|
+
raise
|
|
727
740
|
|
|
728
741
|
return post_openai_api_compatible
|
|
729
742
|
|
|
730
|
-
|
|
731
|
-
"""Background task to run the evaluation."""
|
|
732
|
-
async with async_job_concurrency:
|
|
733
|
-
try:
|
|
734
|
-
result = await generate_single_response(payload=payload,
|
|
735
|
-
session_manager=session_manager,
|
|
736
|
-
result_type=result_type)
|
|
737
|
-
job_store.update_status(job_id, "success", output=result)
|
|
738
|
-
except Exception as e:
|
|
739
|
-
logger.error("Error in evaluation job %s: %s", job_id, e)
|
|
740
|
-
job_store.update_status(job_id, "failure", error=str(e))
|
|
741
|
-
|
|
742
|
-
def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
|
|
743
|
+
def _job_status_to_response(job: "JobInfo") -> AsyncGenerationStatusResponse:
|
|
743
744
|
job_output = job.output
|
|
744
745
|
if job_output is not None:
|
|
745
|
-
|
|
746
|
+
try:
|
|
747
|
+
job_output = json.loads(job_output)
|
|
748
|
+
except json.JSONDecodeError:
|
|
749
|
+
logger.error("Failed to parse job output as JSON: %s", job_output)
|
|
750
|
+
job_output = {"error": "Output parsing failed"}
|
|
751
|
+
|
|
746
752
|
return AsyncGenerationStatusResponse(job_id=job.job_id,
|
|
747
753
|
status=job.status,
|
|
748
754
|
error=job.error,
|
|
749
755
|
output=job_output,
|
|
750
756
|
created_at=job.created_at,
|
|
751
757
|
updated_at=job.updated_at,
|
|
752
|
-
expires_at=
|
|
758
|
+
expires_at=self._job_store.get_expires_at(job))
|
|
759
|
+
|
|
760
|
+
async def run_generation(scheduler_address: str,
|
|
761
|
+
db_url: str,
|
|
762
|
+
config_file_path: str,
|
|
763
|
+
job_id: str,
|
|
764
|
+
payload: typing.Any):
|
|
765
|
+
"""Background task to run the workflow."""
|
|
766
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
767
|
+
try:
|
|
768
|
+
async with load_workflow(config_file_path) as local_session_manager:
|
|
769
|
+
result = await generate_single_response(
|
|
770
|
+
payload, local_session_manager, result_type=local_session_manager.workflow.single_output_schema)
|
|
753
771
|
|
|
754
|
-
|
|
772
|
+
await job_store.update_status(job_id, JobStatus.SUCCESS, output=result)
|
|
773
|
+
except Exception as e:
|
|
774
|
+
logger.exception("Error in async job %s", job_id)
|
|
775
|
+
await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
|
|
776
|
+
|
|
777
|
+
def post_async_generation(request_type: type):
|
|
755
778
|
|
|
756
779
|
async def start_async_generation(
|
|
757
|
-
request: request_type,
|
|
780
|
+
request: request_type, response: Response,
|
|
758
781
|
http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
|
|
759
782
|
"""Handle async generation requests."""
|
|
760
783
|
|
|
761
|
-
async with session_manager.session(
|
|
784
|
+
async with session_manager.session(http_connection=http_request):
|
|
762
785
|
|
|
763
786
|
# if job_id is present and already exists return the job info
|
|
764
787
|
if request.job_id:
|
|
765
|
-
job =
|
|
788
|
+
job = await self._job_store.get_job(request.job_id)
|
|
766
789
|
if job:
|
|
767
790
|
return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
|
|
768
791
|
|
|
769
|
-
job_id =
|
|
770
|
-
await self.
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
now = time.time()
|
|
787
|
-
sync_timeout = now + request.sync_timeout
|
|
788
|
-
while time.time() < sync_timeout:
|
|
789
|
-
job = job_store.get_job(job_id)
|
|
790
|
-
if job is not None and job.status not in job_store.ACTIVE_STATUS:
|
|
791
|
-
# If the job is done, return the result
|
|
792
|
-
response.status_code = 200
|
|
793
|
-
return _job_status_to_response(job)
|
|
794
|
-
|
|
795
|
-
# Sleep for a short time before checking again
|
|
796
|
-
await asyncio.sleep(0.1)
|
|
792
|
+
job_id = self._job_store.ensure_job_id(request.job_id)
|
|
793
|
+
(_, job) = await self._job_store.submit_job(job_id=job_id,
|
|
794
|
+
expiry_seconds=request.expiry_seconds,
|
|
795
|
+
job_fn=run_generation,
|
|
796
|
+
sync_timeout=request.sync_timeout,
|
|
797
|
+
job_args=[
|
|
798
|
+
self._scheduler_address,
|
|
799
|
+
self._db_url,
|
|
800
|
+
self._config_file_path,
|
|
801
|
+
job_id,
|
|
802
|
+
request.model_dump(mode="json")
|
|
803
|
+
])
|
|
804
|
+
|
|
805
|
+
if job is not None:
|
|
806
|
+
response.status_code = 200
|
|
807
|
+
return _job_status_to_response(job)
|
|
797
808
|
|
|
798
809
|
response.status_code = 202
|
|
799
|
-
return AsyncGenerateResponse(job_id=job_id, status=
|
|
810
|
+
return AsyncGenerateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
|
|
800
811
|
|
|
801
812
|
return start_async_generation
|
|
802
813
|
|
|
@@ -804,10 +815,10 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
804
815
|
"""Get the status of an async job."""
|
|
805
816
|
logger.info("Getting status for job %s", job_id)
|
|
806
817
|
|
|
807
|
-
async with session_manager.session(
|
|
818
|
+
async with session_manager.session(http_connection=http_request):
|
|
808
819
|
|
|
809
|
-
job =
|
|
810
|
-
if
|
|
820
|
+
job = await self._job_store.get_job(job_id)
|
|
821
|
+
if job is None:
|
|
811
822
|
logger.warning("Job %s not found", job_id)
|
|
812
823
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
813
824
|
|
|
@@ -935,30 +946,33 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
935
946
|
responses={500: response_500},
|
|
936
947
|
)
|
|
937
948
|
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
949
|
+
if self._dask_available:
|
|
950
|
+
app.add_api_route(
|
|
951
|
+
path=f"{endpoint.path}/async",
|
|
952
|
+
endpoint=post_async_generation(request_type=AsyncGenerateRequest),
|
|
953
|
+
methods=[endpoint.method],
|
|
954
|
+
response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
|
|
955
|
+
description="Start an async generate job",
|
|
956
|
+
responses={500: response_500},
|
|
957
|
+
)
|
|
958
|
+
else:
|
|
959
|
+
logger.warning("Dask is not available, async generation endpoints will not be added.")
|
|
947
960
|
else:
|
|
948
961
|
raise ValueError(f"Unsupported method {endpoint.method}")
|
|
949
962
|
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
963
|
+
if self._dask_available:
|
|
964
|
+
app.add_api_route(
|
|
965
|
+
path=f"{endpoint.path}/async/job/{{job_id}}",
|
|
966
|
+
endpoint=get_async_job_status,
|
|
967
|
+
methods=["GET"],
|
|
968
|
+
response_model=AsyncGenerationStatusResponse,
|
|
969
|
+
description="Get the status of an async job",
|
|
970
|
+
responses={
|
|
971
|
+
404: {
|
|
972
|
+
"description": "Job not found"
|
|
973
|
+
}, 500: response_500
|
|
974
|
+
},
|
|
975
|
+
)
|
|
962
976
|
|
|
963
977
|
if (endpoint.openai_api_path):
|
|
964
978
|
if (endpoint.method == "GET"):
|
|
@@ -1061,8 +1075,13 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1061
1075
|
code_verifier=verifier,
|
|
1062
1076
|
state=state)
|
|
1063
1077
|
flow_state.future.set_result(res)
|
|
1078
|
+
except OAuthError as e:
|
|
1079
|
+
flow_state.future.set_exception(
|
|
1080
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
1081
|
+
except httpx.HTTPError as e:
|
|
1082
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
1064
1083
|
except Exception as e:
|
|
1065
|
-
flow_state.future.set_exception(e)
|
|
1084
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
1066
1085
|
|
|
1067
1086
|
return HTMLResponse(content=AUTH_REDIRECT_SUCCESS_HTML,
|
|
1068
1087
|
status_code=200,
|
|
@@ -1078,6 +1097,183 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1078
1097
|
methods=["GET"],
|
|
1079
1098
|
description="Handles the authorization code and state returned from the Authorization Code Grant Flow.")
|
|
1080
1099
|
|
|
1100
|
+
async def add_mcp_client_tool_list_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
1101
|
+
"""Add the MCP client tool list endpoint to the FastAPI app."""
|
|
1102
|
+
from typing import Any
|
|
1103
|
+
|
|
1104
|
+
from pydantic import BaseModel
|
|
1105
|
+
|
|
1106
|
+
class MCPToolInfo(BaseModel):
|
|
1107
|
+
name: str
|
|
1108
|
+
description: str
|
|
1109
|
+
server: str
|
|
1110
|
+
available: bool
|
|
1111
|
+
|
|
1112
|
+
class MCPClientToolListResponse(BaseModel):
|
|
1113
|
+
mcp_clients: list[dict[str, Any]]
|
|
1114
|
+
|
|
1115
|
+
async def get_mcp_client_tool_list() -> MCPClientToolListResponse:
|
|
1116
|
+
"""
|
|
1117
|
+
Get the list of MCP tools from all MCP clients in the workflow configuration.
|
|
1118
|
+
Checks session health and compares with workflow function group configuration.
|
|
1119
|
+
"""
|
|
1120
|
+
mcp_clients_info = []
|
|
1121
|
+
|
|
1122
|
+
try:
|
|
1123
|
+
# Get all function groups from the builder
|
|
1124
|
+
function_groups = builder._function_groups
|
|
1125
|
+
|
|
1126
|
+
# Find MCP client function groups
|
|
1127
|
+
for group_name, configured_group in function_groups.items():
|
|
1128
|
+
if configured_group.config.type != "mcp_client":
|
|
1129
|
+
continue
|
|
1130
|
+
|
|
1131
|
+
from nat.plugins.mcp.client_impl import MCPClientConfig
|
|
1132
|
+
|
|
1133
|
+
config = configured_group.config
|
|
1134
|
+
assert isinstance(config, MCPClientConfig)
|
|
1135
|
+
|
|
1136
|
+
# Reuse the existing MCP client session stored on the function group instance
|
|
1137
|
+
group_instance = configured_group.instance
|
|
1138
|
+
|
|
1139
|
+
client = group_instance.mcp_client
|
|
1140
|
+
if client is None:
|
|
1141
|
+
raise RuntimeError(f"MCP client not found for group {group_name}")
|
|
1142
|
+
|
|
1143
|
+
try:
|
|
1144
|
+
session_healthy = False
|
|
1145
|
+
server_tools: dict[str, Any] = {}
|
|
1146
|
+
|
|
1147
|
+
try:
|
|
1148
|
+
server_tools = await client.get_tools()
|
|
1149
|
+
session_healthy = True
|
|
1150
|
+
except Exception as e:
|
|
1151
|
+
logger.exception(f"Failed to connect to MCP server {client.server_name}: {e}")
|
|
1152
|
+
session_healthy = False
|
|
1153
|
+
|
|
1154
|
+
# Get workflow function group configuration (configured client-side tools)
|
|
1155
|
+
configured_short_names: set[str] = set()
|
|
1156
|
+
configured_full_to_fn: dict[str, Function] = {}
|
|
1157
|
+
try:
|
|
1158
|
+
# Pass a no-op filter function to bypass any default filtering that might check
|
|
1159
|
+
# health status, preventing potential infinite recursion during health status checks.
|
|
1160
|
+
async def pass_through_filter(fn):
|
|
1161
|
+
return fn
|
|
1162
|
+
|
|
1163
|
+
accessible_functions = await group_instance.get_accessible_functions(
|
|
1164
|
+
filter_fn=pass_through_filter)
|
|
1165
|
+
configured_full_to_fn = accessible_functions
|
|
1166
|
+
configured_short_names = {name.split('.', 1)[1] for name in accessible_functions.keys()}
|
|
1167
|
+
except Exception as e:
|
|
1168
|
+
logger.exception(f"Failed to get accessible functions for group {group_name}: {e}")
|
|
1169
|
+
|
|
1170
|
+
# Build alias->original mapping and override configs from overrides
|
|
1171
|
+
alias_to_original: dict[str, str] = {}
|
|
1172
|
+
override_configs: dict[str, Any] = {}
|
|
1173
|
+
try:
|
|
1174
|
+
if config.tool_overrides is not None:
|
|
1175
|
+
for orig_name, override in config.tool_overrides.items():
|
|
1176
|
+
if override.alias is not None:
|
|
1177
|
+
alias_to_original[override.alias] = orig_name
|
|
1178
|
+
override_configs[override.alias] = override
|
|
1179
|
+
else:
|
|
1180
|
+
override_configs[orig_name] = override
|
|
1181
|
+
except Exception:
|
|
1182
|
+
pass
|
|
1183
|
+
|
|
1184
|
+
# Create tool info list (always return configured tools; mark availability)
|
|
1185
|
+
tools_info: list[dict[str, Any]] = []
|
|
1186
|
+
available_count = 0
|
|
1187
|
+
for wf_fn, fn_short in zip(configured_full_to_fn.values(), configured_short_names):
|
|
1188
|
+
orig_name = alias_to_original.get(fn_short, fn_short)
|
|
1189
|
+
available = session_healthy and (orig_name in server_tools)
|
|
1190
|
+
if available:
|
|
1191
|
+
available_count += 1
|
|
1192
|
+
|
|
1193
|
+
# Prefer tool override description, then workflow function description,
|
|
1194
|
+
# then server description
|
|
1195
|
+
description = ""
|
|
1196
|
+
if fn_short in override_configs and override_configs[fn_short].description:
|
|
1197
|
+
description = override_configs[fn_short].description
|
|
1198
|
+
elif wf_fn.description:
|
|
1199
|
+
description = wf_fn.description
|
|
1200
|
+
elif available and orig_name in server_tools:
|
|
1201
|
+
description = server_tools[orig_name].description or ""
|
|
1202
|
+
|
|
1203
|
+
tools_info.append(
|
|
1204
|
+
MCPToolInfo(name=fn_short,
|
|
1205
|
+
description=description or "",
|
|
1206
|
+
server=client.server_name,
|
|
1207
|
+
available=available).model_dump())
|
|
1208
|
+
|
|
1209
|
+
# Sort tools_info by name to maintain consistent ordering
|
|
1210
|
+
tools_info.sort(key=lambda x: x['name'])
|
|
1211
|
+
|
|
1212
|
+
mcp_clients_info.append({
|
|
1213
|
+
"function_group": group_name,
|
|
1214
|
+
"server": client.server_name,
|
|
1215
|
+
"transport": config.server.transport,
|
|
1216
|
+
"session_healthy": session_healthy,
|
|
1217
|
+
"tools": tools_info,
|
|
1218
|
+
"total_tools": len(configured_short_names),
|
|
1219
|
+
"available_tools": available_count
|
|
1220
|
+
})
|
|
1221
|
+
|
|
1222
|
+
except Exception as e:
|
|
1223
|
+
logger.error(f"Error processing MCP client {group_name}: {e}")
|
|
1224
|
+
mcp_clients_info.append({
|
|
1225
|
+
"function_group": group_name,
|
|
1226
|
+
"server": "unknown",
|
|
1227
|
+
"transport": config.server.transport if config.server else "unknown",
|
|
1228
|
+
"session_healthy": False,
|
|
1229
|
+
"error": str(e),
|
|
1230
|
+
"tools": [],
|
|
1231
|
+
"total_tools": 0,
|
|
1232
|
+
"workflow_tools": 0
|
|
1233
|
+
})
|
|
1234
|
+
|
|
1235
|
+
return MCPClientToolListResponse(mcp_clients=mcp_clients_info)
|
|
1236
|
+
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
logger.error(f"Error in MCP client tool list endpoint: {e}")
|
|
1239
|
+
raise HTTPException(status_code=500, detail=f"Failed to retrieve MCP client information: {str(e)}")
|
|
1240
|
+
|
|
1241
|
+
# Add the route to the FastAPI app
|
|
1242
|
+
app.add_api_route(
|
|
1243
|
+
path="/mcp/client/tool/list",
|
|
1244
|
+
endpoint=get_mcp_client_tool_list,
|
|
1245
|
+
methods=["GET"],
|
|
1246
|
+
response_model=MCPClientToolListResponse,
|
|
1247
|
+
description="Get list of MCP client tools with session health and workflow configuration comparison",
|
|
1248
|
+
responses={
|
|
1249
|
+
200: {
|
|
1250
|
+
"description": "Successfully retrieved MCP client tool information",
|
|
1251
|
+
"content": {
|
|
1252
|
+
"application/json": {
|
|
1253
|
+
"example": {
|
|
1254
|
+
"mcp_clients": [{
|
|
1255
|
+
"function_group": "mcp_tools",
|
|
1256
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1257
|
+
"transport": "streamable-http",
|
|
1258
|
+
"session_healthy": True,
|
|
1259
|
+
"tools": [{
|
|
1260
|
+
"name": "tool_a",
|
|
1261
|
+
"description": "Tool A description",
|
|
1262
|
+
"server": "streamable-http:http://localhost:9901/mcp",
|
|
1263
|
+
"available": True
|
|
1264
|
+
}],
|
|
1265
|
+
"total_tools": 1,
|
|
1266
|
+
"available_tools": 1
|
|
1267
|
+
}]
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
}
|
|
1271
|
+
},
|
|
1272
|
+
500: {
|
|
1273
|
+
"description": "Internal Server Error"
|
|
1274
|
+
}
|
|
1275
|
+
})
|
|
1276
|
+
|
|
1081
1277
|
async def _add_flow(self, state: str, flow_state: FlowState):
|
|
1082
1278
|
async with self._outstanding_flows_lock:
|
|
1083
1279
|
self._outstanding_flows[state] = flow_state
|
|
@@ -1085,3 +1281,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1085
1281
|
async def _remove_flow(self, state: str):
|
|
1086
1282
|
async with self._outstanding_flows_lock:
|
|
1087
1283
|
del self._outstanding_flows[state]
|
|
1284
|
+
|
|
1285
|
+
|
|
1286
|
+
# Prevent Sphinx from documenting items not a part of the public API
|
|
1287
|
+
__all__ = ["FastApiFrontEndPluginWorkerBase", "FastApiFrontEndPluginWorker", "RouteInfo"]
|