nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +15 -5
  6. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  7. nat/agent/register.py +2 -0
  8. nat/agent/rewoo_agent/agent.py +4 -2
  9. nat/agent/rewoo_agent/register.py +8 -3
  10. nat/agent/router_agent/__init__.py +0 -0
  11. nat/agent/router_agent/agent.py +329 -0
  12. nat/agent/router_agent/prompt.py +48 -0
  13. nat/agent/router_agent/register.py +97 -0
  14. nat/agent/tool_calling_agent/agent.py +69 -7
  15. nat/agent/tool_calling_agent/register.py +11 -3
  16. nat/builder/builder.py +27 -4
  17. nat/builder/component_utils.py +7 -3
  18. nat/builder/function.py +167 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +213 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +2 -0
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/data_models/component.py +2 -0
  29. nat/data_models/component_ref.py +11 -0
  30. nat/data_models/config.py +40 -16
  31. nat/data_models/function.py +34 -0
  32. nat/data_models/function_dependencies.py +8 -0
  33. nat/data_models/optimizable.py +119 -0
  34. nat/data_models/optimizer.py +149 -0
  35. nat/data_models/temperature_mixin.py +4 -3
  36. nat/data_models/top_p_mixin.py +4 -3
  37. nat/embedder/nim_embedder.py +1 -1
  38. nat/embedder/openai_embedder.py +1 -1
  39. nat/eval/config.py +1 -1
  40. nat/eval/evaluate.py +5 -1
  41. nat/eval/register.py +4 -0
  42. nat/eval/runtime_evaluator/__init__.py +14 -0
  43. nat/eval/runtime_evaluator/evaluate.py +123 -0
  44. nat/eval/runtime_evaluator/register.py +100 -0
  45. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  46. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  47. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  48. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  49. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  50. nat/front_ends/fastapi/job_store.py +518 -99
  51. nat/front_ends/fastapi/main.py +11 -19
  52. nat/front_ends/fastapi/utils.py +57 -0
  53. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  54. nat/llm/aws_bedrock_llm.py +14 -3
  55. nat/llm/nim_llm.py +14 -3
  56. nat/llm/openai_llm.py +8 -1
  57. nat/observability/exporter/processing_exporter.py +29 -55
  58. nat/observability/mixin/redaction_config_mixin.py +5 -4
  59. nat/observability/mixin/tagging_config_mixin.py +26 -14
  60. nat/observability/mixin/type_introspection_mixin.py +401 -107
  61. nat/observability/processor/processor.py +3 -0
  62. nat/observability/processor/redaction/__init__.py +24 -0
  63. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  64. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  65. nat/observability/processor/redaction/redaction_processor.py +177 -0
  66. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  67. nat/observability/processor/span_tagging_processor.py +21 -14
  68. nat/profiler/decorators/framework_wrapper.py +9 -6
  69. nat/profiler/parameter_optimization/__init__.py +0 -0
  70. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  71. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  72. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  73. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  74. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  75. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  76. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  77. nat/profiler/utils.py +3 -1
  78. nat/tool/chat_completion.py +4 -1
  79. nat/tool/github_tools.py +450 -0
  80. nat/tool/register.py +2 -7
  81. nat/utils/callable_utils.py +70 -0
  82. nat/utils/exception_handlers/automatic_retries.py +103 -48
  83. nat/utils/type_utils.py +4 -0
  84. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  85. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
  86. nat/observability/processor/header_redaction_processor.py +0 -123
  87. nat/observability/processor/redaction_processor.py +0 -77
  88. nat/tool/github_tools/create_github_commit.py +0 -133
  89. nat/tool/github_tools/create_github_issue.py +0 -87
  90. nat/tool/github_tools/create_github_pr.py +0 -106
  91. nat/tool/github_tools/get_github_file.py +0 -106
  92. nat/tool/github_tools/get_github_issue.py +0 -166
  93. nat/tool/github_tools/get_github_pr.py +0 -256
  94. nat/tool/github_tools/update_github_issue.py +0 -100
  95. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  96. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  97. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -14,9 +14,9 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import json
17
18
  import logging
18
19
  import os
19
- import time
20
20
  import typing
21
21
  from abc import ABC
22
22
  from abc import abstractmethod
@@ -25,7 +25,6 @@ from collections.abc import Callable
25
25
  from contextlib import asynccontextmanager
26
26
  from pathlib import Path
27
27
 
28
- from fastapi import BackgroundTasks
29
28
  from fastapi import Body
30
29
  from fastapi import FastAPI
31
30
  from fastapi import Request
@@ -58,18 +57,30 @@ from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateRequest
58
57
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateResponse
59
58
  from nat.front_ends.fastapi.fastapi_front_end_config import EvaluateStatusResponse
60
59
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
61
- from nat.front_ends.fastapi.job_store import JobInfo
62
- from nat.front_ends.fastapi.job_store import JobStore
63
60
  from nat.front_ends.fastapi.message_handler import WebSocketMessageHandler
64
61
  from nat.front_ends.fastapi.response_helpers import generate_single_response
65
62
  from nat.front_ends.fastapi.response_helpers import generate_streaming_response_as_str
66
63
  from nat.front_ends.fastapi.response_helpers import generate_streaming_response_full_as_str
67
64
  from nat.front_ends.fastapi.step_adaptor import StepAdaptor
65
+ from nat.front_ends.fastapi.utils import get_config_file_path
68
66
  from nat.object_store.models import ObjectStoreItem
67
+ from nat.runtime.loader import load_workflow
69
68
  from nat.runtime.session import SessionManager
70
69
 
71
70
  logger = logging.getLogger(__name__)
72
71
 
72
+ _DASK_AVAILABLE = False
73
+
74
+ try:
75
+ from nat.front_ends.fastapi.job_store import JobInfo
76
+ from nat.front_ends.fastapi.job_store import JobStatus
77
+ from nat.front_ends.fastapi.job_store import JobStore
78
+ _DASK_AVAILABLE = True
79
+ except ImportError:
80
+ JobInfo = None
81
+ JobStatus = None
82
+ JobStore = None
83
+
73
84
 
74
85
  class FastApiFrontEndPluginWorkerBase(ABC):
75
86
 
@@ -80,10 +91,29 @@ class FastApiFrontEndPluginWorkerBase(ABC):
80
91
  FastApiFrontEndConfig), ("Front end config is not FastApiFrontEndConfig")
81
92
 
82
93
  self._front_end_config = config.general.front_end
83
-
84
- self._cleanup_tasks: list[str] = []
85
- self._cleanup_tasks_lock = asyncio.Lock()
94
+ self._dask_available = False
95
+ self._job_store = None
86
96
  self._http_flow_handler: HTTPAuthenticationFlowHandler | None = HTTPAuthenticationFlowHandler()
97
+ self._scheduler_address = os.environ.get("NAT_DASK_SCHEDULER_ADDRESS")
98
+ self._db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
99
+ self._config_file_path = get_config_file_path()
100
+
101
+ if self._scheduler_address is not None:
102
+ if not _DASK_AVAILABLE:
103
+ raise RuntimeError("Dask is not available, please install it to use the FastAPI front end with Dask.")
104
+
105
+ if self._db_url is None:
106
+ raise RuntimeError(
107
+ "NAT_JOB_STORE_DB_URL must be set when using Dask (configure a persistent JobStore database).")
108
+
109
+ try:
110
+ self._job_store = JobStore(scheduler_address=self._scheduler_address, db_url=self._db_url)
111
+ self._dask_available = True
112
+ logger.debug("Connected to Dask scheduler at %s", self._scheduler_address)
113
+ except Exception as e:
114
+ raise RuntimeError(f"Failed to connect to Dask scheduler at {self._scheduler_address}: {e}") from e
115
+ else:
116
+ logger.debug("No Dask scheduler address provided, running without Dask support.")
87
117
 
88
118
  @property
89
119
  def config(self) -> Config:
@@ -107,20 +137,6 @@ class FastApiFrontEndPluginWorkerBase(ABC):
107
137
 
108
138
  yield
109
139
 
110
- # If a cleanup task is running, cancel it
111
- async with self._cleanup_tasks_lock:
112
-
113
- # Cancel all cleanup tasks
114
- for task_name in self._cleanup_tasks:
115
- cleanup_task: asyncio.Task | None = getattr(starting_app.state, task_name, None)
116
- if cleanup_task is not None:
117
- logger.info("Cancelling %s cleanup task", task_name)
118
- cleanup_task.cancel()
119
- else:
120
- logger.warning("No cleanup task found for %s", task_name)
121
-
122
- self._cleanup_tasks.clear()
123
-
124
140
  logger.debug("Closing NAT server from process %s", os.getpid())
125
141
 
126
142
  nat_app = FastAPI(lifespan=lifespan)
@@ -208,32 +224,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
208
224
  self._outstanding_flows: dict[str, FlowState] = {}
209
225
  self._outstanding_flows_lock = asyncio.Lock()
210
226
 
211
- @staticmethod
212
- async def _periodic_cleanup(name: str, job_store: JobStore, sleep_time_sec: int = 300):
213
- while True:
214
- try:
215
- job_store.cleanup_expired_jobs()
216
- logger.debug("Expired %s jobs cleaned up", name)
217
- except Exception as e:
218
- logger.exception("Error during %s job cleanup: %s", name, e)
219
- await asyncio.sleep(sleep_time_sec)
220
-
221
- async def create_cleanup_task(self, app: FastAPI, name: str, job_store: JobStore, sleep_time_sec: int = 300):
222
- # Schedule periodic cleanup of expired jobs on first job creation
223
- attr_name = f"{name}_cleanup_task"
224
-
225
- # Cheap check, if it doesn't exist, we will need to re-check after we acquire the lock
226
- if not hasattr(app.state, attr_name):
227
- async with self._cleanup_tasks_lock:
228
- if not hasattr(app.state, attr_name):
229
- logger.info("Starting %s periodic cleanup task", name)
230
- setattr(
231
- app.state,
232
- attr_name,
233
- asyncio.create_task(
234
- self._periodic_cleanup(name=name, job_store=job_store, sleep_time_sec=sleep_time_sec)))
235
- self._cleanup_tasks.append(attr_name)
236
-
237
227
  def get_step_adaptor(self) -> StepAdaptor:
238
228
 
239
229
  return StepAdaptor(self.front_end_config.step_adaptor)
@@ -276,52 +266,72 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
276
266
  },
277
267
  }
278
268
 
279
- # Create job store for tracking evaluation jobs
280
- job_store = JobStore()
281
- # Don't run multiple evaluations at the same time
282
- evaluation_lock = asyncio.Lock()
283
-
284
- async def run_evaluation(job_id: str, config_file: str, reps: int, session_manager: SessionManager):
269
+ # TODO: Find another way to limit the number of concurrent evaluations
270
+ async def run_evaluation(scheduler_address: str,
271
+ db_url: str,
272
+ workflow_config_file_path: str,
273
+ job_id: str,
274
+ eval_config_file: str,
275
+ reps: int):
285
276
  """Background task to run the evaluation."""
286
- async with evaluation_lock:
287
- try:
288
- # Create EvaluationRunConfig using the CLI defaults
289
- eval_config = EvaluationRunConfig(config_file=Path(config_file), dataset=None, reps=reps)
290
-
291
- # Create a new EvaluationRun with the evaluation-specific config
292
- job_store.update_status(job_id, "running")
293
- eval_runner = EvaluationRun(eval_config)
294
- output: EvaluationRunOutput = await eval_runner.run_and_evaluate(session_manager=session_manager,
295
- job_id=job_id)
296
- if output.workflow_interrupted:
297
- job_store.update_status(job_id, "interrupted")
298
- else:
299
- parent_dir = os.path.dirname(
300
- output.workflow_output_file) if output.workflow_output_file else None
301
-
302
- job_store.update_status(job_id, "success", output_path=str(parent_dir))
303
- except Exception as e:
304
- logger.exception("Error in evaluation job %s: %s", job_id, str(e))
305
- job_store.update_status(job_id, "failure", error=str(e))
306
-
307
- async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
277
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
278
+
279
+ try:
280
+ # We have two config files, one for the workflow and one for the evaluation
281
+ # Create EvaluationRunConfig using the CLI defaults
282
+ eval_config = EvaluationRunConfig(config_file=Path(eval_config_file), dataset=None, reps=reps)
283
+
284
+ # Create a new EvaluationRun with the evaluation-specific config
285
+ await job_store.update_status(job_id, JobStatus.RUNNING)
286
+ eval_runner = EvaluationRun(eval_config)
287
+
288
+ async with load_workflow(workflow_config_file_path) as local_session_manager:
289
+ output: EvaluationRunOutput = await eval_runner.run_and_evaluate(
290
+ session_manager=local_session_manager, job_id=job_id)
291
+
292
+ if output.workflow_interrupted:
293
+ await job_store.update_status(job_id, JobStatus.INTERRUPTED)
294
+ else:
295
+ parent_dir = os.path.dirname(output.workflow_output_file) if output.workflow_output_file else None
296
+
297
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output_path=str(parent_dir))
298
+ except Exception as e:
299
+ logger.exception("Error in evaluation job %s", job_id)
300
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
301
+
302
+ async def start_evaluation(request: EvaluateRequest, http_request: Request):
308
303
  """Handle evaluation requests."""
309
304
 
310
305
  async with session_manager.session(http_connection=http_request):
311
306
 
312
307
  # if job_id is present and already exists return the job info
308
+ # There is a race condition between this check and the actual job submission, however if the client is
309
+ # supplying their own job_ids, then it is their responsibility to ensure that the job_id is unique.
313
310
  if request.job_id:
314
- job = job_store.get_job(request.job_id)
315
- if job:
316
- return EvaluateResponse(job_id=job.job_id, status=job.status)
317
-
318
- job_id = job_store.create_job(request.config_file, request.job_id, request.expiry_seconds)
319
- await self.create_cleanup_task(app=app, name="async_evaluation", job_store=job_store)
320
- background_tasks.add_task(run_evaluation, job_id, request.config_file, request.reps, session_manager)
321
-
322
- return EvaluateResponse(job_id=job_id, status="submitted")
323
-
324
- def translate_job_to_response(job: JobInfo) -> EvaluateStatusResponse:
311
+ job_status = await self._job_store.get_status(request.job_id)
312
+ if job_status != JobStatus.NOT_FOUND:
313
+ return EvaluateResponse(job_id=request.job_id, status=job_status)
314
+
315
+ job_id = self._job_store.ensure_job_id(request.job_id)
316
+
317
+ await self._job_store.submit_job(job_id=job_id,
318
+ config_file=request.config_file,
319
+ expiry_seconds=request.expiry_seconds,
320
+ job_fn=run_evaluation,
321
+ job_args=[
322
+ self._scheduler_address,
323
+ self._db_url,
324
+ self._config_file_path,
325
+ job_id,
326
+ request.config_file,
327
+ request.reps
328
+ ])
329
+
330
+ logger.info("Submitted evaluation job %s with config %s", job_id, request.config_file)
331
+
332
+ return EvaluateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
333
+
334
+ def translate_job_to_response(job: "JobInfo") -> EvaluateStatusResponse:
325
335
  """Translate a JobInfo object to an EvaluateStatusResponse."""
326
336
  return EvaluateStatusResponse(job_id=job.job_id,
327
337
  status=job.status,
@@ -330,7 +340,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
330
340
  output_path=str(job.output_path),
331
341
  created_at=job.created_at,
332
342
  updated_at=job.updated_at,
333
- expires_at=job_store.get_expires_at(job))
343
+ expires_at=self._job_store.get_expires_at(job))
334
344
 
335
345
  async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusResponse:
336
346
  """Get the status of an evaluation job."""
@@ -338,7 +348,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
338
348
 
339
349
  async with session_manager.session(http_connection=http_request):
340
350
 
341
- job = job_store.get_job(job_id)
351
+ job = await self._job_store.get_job(job_id)
342
352
  if not job:
343
353
  logger.warning("Job %s not found", job_id)
344
354
  raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
@@ -351,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
351
361
 
352
362
  async with session_manager.session(http_connection=http_request):
353
363
 
354
- job = job_store.get_last_job()
364
+ job = await self._job_store.get_last_job()
355
365
  if not job:
356
366
  logger.warning("No jobs found when requesting last job status")
357
367
  raise HTTPException(status_code=404, detail="No jobs found")
@@ -365,61 +375,65 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
365
375
 
366
376
  if status is None:
367
377
  logger.info("Getting all jobs")
368
- jobs = job_store.get_all_jobs()
378
+ jobs = await self._job_store.get_all_jobs()
369
379
  else:
370
380
  logger.info("Getting jobs with status %s", status)
371
- jobs = job_store.get_jobs_by_status(status)
381
+ jobs = await self._job_store.get_jobs_by_status(JobStatus(status))
382
+
372
383
  logger.info("Found %d jobs", len(jobs))
373
384
  return [translate_job_to_response(job) for job in jobs]
374
385
 
375
386
  if self.front_end_config.evaluate.path:
376
- # Add last job endpoint first (most specific)
377
- app.add_api_route(
378
- path=f"{self.front_end_config.evaluate.path}/job/last",
379
- endpoint=get_last_job_status,
380
- methods=["GET"],
381
- response_model=EvaluateStatusResponse,
382
- description="Get the status of the last created evaluation job",
383
- responses={
384
- 404: {
385
- "description": "No jobs found"
386
- }, 500: response_500
387
- },
388
- )
389
-
390
- # Add specific job endpoint (least specific)
391
- app.add_api_route(
392
- path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
393
- endpoint=get_job_status,
394
- methods=["GET"],
395
- response_model=EvaluateStatusResponse,
396
- description="Get the status of an evaluation job",
397
- responses={
398
- 404: {
399
- "description": "Job not found"
400
- }, 500: response_500
401
- },
402
- )
403
-
404
- # Add jobs endpoint with optional status query parameter
405
- app.add_api_route(
406
- path=f"{self.front_end_config.evaluate.path}/jobs",
407
- endpoint=get_jobs,
408
- methods=["GET"],
409
- response_model=list[EvaluateStatusResponse],
410
- description="Get all jobs, optionally filtered by status",
411
- responses={500: response_500},
412
- )
387
+ if self._dask_available:
388
+ # Add last job endpoint first (most specific)
389
+ app.add_api_route(
390
+ path=f"{self.front_end_config.evaluate.path}/job/last",
391
+ endpoint=get_last_job_status,
392
+ methods=["GET"],
393
+ response_model=EvaluateStatusResponse,
394
+ description="Get the status of the last created evaluation job",
395
+ responses={
396
+ 404: {
397
+ "description": "No jobs found"
398
+ }, 500: response_500
399
+ },
400
+ )
413
401
 
414
- # Add HTTP endpoint for evaluation
415
- app.add_api_route(
416
- path=self.front_end_config.evaluate.path,
417
- endpoint=start_evaluation,
418
- methods=[self.front_end_config.evaluate.method],
419
- response_model=EvaluateResponse,
420
- description=self.front_end_config.evaluate.description,
421
- responses={500: response_500},
422
- )
402
+ # Add specific job endpoint (least specific)
403
+ app.add_api_route(
404
+ path=f"{self.front_end_config.evaluate.path}/job/{{job_id}}",
405
+ endpoint=get_job_status,
406
+ methods=["GET"],
407
+ response_model=EvaluateStatusResponse,
408
+ description="Get the status of an evaluation job",
409
+ responses={
410
+ 404: {
411
+ "description": "Job not found"
412
+ }, 500: response_500
413
+ },
414
+ )
415
+
416
+ # Add jobs endpoint with optional status query parameter
417
+ app.add_api_route(
418
+ path=f"{self.front_end_config.evaluate.path}/jobs",
419
+ endpoint=get_jobs,
420
+ methods=["GET"],
421
+ response_model=list[EvaluateStatusResponse],
422
+ description="Get all jobs, optionally filtered by status",
423
+ responses={500: response_500},
424
+ )
425
+
426
+ # Add HTTP endpoint for evaluation
427
+ app.add_api_route(
428
+ path=self.front_end_config.evaluate.path,
429
+ endpoint=start_evaluation,
430
+ methods=[self.front_end_config.evaluate.method],
431
+ response_model=EvaluateResponse,
432
+ description=self.front_end_config.evaluate.description,
433
+ responses={500: response_500},
434
+ )
435
+ else:
436
+ logger.warning("Dask is not available, evaluation endpoints will not be added.")
423
437
 
424
438
  async def add_static_files_route(self, app: FastAPI, builder: WorkflowBuilder):
425
439
 
@@ -526,21 +540,23 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
526
540
  GenerateStreamResponseType = workflow.streaming_output_schema
527
541
  GenerateSingleResponseType = workflow.single_output_schema
528
542
 
529
- # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
530
- # Consider prefixing these with "nat_" to avoid conflicts
531
- class AsyncGenerateRequest(GenerateBodyType):
532
- job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
533
- sync_timeout: int = Field(
534
- default=0,
535
- ge=0,
536
- le=300,
537
- description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
538
- "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
539
- expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
540
- ge=JobStore.MIN_EXPIRY,
541
- le=JobStore.MAX_EXPIRY,
542
- description="Optional time (in seconds) before the job expires. "
543
- "Clamped between 600 (10 min) and 86400 (24h).")
543
+ if self._dask_available:
544
+ # Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
545
+ # Consider prefixing these with "nat_" to avoid conflicts
546
+
547
+ class AsyncGenerateRequest(GenerateBodyType):
548
+ job_id: str | None = Field(default=None, description="Unique identifier for the evaluation job")
549
+ sync_timeout: int = Field(
550
+ default=0,
551
+ ge=0,
552
+ le=300,
553
+ description="Attempt to perform the job synchronously up until `sync_timeout` sectonds, "
554
+ "if the job hasn't been completed by then a job_id will be returned with a status code of 202.")
555
+ expiry_seconds: int = Field(default=JobStore.DEFAULT_EXPIRY,
556
+ ge=JobStore.MIN_EXPIRY,
557
+ le=JobStore.MAX_EXPIRY,
558
+ description="Optional time (in seconds) before the job expires. "
559
+ "Clamped between 600 (10 min) and 86400 (24h).")
544
560
 
545
561
  # Ensure that the input is in the body. POD types are treated as query parameters
546
562
  if (not issubclass(GenerateBodyType, BaseModel)):
@@ -560,12 +576,6 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
560
576
  },
561
577
  }
562
578
 
563
- # Create job store for tracking async generation jobs
564
- job_store = JobStore()
565
-
566
- # Run up to max_running_async_jobs jobs at the same time
567
- async_job_concurrency = asyncio.Semaphore(self._front_end_config.max_running_async_jobs)
568
-
569
579
  def get_single_endpoint(result_type: type | None):
570
580
 
571
581
  async def get_single(response: Response, request: Request):
@@ -726,34 +736,44 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
726
736
 
727
737
  return post_openai_api_compatible
728
738
 
729
- async def run_generation(job_id: str, payload: typing.Any, session_manager: SessionManager, result_type: type):
730
- """Background task to run the evaluation."""
731
- async with async_job_concurrency:
732
- try:
733
- result = await generate_single_response(payload=payload,
734
- session_manager=session_manager,
735
- result_type=result_type)
736
- job_store.update_status(job_id, "success", output=result)
737
- except Exception as e:
738
- logger.exception("Error in evaluation job %s: %s", job_id, e)
739
- job_store.update_status(job_id, "failure", error=str(e))
740
-
741
- def _job_status_to_response(job: JobInfo) -> AsyncGenerationStatusResponse:
739
+ def _job_status_to_response(job: "JobInfo") -> AsyncGenerationStatusResponse:
742
740
  job_output = job.output
743
741
  if job_output is not None:
744
- job_output = job_output.model_dump()
742
+ try:
743
+ job_output = json.loads(job_output)
744
+ except json.JSONDecodeError:
745
+ logger.error("Failed to parse job output as JSON: %s", job_output)
746
+ job_output = {"error": "Output parsing failed"}
747
+
745
748
  return AsyncGenerationStatusResponse(job_id=job.job_id,
746
749
  status=job.status,
747
750
  error=job.error,
748
751
  output=job_output,
749
752
  created_at=job.created_at,
750
753
  updated_at=job.updated_at,
751
- expires_at=job_store.get_expires_at(job))
754
+ expires_at=self._job_store.get_expires_at(job))
755
+
756
+ async def run_generation(scheduler_address: str,
757
+ db_url: str,
758
+ config_file_path: str,
759
+ job_id: str,
760
+ payload: typing.Any):
761
+ """Background task to run the workflow."""
762
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
763
+ try:
764
+ async with load_workflow(config_file_path) as local_session_manager:
765
+ result = await generate_single_response(
766
+ payload, local_session_manager, result_type=local_session_manager.workflow.single_output_schema)
752
767
 
753
- def post_async_generation(request_type: type, final_result_type: type):
768
+ await job_store.update_status(job_id, JobStatus.SUCCESS, output=result)
769
+ except Exception as e:
770
+ logger.exception("Error in async job %s", job_id)
771
+ await job_store.update_status(job_id, JobStatus.FAILURE, error=str(e))
772
+
773
+ def post_async_generation(request_type: type):
754
774
 
755
775
  async def start_async_generation(
756
- request: request_type, background_tasks: BackgroundTasks, response: Response,
776
+ request: request_type, response: Response,
757
777
  http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
758
778
  """Handle async generation requests."""
759
779
 
@@ -761,41 +781,29 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
761
781
 
762
782
  # if job_id is present and already exists return the job info
763
783
  if request.job_id:
764
- job = job_store.get_job(request.job_id)
784
+ job = await self._job_store.get_job(request.job_id)
765
785
  if job:
766
786
  return AsyncGenerateResponse(job_id=job.job_id, status=job.status)
767
787
 
768
- job_id = job_store.create_job(job_id=request.job_id, expiry_seconds=request.expiry_seconds)
769
- await self.create_cleanup_task(app=app, name="async_generation", job_store=job_store)
770
-
771
- # The fastapi/starlette background tasks won't begin executing until after the response is sent
772
- # to the client, so we need to wrap the task in a function, alowing us to start the task now,
773
- # and allowing the background task function to await the results.
774
- task = asyncio.create_task(
775
- run_generation(job_id=job_id,
776
- payload=request,
777
- session_manager=session_manager,
778
- result_type=final_result_type))
779
-
780
- async def wrapped_task(t: asyncio.Task):
781
- return await t
782
-
783
- background_tasks.add_task(wrapped_task, task)
784
-
785
- now = time.time()
786
- sync_timeout = now + request.sync_timeout
787
- while time.time() < sync_timeout:
788
- job = job_store.get_job(job_id)
789
- if job is not None and job.status not in job_store.ACTIVE_STATUS:
790
- # If the job is done, return the result
791
- response.status_code = 200
792
- return _job_status_to_response(job)
793
-
794
- # Sleep for a short time before checking again
795
- await asyncio.sleep(0.1)
788
+ job_id = self._job_store.ensure_job_id(request.job_id)
789
+ (_, job) = await self._job_store.submit_job(job_id=job_id,
790
+ expiry_seconds=request.expiry_seconds,
791
+ job_fn=run_generation,
792
+ sync_timeout=request.sync_timeout,
793
+ job_args=[
794
+ self._scheduler_address,
795
+ self._db_url,
796
+ self._config_file_path,
797
+ job_id,
798
+ request.model_dump(mode="json")
799
+ ])
800
+
801
+ if job is not None:
802
+ response.status_code = 200
803
+ return _job_status_to_response(job)
796
804
 
797
805
  response.status_code = 202
798
- return AsyncGenerateResponse(job_id=job_id, status="submitted")
806
+ return AsyncGenerateResponse(job_id=job_id, status=JobStatus.SUBMITTED)
799
807
 
800
808
  return start_async_generation
801
809
 
@@ -805,8 +813,8 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
805
813
 
806
814
  async with session_manager.session(http_connection=http_request):
807
815
 
808
- job = job_store.get_job(job_id)
809
- if not job:
816
+ job = await self._job_store.get_job(job_id)
817
+ if job is None:
810
818
  logger.warning("Job %s not found", job_id)
811
819
  raise HTTPException(status_code=404, detail=f"Job {job_id} not found")
812
820
 
@@ -934,30 +942,33 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
934
942
  responses={500: response_500},
935
943
  )
936
944
 
937
- app.add_api_route(
938
- path=f"{endpoint.path}/async",
939
- endpoint=post_async_generation(request_type=AsyncGenerateRequest,
940
- final_result_type=GenerateSingleResponseType),
941
- methods=[endpoint.method],
942
- response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
943
- description="Start an async generate job",
944
- responses={500: response_500},
945
- )
945
+ if self._dask_available:
946
+ app.add_api_route(
947
+ path=f"{endpoint.path}/async",
948
+ endpoint=post_async_generation(request_type=AsyncGenerateRequest),
949
+ methods=[endpoint.method],
950
+ response_model=AsyncGenerateResponse | AsyncGenerationStatusResponse,
951
+ description="Start an async generate job",
952
+ responses={500: response_500},
953
+ )
954
+ else:
955
+ logger.warning("Dask is not available, async generation endpoints will not be added.")
946
956
  else:
947
957
  raise ValueError(f"Unsupported method {endpoint.method}")
948
958
 
949
- app.add_api_route(
950
- path=f"{endpoint.path}/async/job/{{job_id}}",
951
- endpoint=get_async_job_status,
952
- methods=["GET"],
953
- response_model=AsyncGenerationStatusResponse,
954
- description="Get the status of an async job",
955
- responses={
956
- 404: {
957
- "description": "Job not found"
958
- }, 500: response_500
959
- },
960
- )
959
+ if self._dask_available:
960
+ app.add_api_route(
961
+ path=f"{endpoint.path}/async/job/{{job_id}}",
962
+ endpoint=get_async_job_status,
963
+ methods=["GET"],
964
+ response_model=AsyncGenerationStatusResponse,
965
+ description="Get the status of an async job",
966
+ responses={
967
+ 404: {
968
+ "description": "Job not found"
969
+ }, 500: response_500
970
+ },
971
+ )
961
972
 
962
973
  if (endpoint.openai_api_path):
963
974
  if (endpoint.method == "GET"):
@@ -1084,3 +1095,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
1084
1095
  async def _remove_flow(self, state: str):
1085
1096
  async with self._outstanding_flows_lock:
1086
1097
  del self._outstanding_flows[state]
1098
+
1099
+
1100
+ # Prevent Sphinx from documenting items not a part of the public API
1101
+ __all__ = ["FastApiFrontEndPluginWorkerBase", "FastApiFrontEndPluginWorker", "RouteInfo"]