nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +15 -5
  6. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  7. nat/agent/register.py +2 -0
  8. nat/agent/rewoo_agent/agent.py +4 -2
  9. nat/agent/rewoo_agent/register.py +8 -3
  10. nat/agent/router_agent/__init__.py +0 -0
  11. nat/agent/router_agent/agent.py +329 -0
  12. nat/agent/router_agent/prompt.py +48 -0
  13. nat/agent/router_agent/register.py +97 -0
  14. nat/agent/tool_calling_agent/agent.py +69 -7
  15. nat/agent/tool_calling_agent/register.py +11 -3
  16. nat/builder/builder.py +27 -4
  17. nat/builder/component_utils.py +7 -3
  18. nat/builder/function.py +167 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +213 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +2 -0
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/data_models/component.py +2 -0
  29. nat/data_models/component_ref.py +11 -0
  30. nat/data_models/config.py +40 -16
  31. nat/data_models/function.py +34 -0
  32. nat/data_models/function_dependencies.py +8 -0
  33. nat/data_models/optimizable.py +119 -0
  34. nat/data_models/optimizer.py +149 -0
  35. nat/data_models/temperature_mixin.py +4 -3
  36. nat/data_models/top_p_mixin.py +4 -3
  37. nat/embedder/nim_embedder.py +1 -1
  38. nat/embedder/openai_embedder.py +1 -1
  39. nat/eval/config.py +1 -1
  40. nat/eval/evaluate.py +5 -1
  41. nat/eval/register.py +4 -0
  42. nat/eval/runtime_evaluator/__init__.py +14 -0
  43. nat/eval/runtime_evaluator/evaluate.py +123 -0
  44. nat/eval/runtime_evaluator/register.py +100 -0
  45. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  46. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  47. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  48. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  49. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  50. nat/front_ends/fastapi/job_store.py +518 -99
  51. nat/front_ends/fastapi/main.py +11 -19
  52. nat/front_ends/fastapi/utils.py +57 -0
  53. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  54. nat/llm/aws_bedrock_llm.py +14 -3
  55. nat/llm/nim_llm.py +14 -3
  56. nat/llm/openai_llm.py +8 -1
  57. nat/observability/exporter/processing_exporter.py +29 -55
  58. nat/observability/mixin/redaction_config_mixin.py +5 -4
  59. nat/observability/mixin/tagging_config_mixin.py +26 -14
  60. nat/observability/mixin/type_introspection_mixin.py +401 -107
  61. nat/observability/processor/processor.py +3 -0
  62. nat/observability/processor/redaction/__init__.py +24 -0
  63. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  64. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  65. nat/observability/processor/redaction/redaction_processor.py +177 -0
  66. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  67. nat/observability/processor/span_tagging_processor.py +21 -14
  68. nat/profiler/decorators/framework_wrapper.py +9 -6
  69. nat/profiler/parameter_optimization/__init__.py +0 -0
  70. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  71. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  72. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  73. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  74. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  75. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  76. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  77. nat/profiler/utils.py +3 -1
  78. nat/tool/chat_completion.py +4 -1
  79. nat/tool/github_tools.py +450 -0
  80. nat/tool/register.py +2 -7
  81. nat/utils/callable_utils.py +70 -0
  82. nat/utils/exception_handlers/automatic_retries.py +103 -48
  83. nat/utils/type_utils.py +4 -0
  84. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  85. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
  86. nat/observability/processor/header_redaction_processor.py +0 -123
  87. nat/observability/processor/redaction_processor.py +0 -77
  88. nat/tool/github_tools/create_github_commit.py +0 -133
  89. nat/tool/github_tools/create_github_issue.py +0 -87
  90. nat/tool/github_tools/create_github_pr.py +0 -106
  91. nat/tool/github_tools/get_github_file.py +0 -106
  92. nat/tool/github_tools/get_github_issue.py +0 -166
  93. nat/tool/github_tools/get_github_pr.py +0 -256
  94. nat/tool/github_tools/update_github_issue.py +0 -100
  95. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  96. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  97. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,10 @@
16
16
  import re
17
17
 
18
18
  from pydantic import BaseModel
19
- from pydantic import Field
20
19
 
21
20
  from nat.data_models.gated_field_mixin import GatedFieldMixin
21
+ from nat.data_models.optimizable import OptimizableField
22
+ from nat.data_models.optimizable import SearchSpace
22
23
 
23
24
 
24
25
  class TemperatureMixin(
@@ -35,9 +36,9 @@ class TemperatureMixin(
35
36
  Attributes:
36
37
  temperature: Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.
37
38
  """
38
- temperature: float | None = Field(
39
+ temperature: float | None = OptimizableField(
39
40
  default=None,
40
41
  ge=0.0,
41
42
  le=1.0,
42
43
  description="Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.",
43
- )
44
+ space=SearchSpace(high=0.9, low=0.1, step=0.2))
@@ -16,9 +16,10 @@
16
16
  import re
17
17
 
18
18
  from pydantic import BaseModel
19
- from pydantic import Field
20
19
 
21
20
  from nat.data_models.gated_field_mixin import GatedFieldMixin
21
+ from nat.data_models.optimizable import OptimizableField
22
+ from nat.data_models.optimizable import SearchSpace
22
23
 
23
24
 
24
25
  class TopPMixin(
@@ -35,9 +36,9 @@ class TopPMixin(
35
36
  Attributes:
36
37
  top_p: Top-p for distribution sampling. Defaults to 1.0 when supported on the model.
37
38
  """
38
- top_p: float | None = Field(
39
+ top_p: float | None = OptimizableField(
39
40
  default=None,
40
41
  ge=0.0,
41
42
  le=1.0,
42
43
  description="Top-p for distribution sampling. Defaults to 1.0 when supported on the model.",
43
- )
44
+ space=SearchSpace(high=1.0, low=0.5, step=0.1))
@@ -50,7 +50,7 @@ class NIMEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="nim"):
50
50
  description=("The truncation strategy if the input on the "
51
51
  "server side if it's too large."))
52
52
 
53
- model_config = ConfigDict(protected_namespaces=())
53
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
54
54
 
55
55
 
56
56
  @register_embedder_provider(config_type=NIMEmbedderModelConfig)
@@ -27,7 +27,7 @@ from nat.data_models.retry_mixin import RetryMixin
27
27
  class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
28
28
  """An OpenAI LLM provider to be used with an LLM client."""
29
29
 
30
- model_config = ConfigDict(protected_namespaces=())
30
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
31
 
32
32
  api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
33
33
  base_url: str | None = Field(default=None, description="Base url to the hosted model.")
nat/eval/config.py CHANGED
@@ -27,7 +27,7 @@ class EvaluationRunConfig(BaseModel):
27
27
  """
28
28
  Parameters used for a single evaluation run.
29
29
  """
30
- config_file: Path
30
+ config_file: Path | BaseModel
31
31
  dataset: str | None = None # dataset file path can be specified in the config file
32
32
  result_json_path: str = "$"
33
33
  skip_workflow: bool = False
nat/eval/evaluate.py CHANGED
@@ -449,10 +449,14 @@ class EvaluationRun:
449
449
  from nat.runtime.loader import load_config
450
450
 
451
451
  # Load and override the config
452
- if self.config.override:
452
+ config = None
453
+ if isinstance(self.config.config_file, BaseModel):
454
+ config = self.config.config_file
455
+ elif self.config.override:
453
456
  config = self.apply_overrides()
454
457
  else:
455
458
  config = load_config(self.config.config_file)
459
+
456
460
  self.eval_config = config.eval
457
461
  workflow_alias = self._get_workflow_alias(config.workflow.type)
458
462
  logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
nat/eval/register.py CHANGED
@@ -17,6 +17,10 @@
17
17
 
18
18
  # Import evaluators which need to be automatically registered here
19
19
  from .rag_evaluator.register import register_ragas_evaluator
20
+ from .runtime_evaluator.register import register_avg_llm_latency_evaluator
21
+ from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
22
+ from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
23
+ from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
20
24
  from .swe_bench_evaluator.register import register_swe_bench_evaluator
21
25
  from .trajectory_evaluator.register import register_trajectory_evaluator
22
26
  from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,123 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections import defaultdict
19
+ from dataclasses import dataclass
20
+
21
+ from nat.data_models.intermediate_step import IntermediateStepType
22
+ from nat.eval.evaluator.base_evaluator import BaseEvaluator
23
+ from nat.eval.evaluator.evaluator_model import EvalInputItem
24
+ from nat.eval.evaluator.evaluator_model import EvalOutputItem
25
+ from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
26
+
27
+
28
+ @dataclass
29
+ class _CallTiming:
30
+ start_ts: float | None = None
31
+ end_ts: float | None = None
32
+
33
+ @property
34
+ def latency(self) -> float | None:
35
+ if self.start_ts is None or self.end_ts is None:
36
+ return None
37
+ return max(0.0, self.end_ts - self.start_ts)
38
+
39
+
40
+ class AverageLLMLatencyEvaluator(BaseEvaluator):
41
+ """
42
+ Mean difference between connected LLM_START and LLM_END events (same UUID).
43
+ The score is the average latency in seconds for the item. Reasoning contains per-call latencies.
44
+ """
45
+
46
+ def __init__(self, max_concurrency: int = 8):
47
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg LLM Latency")
48
+
49
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
50
+ calls: dict[str, _CallTiming] = defaultdict(_CallTiming)
51
+
52
+ for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
53
+ if step.event_type == IntermediateStepType.LLM_START:
54
+ calls[step.UUID].start_ts = step.event_timestamp
55
+ elif step.event_type == IntermediateStepType.LLM_END:
56
+ calls[step.UUID].end_ts = step.event_timestamp
57
+
58
+ latencies = [ct.latency for ct in calls.values() if ct.latency is not None]
59
+ avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
60
+
61
+ reasoning = {
62
+ "num_llm_calls": len(latencies),
63
+ "latencies": latencies,
64
+ }
65
+ return EvalOutputItem(id=item.id, score=round(avg_latency, 4), reasoning=reasoning)
66
+
67
+
68
+ class AverageWorkflowRuntimeEvaluator(BaseEvaluator):
69
+ """
70
+ Average workflow runtime per item: max(event_timestamp) - min(event_timestamp) across the trajectory.
71
+ The score is the runtime in seconds for the item.
72
+ """
73
+
74
+ def __init__(self, max_concurrency: int = 8):
75
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Workflow Runtime")
76
+
77
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
78
+ if not item.trajectory:
79
+ return EvalOutputItem(id=item.id, score=0.0, reasoning={"note": "no steps"})
80
+
81
+ timestamps = [s.event_timestamp for s in item.trajectory]
82
+ runtime = max(timestamps) - min(timestamps)
83
+ return EvalOutputItem(id=item.id, score=round(max(0.0, runtime), 4), reasoning={"steps": len(timestamps)})
84
+
85
+
86
+ class AverageNumberOfLLMCallsEvaluator(BaseEvaluator):
87
+ """
88
+ Average number of LLM calls per item. The score is the count for the item.
89
+ """
90
+
91
+ def __init__(self, max_concurrency: int = 8):
92
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg # LLM Calls")
93
+
94
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
95
+ num_calls = sum(1 for s in item.trajectory if s.event_type == IntermediateStepType.LLM_END)
96
+ return EvalOutputItem(id=item.id, score=float(num_calls), reasoning={"num_llm_end": num_calls})
97
+
98
+
99
+ class AverageTokensPerLLMEndEvaluator(BaseEvaluator):
100
+ """
101
+ Average total tokens per LLM_END event: sum of prompt and completion tokens if available.
102
+ The score is the average tokens per LLM_END for the item (0 if none).
103
+ """
104
+
105
+ def __init__(self, max_concurrency: int = 8):
106
+ super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Tokens/LLM_END")
107
+
108
+ async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
109
+ totals: list[int] = []
110
+ for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
111
+ if step.event_type == IntermediateStepType.LLM_END:
112
+ total_tokens = step.token_usage.total_tokens
113
+ # If framework doesn't set total, compute from prompt+completion
114
+ if total_tokens == 0:
115
+ total_tokens = step.token_usage.prompt_tokens + step.token_usage.completion_tokens
116
+ totals.append(total_tokens)
117
+
118
+ avg_tokens = (sum(totals) / len(totals)) if totals else 0.0
119
+ reasoning = {
120
+ "num_llm_end": len(totals),
121
+ "totals": totals,
122
+ }
123
+ return EvalOutputItem(id=item.id, score=round(avg_tokens, 2), reasoning=reasoning)
@@ -0,0 +1,100 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import Field
17
+
18
+ from nat.builder.builder import EvalBuilder
19
+ from nat.builder.evaluator import EvaluatorInfo
20
+ from nat.cli.register_workflow import register_evaluator
21
+ from nat.data_models.evaluator import EvaluatorBaseConfig
22
+ from nat.eval.evaluator.evaluator_model import EvalInput
23
+ from nat.eval.evaluator.evaluator_model import EvalOutput
24
+
25
+
26
+ class AverageLLMLatencyConfig(EvaluatorBaseConfig, name="avg_llm_latency"):
27
+ """Mean difference between connected LLM_START and LLM_END events (same UUID)."""
28
+
29
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
30
+
31
+
32
+ class AverageWorkflowRuntimeConfig(EvaluatorBaseConfig, name="avg_workflow_runtime"):
33
+ """Average workflow runtime per item (max timestamp - min timestamp)."""
34
+
35
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
36
+
37
+
38
+ class AverageNumberOfLLMCallsConfig(EvaluatorBaseConfig, name="avg_num_llm_calls"):
39
+ """Average number of LLM calls per item (count of LLM_END)."""
40
+
41
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
42
+
43
+
44
+ class AverageTokensPerLLMEndConfig(EvaluatorBaseConfig, name="avg_tokens_per_llm_end"):
45
+ """Average total tokens per LLM_END event (prompt + completion if available)."""
46
+
47
+ max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
48
+
49
+
50
+ @register_evaluator(config_type=AverageLLMLatencyConfig)
51
+ async def register_avg_llm_latency_evaluator(config: AverageLLMLatencyConfig, builder: EvalBuilder):
52
+ from .evaluate import AverageLLMLatencyEvaluator
53
+
54
+ evaluator = AverageLLMLatencyEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
55
+
56
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
57
+ return await evaluator.evaluate(eval_input)
58
+
59
+ yield EvaluatorInfo(config=config,
60
+ evaluate_fn=evaluate_fn,
61
+ description="Average LLM latency (s) from LLM_START to LLM_END")
62
+
63
+
64
+ @register_evaluator(config_type=AverageWorkflowRuntimeConfig)
65
+ async def register_avg_workflow_runtime_evaluator(config: AverageWorkflowRuntimeConfig, builder: EvalBuilder):
66
+ from .evaluate import AverageWorkflowRuntimeEvaluator
67
+
68
+ evaluator = AverageWorkflowRuntimeEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
69
+
70
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
71
+ return await evaluator.evaluate(eval_input)
72
+
73
+ yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average workflow runtime (s)")
74
+
75
+
76
+ @register_evaluator(config_type=AverageNumberOfLLMCallsConfig)
77
+ async def register_avg_num_llm_calls_evaluator(config: AverageNumberOfLLMCallsConfig, builder: EvalBuilder):
78
+ from .evaluate import AverageNumberOfLLMCallsEvaluator
79
+
80
+ evaluator = AverageNumberOfLLMCallsEvaluator(
81
+ max_concurrency=config.max_concurrency or builder.get_max_concurrency())
82
+
83
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
84
+ return await evaluator.evaluate(eval_input)
85
+
86
+ yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average number of LLM calls")
87
+
88
+
89
+ @register_evaluator(config_type=AverageTokensPerLLMEndConfig)
90
+ async def register_avg_tokens_per_llm_end_evaluator(config: AverageTokensPerLLMEndConfig, builder: EvalBuilder):
91
+ from .evaluate import AverageTokensPerLLMEndEvaluator
92
+
93
+ evaluator = AverageTokensPerLLMEndEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
94
+
95
+ async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
96
+ return await evaluator.evaluate(eval_input)
97
+
98
+ yield EvaluatorInfo(config=config,
99
+ evaluate_fn=evaluate_fn,
100
+ description="Average total tokens per LLM_END (prompt + completion)")
@@ -97,7 +97,11 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
97
97
  f"function without a description.")
98
98
 
99
99
  # Get the function dependencies of the augmented function
100
- function_used_tools = builder.get_function_dependencies(config.augmented_fn).functions
100
+ function_dependencies = builder.get_function_dependencies(config.augmented_fn)
101
+ function_used_tools = set(function_dependencies.functions)
102
+ for function_group in function_dependencies.function_groups:
103
+ function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
104
+
101
105
  tool_list = "Tool: Description\n"
102
106
 
103
107
  for tool in function_used_tools:
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import typing
17
+ from abc import ABC
18
+ from collections.abc import AsyncGenerator
19
+ from contextlib import asynccontextmanager
20
+
21
+ if typing.TYPE_CHECKING:
22
+ from dask.distributed import Client
23
+
24
+
25
+ class DaskClientMixin(ABC):
26
+
27
+ @asynccontextmanager
28
+ async def client(self, address: str) -> AsyncGenerator["Client"]:
29
+ """
30
+ Async context manager for obtaining a Dask client connection.
31
+
32
+ Yields
33
+ ------
34
+ Client
35
+ An active Dask client connected to the scheduler. The client is automatically closed when exiting the
36
+ context manager.
37
+ """
38
+ from dask.distributed import Client
39
+ client = await Client(address=address, asynchronous=True)
40
+
41
+ yield client
42
+
43
+ await client.close()
@@ -197,9 +197,20 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
197
197
  port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
198
198
  reload: bool = Field(default=False, description="Enable auto-reload for development")
199
199
  workers: int = Field(default=1, description="Number of workers to run", ge=1)
200
- max_running_async_jobs: int = Field(default=10,
201
- description="Maximum number of async jobs to run concurrently",
202
- ge=1)
200
+ scheduler_address: str | None = Field(
201
+ default=None,
202
+ description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
203
+ "Note: This requires the optional dask dependency to be installed."))
204
+ db_url: str | None = Field(
205
+ default=None,
206
+ description=
207
+ "SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
208
+ max_running_async_jobs: int = Field(
209
+ default=10,
210
+ description=(
211
+ "Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
212
+ "This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
213
+ ge=1)
203
214
  step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
204
215
 
205
216
  workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
@@ -13,21 +13,36 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import asyncio
16
17
  import logging
17
18
  import os
19
+ import sys
18
20
  import tempfile
19
21
  import typing
20
22
 
21
23
  from nat.builder.front_end import FrontEndBase
24
+ from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
22
25
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
23
26
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
24
27
  from nat.front_ends.fastapi.main import get_app
28
+ from nat.front_ends.fastapi.utils import get_class_name
25
29
  from nat.utils.io.yaml_tools import yaml_dump
26
30
 
31
+ if (typing.TYPE_CHECKING):
32
+ from nat.data_models.config import Config
33
+
27
34
  logger = logging.getLogger(__name__)
28
35
 
29
36
 
30
- class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
37
+ class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]):
38
+
39
+ def __init__(self, full_config: "Config"):
40
+ super().__init__(full_config)
41
+
42
+ # This attribute is set if dask is installed, and an external cluster is not used (scheduler_address is None)
43
+ self._cluster = None
44
+ self._periodic_cleanup_future = None
45
+ self._scheduler_address = None
31
46
 
32
47
  def get_worker_class(self) -> type[FastApiFrontEndPluginWorkerBase]:
33
48
  from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
@@ -42,7 +57,38 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
42
57
 
43
58
  worker_class = self.get_worker_class()
44
59
 
45
- return f"{worker_class.__module__}.{worker_class.__qualname__}"
60
+ return get_class_name(worker_class)
61
+
62
+ @staticmethod
63
+ async def _periodic_cleanup(scheduler_address: str,
64
+ db_url: str,
65
+ sleep_time_sec: int = 300,
66
+ log_level: int = logging.INFO):
67
+ from nat.front_ends.fastapi.job_store import JobStore
68
+
69
+ job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
70
+
71
+ logging.basicConfig(level=log_level)
72
+ logger.info("Starting periodic cleanup of expired jobs every %d seconds", sleep_time_sec)
73
+ while True:
74
+ await asyncio.sleep(sleep_time_sec)
75
+
76
+ try:
77
+ await job_store.cleanup_expired_jobs()
78
+ logger.debug("Expired jobs cleaned up")
79
+ except: # noqa: E722
80
+ logger.exception("Error during job cleanup")
81
+
82
+ async def _submit_cleanup_task(self, scheduler_address: str, db_url: str):
83
+ """Submit a cleanup task to the cluster to remove the job after expiry."""
84
+ logger.info("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
85
+ async with self.client(self._scheduler_address) as client:
86
+ self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
87
+ scheduler_address=self._scheduler_address,
88
+ db_url=db_url,
89
+ log_level=logger.getEffectiveLevel())
90
+
91
+ logger.info("Submitted periodic cleanup task to Dask cluster at %s", scheduler_address)
46
92
 
47
93
  async def run(self):
48
94
 
@@ -52,6 +98,45 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
52
98
  # Get as dict
53
99
  config_dict = self.full_config.model_dump(mode="json", by_alias=True, round_trip=True)
54
100
 
101
+ # Three possible cases:
102
+ # 1. Dask is installed and scheduler_address is None, we create a LocalCluster
103
+ # 2. Dask is installed and scheduler_address is set, we use the existing cluster
104
+ # 3. Dask is not installed, we skip the cluster setup
105
+ self._scheduler_address = self.front_end_config.scheduler_address
106
+ if self._scheduler_address is None:
107
+ try:
108
+ from dask.distributed import LocalCluster
109
+
110
+ self._cluster = LocalCluster(n_workers=self.front_end_config.max_running_async_jobs,
111
+ threads_per_worker=1)
112
+
113
+ self._scheduler_address = self._cluster.scheduler.address
114
+ logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
115
+
116
+ except ImportError:
117
+ logger.warning("Dask is not installed, async execution and evaluation will not be available.")
118
+
119
+ if self._scheduler_address is not None:
120
+ # If we are here then either the user provided a scheduler address, or we created a LocalCluster
121
+
122
+ from nat.front_ends.fastapi.job_store import Base
123
+ from nat.front_ends.fastapi.job_store import get_db_engine
124
+
125
+ db_engine = get_db_engine(self.front_end_config.db_url, use_async=True)
126
+ async with db_engine.begin() as conn:
127
+ await conn.run_sync(Base.metadata.create_all, checkfirst=True) # create tables if they do not exist
128
+
129
+ # If self.front_end_config.db_url is None, then we need to get the actual url from the engine
130
+ db_url = str(db_engine.url)
131
+ await self._submit_cleanup_task(scheduler_address=self._scheduler_address, db_url=db_url)
132
+
133
+ # Set environment variabls such that the worker subprocesses will know how to connect to dask and to
134
+ # the database
135
+ os.environ.update({
136
+ "NAT_DASK_SCHEDULER_ADDRESS": self._scheduler_address,
137
+ "NAT_JOB_STORE_DB_URL": db_url,
138
+ })
139
+
55
140
  # Write to YAML file
56
141
  yaml_dump(config_dict, config_file)
57
142
 
@@ -70,13 +155,25 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
70
155
 
71
156
  reload_excludes = ["./.*"]
72
157
 
158
+ # By default, Uvicorn uses "auto" event loop policy, which prefers `uvloop` if installed. However,
159
+ # uvloop’s event loop policy for macOS doesn’t provide a child watcher (which is needed for MCP server),
160
+ # so setting loop="asyncio" forces Uvicorn to use the standard event loop, which includes child-watcher
161
+ # support.
162
+ if sys.platform == "darwin" or sys.platform.startswith("linux"):
163
+ # For macOS
164
+ event_loop_policy = "asyncio"
165
+ else:
166
+ # For non-macOS platforms
167
+ event_loop_policy = "auto"
168
+
73
169
  uvicorn.run("nat.front_ends.fastapi.main:get_app",
74
170
  host=self.front_end_config.host,
75
171
  port=self.front_end_config.port,
76
172
  workers=self.front_end_config.workers,
77
173
  reload=self.front_end_config.reload,
78
174
  factory=True,
79
- reload_excludes=reload_excludes)
175
+ reload_excludes=reload_excludes,
176
+ loop=event_loop_policy)
80
177
 
81
178
  else:
82
179
  app = get_app()
@@ -110,6 +207,17 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
110
207
  StandaloneApplication(app, options=options).run()
111
208
 
112
209
  finally:
210
+ logger.debug("Shutting down")
211
+ if self._periodic_cleanup_future is not None:
212
+ logger.info("Cancelling periodic cleanup task.")
213
+ # Use the scheduler address, because self._cluster is None if an external cluster is used
214
+ async with self.client(self._scheduler_address) as client:
215
+ await client.cancel([self._periodic_cleanup_future], asynchronous=True, force=True)
216
+
217
+ if self._cluster is not None:
218
+ # Only shut down the cluster if we created it
219
+ logger.info("Closing Local Dask cluster.")
220
+ self._cluster.close()
113
221
  try:
114
222
  os.remove(config_file_name)
115
223
  except OSError as e: