nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -14
- nat/agent/reasoning_agent/reasoning_agent.py +9 -7
- nat/agent/register.py +1 -0
- nat/agent/rewoo_agent/agent.py +9 -2
- nat/agent/rewoo_agent/register.py +16 -12
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +14 -13
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/context.py +28 -6
- nat/builder/function.py +313 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +215 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +4 -9
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/log_levels.py +25 -0
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
import os
|
|
19
|
-
import time
|
|
20
20
|
import typing
|
|
21
21
|
from abc import ABC
|
|
22
22
|
from abc import abstractmethod
|
|
@@ -25,7 +25,6 @@ from collections.abc import Callable
|
|
|
25
25
|
from contextlib import asynccontextmanager
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
|
|
28
|
-
from fastapi import BackgroundTasks
|
|
29
28
|
from fastapi import Body
|
|
30
29
|
from fastapi import FastAPI
|
|
31
30
|
from fastapi import Request
|
|
@@ -58,18 +57,30 @@ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
|
|
|
58
57
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
|
|
59
58
|
from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
|
|
60
59
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
61
|
-
from nat.front_ends.fastapi.job_store import JobInfo
|
|
62
|
-
from nat.front_ends.fastapi.job_store import JobStore
|
|
63
60
|
from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
|
|
64
61
|
from nat.front_ends.fastapi.response_helpers import generate_single_response
|
|
65
62
|
from nat.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
|
|
66
63
|
from nat.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
|
|
67
64
|
from nat.front_ends.fastapi.step_adaptor import StepAdaptor
|
|
65
|
+
from nat.front_ends.fastapi.utils import get_config_file_path
|
|
68
66
|
from nat.object_store.models import ObjectStoreItem
|
|
67
|
+
from nat.runtime.loader import load_workflow
|
|
69
68
|
from nat.runtime.session import SessionManager
|
|
70
69
|
|
|
71
70
|
logger = logging.getLogger(__name__)
|
|
72
71
|
|
|
72
|
+
_DASK_AVAILABLE = False
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
from nat.front_ends.fastapi.job_store import JobInfo
|
|
76
|
+
from nat.front_ends.fastapi.job_store import JobStatus
|
|
77
|
+
from nat.front_ends.fastapi.job_store import JobStore
|
|
78
|
+
_DASK_AVAILABLE = True
|
|
79
|
+
except ImportError:
|
|
80
|
+
JobInfo = None
|
|
81
|
+
JobStatus = None
|
|
82
|
+
JobStore = None
|
|
83
|
+
|
|
73
84
|
|
|
74
85
|
class FastApiFrontEndPluginWorkerBase(ABC):
|
|
75
86
|
|
|
@@ -80,10 +91,29 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
80
91
|
FastApiFrontEndConfig), ("Front end config is not FastApiFrontEndConfig")
|
|
81
92
|
|
|
82
93
|
self._front_end_config = config.general.front_end
|
|
83
|
-
|
|
84
|
-
self.
|
|
85
|
-
self._cleanup_tasks_lock = asyncio.Lock()
|
|
94
|
+
self._dask_available = False
|
|
95
|
+
self._job_store = None
|
|
86
96
|
self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
|
|
97
|
+
self._scheduler_address = os.environ.get("NAT_DASK_SCHEDULER_ADDRESS")
|
|
98
|
+
self._db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
|
|
99
|
+
self._config_file_path = get_config_file_path()
|
|
100
|
+
|
|
101
|
+
if self._scheduler_address is not None:
|
|
102
|
+
if not _DASK_AVAILABLE:
|
|
103
|
+
raise RuntimeError("Dask is not available, please install it to use the FastAPI front end with Dask.")
|
|
104
|
+
|
|
105
|
+
if self._db_url is None:
|
|
106
|
+
raise RuntimeError(
|
|
107
|
+
"NAT_JOB_STORE_DB_URL must be set when using Dask (configure a persistent JobStore database).")
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
self._job_store = JobStore(scheduler_address=self._scheduler_address, db_url=self._db_url)
|
|
111
|
+
self._dask_available = True
|
|
112
|
+
logger.debug("Connected to Dask scheduler at %s", self._scheduler_address)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise RuntimeError(f"Failed to connect to Dask scheduler at {self._scheduler_address}: {e}") from e
|
|
115
|
+
else:
|
|
116
|
+
logger.debug("No Dask scheduler address provided, running without Dask support.")
|
|
87
117
|
|
|
88
118
|
@property
|
|
89
119
|
def config(self) -> Config:
|
|
@@ -107,20 +137,6 @@ class FastApiFrontEndPluginWorkerBase(ABC):
|
|
|
107
137
|
|
|
108
138
|
yield
|
|
109
139
|
|
|
110
|
-
# If a cleanup task is running, cancel it
|
|
111
|
-
async with self._cleanup_tasks_lock:
|
|
112
|
-
|
|
113
|
-
# Cancel all cleanup tasks
|
|
114
|
-
for task_name in self._cleanup_tasks:
|
|
115
|
-
cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
|
|
116
|
-
if cleanup_task is not None:
|
|
117
|
-
logger.info("Cancelling %s cleanup task", task_name)
|
|
118
|
-
cleanup_task.cancel()
|
|
119
|
-
else:
|
|
120
|
-
logger.warning("No cleanup task found for %s", task_name)
|
|
121
|
-
|
|
122
|
-
self._cleanup_tasks.clear()
|
|
123
|
-
|
|
124
140
|
logger.debug("Closing NAT server from process %s", os.getpid())
|
|
125
141
|
|
|
126
142
|
nat_app = FastAPI(lifespan=lifespan)
|
|
@@ -208,32 +224,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
208
224
|
self._outstanding_flows: dict[str, FlowState] = {}
|
|
209
225
|
self._outstanding_flows_lock = asyncio.Lock()
|
|
210
226
|
|
|
211
|
-
@staticmethod
|
|
212
|
-
async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
213
|
-
while True:
|
|
214
|
-
try:
|
|
215
|
-
job_store.cleanup_expired_jobs()
|
|
216
|
-
logger.debug("Expired %s jobs cleaned up", name)
|
|
217
|
-
except Exception as e:
|
|
218
|
-
logger.exception("Error during %s job cleanup: %s", name, e)
|
|
219
|
-
await asyncio.sleep(sleep_time_sec)
|
|
220
|
-
|
|
221
|
-
async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
|
|
222
|
-
# Schedule periodic cleanup of expired jobs on first job creation
|
|
223
|
-
attr_name = f"{name}_cleanup_task"
|
|
224
|
-
|
|
225
|
-
# Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
|
|
226
|
-
if not hasattr(app.state, attr_name):
|
|
227
|
-
async with self._cleanup_tasks_lock:
|
|
228
|
-
if not hasattr(app.state, attr_name):
|
|
229
|
-
logger.info("Starting %s periodic cleanup task", name)
|
|
230
|
-
setattr(
|
|
231
|
-
app.state,
|
|
232
|
-
attr_name,
|
|
233
|
-
asyncio.create_task(
|
|
234
|
-
self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
|
|
235
|
-
self._cleanup_tasks.append(attr_name)
|
|
236
|
-
|
|
237
227
|
def get_step_adaptor(self) -> StepAdaptor:
|
|
238
228
|
|
|
239
229
|
return StepAdaptor(self.front_end_config.step_adaptor)
|
|
@@ -276,52 +266,72 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
276
266
|
},
|
|
277
267
|
}
|
|
278
268
|
|
|
279
|
-
#
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
269
|
+
# TODO: Find another way to limit the number of concurrent evaluations
|
|
270
|
+
async def run_evaluation(scheduler_address: str,
|
|
271
|
+
db_url: str,
|
|
272
|
+
workflow_config_file_path: str,
|
|
273
|
+
job_id: str,
|
|
274
|
+
eval_config_file: str,
|
|
275
|
+
reps: int):
|
|
285
276
|
"""Background task to run the evaluation."""
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
277
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
# We have two config files, one for the workflow and one for the evaluation
|
|
281
|
+
# Create EvaluationRunConfig using the CLI defaults
|
|
282
|
+
eval_config = EvaluationRunConfig(config_file=Path(eval_config_file), dataset=None, reps=reps)
|
|
283
|
+
|
|
284
|
+
# Create a new EvaluationRun with the evaluation-specific config
|
|
285
|
+
await job_store.update_status(job_id, JobStatus.RUNNING)
|
|
286
|
+
eval_runner = EvaluationRun(eval_config)
|
|
287
|
+
|
|
288
|
+
async with load_workflow(workflow_config_file_path) as local_session_manager:
|
|
289
|
+
output: EvaluationRunOutput = await eval_runner.run_and_evaluate(
|
|
290
|
+
session_manager=local_session_manager, job_id=job_id)
|
|
291
|
+
|
|
292
|
+
if output.workflow_interrupted:
|
|
293
|
+
await job_store.update_status(job_id, JobStatus.INTERRUPTED)
|
|
294
|
+
else:
|
|
295
|
+
parent_dir = os.path.dirname(output.workflow_output_file) if output.workflow_output_file else None
|
|
296
|
+
|
|
297
|
+
await job_store.update_status(job_id, JobStatus.SUCCESS, output_path=str(parent_dir))
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.exception("Error in evaluation job %s", job_id)
|
|
300
|
+
await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
|
|
301
|
+
|
|
302
|
+
async def start_evaluation(request: EvaluateRequest, http_request: Request):
|
|
308
303
|
"""Handle evaluation requests."""
|
|
309
304
|
|
|
310
305
|
async with session_manager.session(http_connection=http_request):
|
|
311
306
|
|
|
312
307
|
# if job_id is present and already exists return the job info
|
|
308
|
+
# There is a race condition between this check and the actual job submission, however if the client is
|
|
309
|
+
# supplying their own job_ids, then it is their responsibility to ensure that the job_id is unique.
|
|
313
310
|
if request.job_id:
|
|
314
|
-
|
|
315
|
-
if
|
|
316
|
-
return EvaluateResponse(job_id=
|
|
317
|
-
|
|
318
|
-
job_id =
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
311
|
+
job_status = await self._job_store.get_status(request.job_id)
|
|
312
|
+
if job_status != JobStatus.NOT_FOUND:
|
|
313
|
+
return EvaluateResponse(job_id=request.job_id, status=job_status)
|
|
314
|
+
|
|
315
|
+
job_id = self._job_store.ensure_job_id(request.job_id)
|
|
316
|
+
|
|
317
|
+
await self._job_store.submit_job(job_id=job_id,
|
|
318
|
+
config_file=request.config_file,
|
|
319
|
+
expiry_seconds=request.expiry_seconds,
|
|
320
|
+
job_fn=run_evaluation,
|
|
321
|
+
job_args=[
|
|
322
|
+
self._scheduler_address,
|
|
323
|
+
self._db_url,
|
|
324
|
+
self._config_file_path,
|
|
325
|
+
job_id,
|
|
326
|
+
request.config_file,
|
|
327
|
+
request.reps
|
|
328
|
+
])
|
|
329
|
+
|
|
330
|
+
logger.info("Submitted evaluation job %s with config %s", job_id, request.config_file)
|
|
331
|
+
|
|
332
|
+
return EvaluateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
|
|
333
|
+
|
|
334
|
+
def translate_job_to_response(job: "JobInfo") -> EvaluateStatusResponse:
|
|
325
335
|
"""Translate a JobInfo object to an EvaluateStatusResponse."""
|
|
326
336
|
return EvaluateStatusResponse(job_id=job.job_id,
|
|
327
337
|
status=job.status,
|
|
@@ -330,7 +340,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
330
340
|
output_path=str(job.output_path),
|
|
331
341
|
created_at=job.created_at,
|
|
332
342
|
updated_at=job.updated_at,
|
|
333
|
-
expires_at=
|
|
343
|
+
expires_at=self._job_store.get_expires_at(job))
|
|
334
344
|
|
|
335
345
|
async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusResponse:
|
|
336
346
|
"""Get the status of an evaluation job."""
|
|
@@ -338,7 +348,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
338
348
|
|
|
339
349
|
async with session_manager.session(http_connection=http_request):
|
|
340
350
|
|
|
341
|
-
job =
|
|
351
|
+
job = await self._job_store.get_job(job_id)
|
|
342
352
|
if not job:
|
|
343
353
|
logger.warning("Job %s not found", job_id)
|
|
344
354
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
@@ -351,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
351
361
|
|
|
352
362
|
async with session_manager.session(http_connection=http_request):
|
|
353
363
|
|
|
354
|
-
job =
|
|
364
|
+
job = await self._job_store.get_last_job()
|
|
355
365
|
if not job:
|
|
356
366
|
logger.warning("No jobs found when requesting last job status")
|
|
357
367
|
raise HTTPException(status_code=404, detail="No jobs found")
|
|
@@ -365,61 +375,65 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
365
375
|
|
|
366
376
|
if status is None:
|
|
367
377
|
logger.info("Getting all jobs")
|
|
368
|
-
jobs =
|
|
378
|
+
jobs = await self._job_store.get_all_jobs()
|
|
369
379
|
else:
|
|
370
380
|
logger.info("Getting jobs with status %s", status)
|
|
371
|
-
jobs =
|
|
381
|
+
jobs = await self._job_store.get_jobs_by_status(JobStatus(status))
|
|
382
|
+
|
|
372
383
|
logger.info("Found %d jobs", len(jobs))
|
|
373
384
|
return [translate_job_to_response(job) for job in jobs]
|
|
374
385
|
|
|
375
386
|
if self.front_end_config.evaluate.path:
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# Add specific job endpoint (least specific)
|
|
391
|
-
app.add_api_route(
|
|
392
|
-
path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
|
|
393
|
-
endpoint=get_job_status,
|
|
394
|
-
methods=["GET"],
|
|
395
|
-
response_model=EvaluateStatusResponse,
|
|
396
|
-
description="Get the status of an evaluation job",
|
|
397
|
-
responses={
|
|
398
|
-
404: {
|
|
399
|
-
"description": "Job not found"
|
|
400
|
-
}, 500: response_500
|
|
401
|
-
},
|
|
402
|
-
)
|
|
403
|
-
|
|
404
|
-
# Add jobs endpoint with optional status query parameter
|
|
405
|
-
app.add_api_route(
|
|
406
|
-
path=f"{self.front_end_config.evaluate.path}/jobs",
|
|
407
|
-
endpoint=get_jobs,
|
|
408
|
-
methods=["GET"],
|
|
409
|
-
response_model=list[EvaluateStatusResponse],
|
|
410
|
-
description="Get all jobs, optionally filtered by status",
|
|
411
|
-
responses={500: response_500},
|
|
412
|
-
)
|
|
387
|
+
if self._dask_available:
|
|
388
|
+
# Add last job endpoint first (most specific)
|
|
389
|
+
app.add_api_route(
|
|
390
|
+
path=f"{self.front_end_config.evaluate.path}/job/last",
|
|
391
|
+
endpoint=get_last_job_status,
|
|
392
|
+
methods=["GET"],
|
|
393
|
+
response_model=EvaluateStatusResponse,
|
|
394
|
+
description="Get the status of the last created evaluation job",
|
|
395
|
+
responses={
|
|
396
|
+
404: {
|
|
397
|
+
"description": "No jobs found"
|
|
398
|
+
}, 500: response_500
|
|
399
|
+
},
|
|
400
|
+
)
|
|
413
401
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
402
|
+
# Add specific job endpoint (least specific)
|
|
403
|
+
app.add_api_route(
|
|
404
|
+
path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
|
|
405
|
+
endpoint=get_job_status,
|
|
406
|
+
methods=["GET"],
|
|
407
|
+
response_model=EvaluateStatusResponse,
|
|
408
|
+
description="Get the status of an evaluation job",
|
|
409
|
+
responses={
|
|
410
|
+
404: {
|
|
411
|
+
"description": "Job not found"
|
|
412
|
+
}, 500: response_500
|
|
413
|
+
},
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# Add jobs endpoint with optional status query parameter
|
|
417
|
+
app.add_api_route(
|
|
418
|
+
path=f"{self.front_end_config.evaluate.path}/jobs",
|
|
419
|
+
endpoint=get_jobs,
|
|
420
|
+
methods=["GET"],
|
|
421
|
+
response_model=list[EvaluateStatusResponse],
|
|
422
|
+
description="Get all jobs, optionally filtered by status",
|
|
423
|
+
responses={500: response_500},
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Add HTTP endpoint for evaluation
|
|
427
|
+
app.add_api_route(
|
|
428
|
+
path=self.front_end_config.evaluate.path,
|
|
429
|
+
endpoint=start_evaluation,
|
|
430
|
+
methods=[self.front_end_config.evaluate.method],
|
|
431
|
+
response_model=EvaluateResponse,
|
|
432
|
+
description=self.front_end_config.evaluate.description,
|
|
433
|
+
responses={500: response_500},
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
logger.warning("Dask is not available, evaluation endpoints will not be added.")
|
|
423
437
|
|
|
424
438
|
async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
|
|
425
439
|
|
|
@@ -526,21 +540,23 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
526
540
|
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
541
|
GenerateSingleResponseType = workflow.single_output_schema
|
|
528
542
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
default=
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
543
|
+
if self._dask_available:
|
|
544
|
+
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
545
|
+
# Consider prefixing these with "nat_" to avoid conflicts
|
|
546
|
+
|
|
547
|
+
class AsyncGenerateRequest(GenerateBodyType):
|
|
548
|
+
job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
|
|
549
|
+
sync_timeout: int = Field(
|
|
550
|
+
default=0,
|
|
551
|
+
ge=0,
|
|
552
|
+
le=300,
|
|
553
|
+
description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
|
|
554
|
+
"if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
|
|
555
|
+
expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
|
|
556
|
+
ge=JobStore.MIN_EXPIRY,
|
|
557
|
+
le=JobStore.MAX_EXPIRY,
|
|
558
|
+
description="Optional time (in seconds) before the job expires. "
|
|
559
|
+
"Clamped between 600 (10 min) and 86400 (24h).")
|
|
544
560
|
|
|
545
561
|
# Ensure that the input is in the body. POD types are treated as query parameters
|
|
546
562
|
if (not issubclass(GenerateBodyType, BaseModel)):
|
|
@@ -560,12 +576,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
560
576
|
},
|
|
561
577
|
}
|
|
562
578
|
|
|
563
|
-
# Create job store for tracking async generation jobs
|
|
564
|
-
job_store = JobStore()
|
|
565
|
-
|
|
566
|
-
# Run up to max_running_async_jobs jobs at the same time
|
|
567
|
-
async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
|
|
568
|
-
|
|
569
579
|
def get_single_endpoint(result_type: type | None):
|
|
570
580
|
|
|
571
581
|
async def get_single(response: Response, request: Request):
|
|
@@ -726,34 +736,44 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
726
736
|
|
|
727
737
|
return post_openai_api_compatible
|
|
728
738
|
|
|
729
|
-
|
|
730
|
-
"""Background task to run the evaluation."""
|
|
731
|
-
async with async_job_concurrency:
|
|
732
|
-
try:
|
|
733
|
-
result = await generate_single_response(payload=payload,
|
|
734
|
-
session_manager=session_manager,
|
|
735
|
-
result_type=result_type)
|
|
736
|
-
job_store.update_status(job_id, "success", output=result)
|
|
737
|
-
except Exception as e:
|
|
738
|
-
logger.exception("Error in evaluation job %s: %s", job_id, e)
|
|
739
|
-
job_store.update_status(job_id, "failure", error=str(e))
|
|
740
|
-
|
|
741
|
-
def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
|
|
739
|
+
def _job_status_to_response(job: "JobInfo") -> AsyncGenerationStatusResponse:
|
|
742
740
|
job_output = job.output
|
|
743
741
|
if job_output is not None:
|
|
744
|
-
|
|
742
|
+
try:
|
|
743
|
+
job_output = json.loads(job_output)
|
|
744
|
+
except json.JSONDecodeError:
|
|
745
|
+
logger.error("Failed to parse job output as JSON: %s", job_output)
|
|
746
|
+
job_output = {"error": "Output parsing failed"}
|
|
747
|
+
|
|
745
748
|
return AsyncGenerationStatusResponse(job_id=job.job_id,
|
|
746
749
|
status=job.status,
|
|
747
750
|
error=job.error,
|
|
748
751
|
output=job_output,
|
|
749
752
|
created_at=job.created_at,
|
|
750
753
|
updated_at=job.updated_at,
|
|
751
|
-
expires_at=
|
|
754
|
+
expires_at=self._job_store.get_expires_at(job))
|
|
755
|
+
|
|
756
|
+
async def run_generation(scheduler_address: str,
|
|
757
|
+
db_url: str,
|
|
758
|
+
config_file_path: str,
|
|
759
|
+
job_id: str,
|
|
760
|
+
payload: typing.Any):
|
|
761
|
+
"""Background task to run the workflow."""
|
|
762
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
763
|
+
try:
|
|
764
|
+
async with load_workflow(config_file_path) as local_session_manager:
|
|
765
|
+
result = await generate_single_response(
|
|
766
|
+
payload, local_session_manager, result_type=local_session_manager.workflow.single_output_schema)
|
|
752
767
|
|
|
753
|
-
|
|
768
|
+
await job_store.update_status(job_id, JobStatus.SUCCESS, output=result)
|
|
769
|
+
except Exception as e:
|
|
770
|
+
logger.exception("Error in async job %s", job_id)
|
|
771
|
+
await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
|
|
772
|
+
|
|
773
|
+
def post_async_generation(request_type: type):
|
|
754
774
|
|
|
755
775
|
async def start_async_generation(
|
|
756
|
-
request: request_type,
|
|
776
|
+
request: request_type, response: Response,
|
|
757
777
|
http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
|
|
758
778
|
"""Handle async generation requests."""
|
|
759
779
|
|
|
@@ -761,41 +781,29 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
761
781
|
|
|
762
782
|
# if job_id is present and already exists return the job info
|
|
763
783
|
if request.job_id:
|
|
764
|
-
job =
|
|
784
|
+
job = await self._job_store.get_job(request.job_id)
|
|
765
785
|
if job:
|
|
766
786
|
return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
|
|
767
787
|
|
|
768
|
-
job_id =
|
|
769
|
-
await self.
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
now = time.time()
|
|
786
|
-
sync_timeout = now + request.sync_timeout
|
|
787
|
-
while time.time() < sync_timeout:
|
|
788
|
-
job = job_store.get_job(job_id)
|
|
789
|
-
if job is not None and job.status not in job_store.ACTIVE_STATUS:
|
|
790
|
-
# If the job is done, return the result
|
|
791
|
-
response.status_code = 200
|
|
792
|
-
return _job_status_to_response(job)
|
|
793
|
-
|
|
794
|
-
# Sleep for a short time before checking again
|
|
795
|
-
await asyncio.sleep(0.1)
|
|
788
|
+
job_id = self._job_store.ensure_job_id(request.job_id)
|
|
789
|
+
(_, job) = await self._job_store.submit_job(job_id=job_id,
|
|
790
|
+
expiry_seconds=request.expiry_seconds,
|
|
791
|
+
job_fn=run_generation,
|
|
792
|
+
sync_timeout=request.sync_timeout,
|
|
793
|
+
job_args=[
|
|
794
|
+
self._scheduler_address,
|
|
795
|
+
self._db_url,
|
|
796
|
+
self._config_file_path,
|
|
797
|
+
job_id,
|
|
798
|
+
request.model_dump(mode="json")
|
|
799
|
+
])
|
|
800
|
+
|
|
801
|
+
if job is not None:
|
|
802
|
+
response.status_code = 200
|
|
803
|
+
return _job_status_to_response(job)
|
|
796
804
|
|
|
797
805
|
response.status_code = 202
|
|
798
|
-
return AsyncGenerateResponse(job_id=job_id, status=
|
|
806
|
+
return AsyncGenerateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
|
|
799
807
|
|
|
800
808
|
return start_async_generation
|
|
801
809
|
|
|
@@ -805,8 +813,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
805
813
|
|
|
806
814
|
async with session_manager.session(http_connection=http_request):
|
|
807
815
|
|
|
808
|
-
job =
|
|
809
|
-
if
|
|
816
|
+
job = await self._job_store.get_job(job_id)
|
|
817
|
+
if job is None:
|
|
810
818
|
logger.warning("Job %s not found", job_id)
|
|
811
819
|
raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
|
|
812
820
|
|
|
@@ -934,30 +942,33 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
934
942
|
responses={500: response_500},
|
|
935
943
|
)
|
|
936
944
|
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
945
|
+
if self._dask_available:
|
|
946
|
+
app.add_api_route(
|
|
947
|
+
path=f"{endpoint.path}/async",
|
|
948
|
+
endpoint=post_async_generation(request_type=AsyncGenerateRequest),
|
|
949
|
+
methods=[endpoint.method],
|
|
950
|
+
response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
|
|
951
|
+
description="Start an async generate job",
|
|
952
|
+
responses={500: response_500},
|
|
953
|
+
)
|
|
954
|
+
else:
|
|
955
|
+
logger.warning("Dask is not available, async generation endpoints will not be added.")
|
|
946
956
|
else:
|
|
947
957
|
raise ValueError(f"Unsupported method {endpoint.method}")
|
|
948
958
|
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
959
|
+
if self._dask_available:
|
|
960
|
+
app.add_api_route(
|
|
961
|
+
path=f"{endpoint.path}/async/job/{{job_id}}",
|
|
962
|
+
endpoint=get_async_job_status,
|
|
963
|
+
methods=["GET"],
|
|
964
|
+
response_model=AsyncGenerationStatusResponse,
|
|
965
|
+
description="Get the status of an async job",
|
|
966
|
+
responses={
|
|
967
|
+
404: {
|
|
968
|
+
"description": "Job not found"
|
|
969
|
+
}, 500: response_500
|
|
970
|
+
},
|
|
971
|
+
)
|
|
961
972
|
|
|
962
973
|
if (endpoint.openai_api_path):
|
|
963
974
|
if (endpoint.method == "GET"):
|
|
@@ -1084,3 +1095,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
1084
1095
|
async def _remove_flow(self, state: str):
|
|
1085
1096
|
async with self._outstanding_flows_lock:
|
|
1086
1097
|
del self._outstanding_flows[state]
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
# Prevent Sphinx from documenting items not a part of the public API
|
|
1101
|
+
__all__ = ["FastApiFrontEndPluginWorkerBase", "FastApiFrontEndPluginWorker", "RouteInfo"]
|