nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +15 -5
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +4 -2
- nat/agent/rewoo_agent/register.py +8 -3
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +11 -3
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import re
|
|
17
17
|
|
|
18
18
|
from pydantic import BaseModel
|
|
19
|
-
from pydantic import Field
|
|
20
19
|
|
|
21
20
|
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
21
|
+
from nat.data_models.optimizable import OptimizableField
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class TemperatureMixin(
|
|
@@ -35,9 +36,9 @@ class TemperatureMixin(
|
|
|
35
36
|
Attributes:
|
|
36
37
|
temperature: Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.
|
|
37
38
|
"""
|
|
38
|
-
temperature: float | None =
|
|
39
|
+
temperature: float | None = OptimizableField(
|
|
39
40
|
default=None,
|
|
40
41
|
ge=0.0,
|
|
41
42
|
le=1.0,
|
|
42
43
|
description="Sampling temperature in [0, 1]. Defaults to 0.0 when supported on the model.",
|
|
43
|
-
|
|
44
|
+
space=SearchSpace(high=0.9, low=0.1, step=0.2))
|
nat/data_models/top_p_mixin.py
CHANGED
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import re
|
|
17
17
|
|
|
18
18
|
from pydantic import BaseModel
|
|
19
|
-
from pydantic import Field
|
|
20
19
|
|
|
21
20
|
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
21
|
+
from nat.data_models.optimizable import OptimizableField
|
|
22
|
+
from nat.data_models.optimizable import SearchSpace
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class TopPMixin(
|
|
@@ -35,9 +36,9 @@ class TopPMixin(
|
|
|
35
36
|
Attributes:
|
|
36
37
|
top_p: Top-p for distribution sampling. Defaults to 1.0 when supported on the model.
|
|
37
38
|
"""
|
|
38
|
-
top_p: float | None =
|
|
39
|
+
top_p: float | None = OptimizableField(
|
|
39
40
|
default=None,
|
|
40
41
|
ge=0.0,
|
|
41
42
|
le=1.0,
|
|
42
43
|
description="Top-p for distribution sampling. Defaults to 1.0 when supported on the model.",
|
|
43
|
-
|
|
44
|
+
space=SearchSpace(high=1.0, low=0.5, step=0.1))
|
nat/embedder/nim_embedder.py
CHANGED
|
@@ -50,7 +50,7 @@ class NIMEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="nim"):
|
|
|
50
50
|
description=("The truncation strategy if the input on the "
|
|
51
51
|
"server side if it's too large."))
|
|
52
52
|
|
|
53
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
53
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
@register_embedder_provider(config_type=NIMEmbedderModelConfig)
|
nat/embedder/openai_embedder.py
CHANGED
|
@@ -27,7 +27,7 @@ from nat.data_models.retry_mixin import RetryMixin
|
|
|
27
27
|
class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
|
|
28
28
|
"""An OpenAI LLM provider to be used with an LLM client."""
|
|
29
29
|
|
|
30
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
30
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
31
|
|
|
32
32
|
api_key: str | None = Field(default=None, description="OpenAI API key to interact with hosted model.")
|
|
33
33
|
base_url: str | None = Field(default=None, description="Base url to the hosted model.")
|
nat/eval/config.py
CHANGED
|
@@ -27,7 +27,7 @@ class EvaluationRunConfig(BaseModel):
|
|
|
27
27
|
"""
|
|
28
28
|
Parameters used for a single evaluation run.
|
|
29
29
|
"""
|
|
30
|
-
config_file: Path
|
|
30
|
+
config_file: Path | BaseModel
|
|
31
31
|
dataset: str | None = None # dataset file path can be specified in the config file
|
|
32
32
|
result_json_path: str = "$"
|
|
33
33
|
skip_workflow: bool = False
|
nat/eval/evaluate.py
CHANGED
|
@@ -449,10 +449,14 @@ class EvaluationRun:
|
|
|
449
449
|
from nat.runtime.loader import load_config
|
|
450
450
|
|
|
451
451
|
# Load and override the config
|
|
452
|
-
|
|
452
|
+
config = None
|
|
453
|
+
if isinstance(self.config.config_file, BaseModel):
|
|
454
|
+
config = self.config.config_file
|
|
455
|
+
elif self.config.override:
|
|
453
456
|
config = self.apply_overrides()
|
|
454
457
|
else:
|
|
455
458
|
config = load_config(self.config.config_file)
|
|
459
|
+
|
|
456
460
|
self.eval_config = config.eval
|
|
457
461
|
workflow_alias = self._get_workflow_alias(config.workflow.type)
|
|
458
462
|
logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
|
nat/eval/register.py
CHANGED
|
@@ -17,6 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
# Import evaluators which need to be automatically registered here
|
|
19
19
|
from .rag_evaluator.register import register_ragas_evaluator
|
|
20
|
+
from .runtime_evaluator.register import register_avg_llm_latency_evaluator
|
|
21
|
+
from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
|
|
22
|
+
from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
|
|
23
|
+
from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
|
|
20
24
|
from .swe_bench_evaluator.register import register_swe_bench_evaluator
|
|
21
25
|
from .trajectory_evaluator.register import register_trajectory_evaluator
|
|
22
26
|
from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
22
|
+
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
24
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
25
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _CallTiming:
|
|
30
|
+
start_ts: float | None = None
|
|
31
|
+
end_ts: float | None = None
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def latency(self) -> float | None:
|
|
35
|
+
if self.start_ts is None or self.end_ts is None:
|
|
36
|
+
return None
|
|
37
|
+
return max(0.0, self.end_ts - self.start_ts)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AverageLLMLatencyEvaluator(BaseEvaluator):
|
|
41
|
+
"""
|
|
42
|
+
Mean difference between connected LLM_START and LLM_END events (same UUID).
|
|
43
|
+
The score is the average latency in seconds for the item. Reasoning contains per-call latencies.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, max_concurrency: int = 8):
|
|
47
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg LLM Latency")
|
|
48
|
+
|
|
49
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
50
|
+
calls: dict[str, _CallTiming] = defaultdict(_CallTiming)
|
|
51
|
+
|
|
52
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
53
|
+
if step.event_type == IntermediateStepType.LLM_START:
|
|
54
|
+
calls[step.UUID].start_ts = step.event_timestamp
|
|
55
|
+
elif step.event_type == IntermediateStepType.LLM_END:
|
|
56
|
+
calls[step.UUID].end_ts = step.event_timestamp
|
|
57
|
+
|
|
58
|
+
latencies = [ct.latency for ct in calls.values() if ct.latency is not None]
|
|
59
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
|
|
60
|
+
|
|
61
|
+
reasoning = {
|
|
62
|
+
"num_llm_calls": len(latencies),
|
|
63
|
+
"latencies": latencies,
|
|
64
|
+
}
|
|
65
|
+
return EvalOutputItem(id=item.id, score=round(avg_latency, 4), reasoning=reasoning)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AverageWorkflowRuntimeEvaluator(BaseEvaluator):
|
|
69
|
+
"""
|
|
70
|
+
Average workflow runtime per item: max(event_timestamp) - min(event_timestamp) across the trajectory.
|
|
71
|
+
The score is the runtime in seconds for the item.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, max_concurrency: int = 8):
|
|
75
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Workflow Runtime")
|
|
76
|
+
|
|
77
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
78
|
+
if not item.trajectory:
|
|
79
|
+
return EvalOutputItem(id=item.id, score=0.0, reasoning={"note": "no steps"})
|
|
80
|
+
|
|
81
|
+
timestamps = [s.event_timestamp for s in item.trajectory]
|
|
82
|
+
runtime = max(timestamps) - min(timestamps)
|
|
83
|
+
return EvalOutputItem(id=item.id, score=round(max(0.0, runtime), 4), reasoning={"steps": len(timestamps)})
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AverageNumberOfLLMCallsEvaluator(BaseEvaluator):
|
|
87
|
+
"""
|
|
88
|
+
Average number of LLM calls per item. The score is the count for the item.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, max_concurrency: int = 8):
|
|
92
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg # LLM Calls")
|
|
93
|
+
|
|
94
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
95
|
+
num_calls = sum(1 for s in item.trajectory if s.event_type == IntermediateStepType.LLM_END)
|
|
96
|
+
return EvalOutputItem(id=item.id, score=float(num_calls), reasoning={"num_llm_end": num_calls})
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class AverageTokensPerLLMEndEvaluator(BaseEvaluator):
|
|
100
|
+
"""
|
|
101
|
+
Average total tokens per LLM_END event: sum of prompt and completion tokens if available.
|
|
102
|
+
The score is the average tokens per LLM_END for the item (0 if none).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, max_concurrency: int = 8):
|
|
106
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Tokens/LLM_END")
|
|
107
|
+
|
|
108
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
109
|
+
totals: list[int] = []
|
|
110
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
111
|
+
if step.event_type == IntermediateStepType.LLM_END:
|
|
112
|
+
total_tokens = step.token_usage.total_tokens
|
|
113
|
+
# If framework doesn't set total, compute from prompt+completion
|
|
114
|
+
if total_tokens == 0:
|
|
115
|
+
total_tokens = step.token_usage.prompt_tokens + step.token_usage.completion_tokens
|
|
116
|
+
totals.append(total_tokens)
|
|
117
|
+
|
|
118
|
+
avg_tokens = (sum(totals) / len(totals)) if totals else 0.0
|
|
119
|
+
reasoning = {
|
|
120
|
+
"num_llm_end": len(totals),
|
|
121
|
+
"totals": totals,
|
|
122
|
+
}
|
|
123
|
+
return EvalOutputItem(id=item.id, score=round(avg_tokens, 2), reasoning=reasoning)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import EvalBuilder
|
|
19
|
+
from nat.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from nat.cli.register_workflow import register_evaluator
|
|
21
|
+
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
22
|
+
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalOutput
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AverageLLMLatencyConfig(EvaluatorBaseConfig, name="avg_llm_latency"):
|
|
27
|
+
"""Mean difference between connected LLM_START and LLM_END events (same UUID)."""
|
|
28
|
+
|
|
29
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AverageWorkflowRuntimeConfig(EvaluatorBaseConfig, name="avg_workflow_runtime"):
|
|
33
|
+
"""Average workflow runtime per item (max timestamp - min timestamp)."""
|
|
34
|
+
|
|
35
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AverageNumberOfLLMCallsConfig(EvaluatorBaseConfig, name="avg_num_llm_calls"):
|
|
39
|
+
"""Average number of LLM calls per item (count of LLM_END)."""
|
|
40
|
+
|
|
41
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AverageTokensPerLLMEndConfig(EvaluatorBaseConfig, name="avg_tokens_per_llm_end"):
|
|
45
|
+
"""Average total tokens per LLM_END event (prompt + completion if available)."""
|
|
46
|
+
|
|
47
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_evaluator(config_type=AverageLLMLatencyConfig)
|
|
51
|
+
async def register_avg_llm_latency_evaluator(config: AverageLLMLatencyConfig, builder: EvalBuilder):
|
|
52
|
+
from .evaluate import AverageLLMLatencyEvaluator
|
|
53
|
+
|
|
54
|
+
evaluator = AverageLLMLatencyEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
55
|
+
|
|
56
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
57
|
+
return await evaluator.evaluate(eval_input)
|
|
58
|
+
|
|
59
|
+
yield EvaluatorInfo(config=config,
|
|
60
|
+
evaluate_fn=evaluate_fn,
|
|
61
|
+
description="Average LLM latency (s) from LLM_START to LLM_END")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@register_evaluator(config_type=AverageWorkflowRuntimeConfig)
|
|
65
|
+
async def register_avg_workflow_runtime_evaluator(config: AverageWorkflowRuntimeConfig, builder: EvalBuilder):
|
|
66
|
+
from .evaluate import AverageWorkflowRuntimeEvaluator
|
|
67
|
+
|
|
68
|
+
evaluator = AverageWorkflowRuntimeEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
69
|
+
|
|
70
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
71
|
+
return await evaluator.evaluate(eval_input)
|
|
72
|
+
|
|
73
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average workflow runtime (s)")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@register_evaluator(config_type=AverageNumberOfLLMCallsConfig)
|
|
77
|
+
async def register_avg_num_llm_calls_evaluator(config: AverageNumberOfLLMCallsConfig, builder: EvalBuilder):
|
|
78
|
+
from .evaluate import AverageNumberOfLLMCallsEvaluator
|
|
79
|
+
|
|
80
|
+
evaluator = AverageNumberOfLLMCallsEvaluator(
|
|
81
|
+
max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
82
|
+
|
|
83
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
84
|
+
return await evaluator.evaluate(eval_input)
|
|
85
|
+
|
|
86
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average number of LLM calls")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@register_evaluator(config_type=AverageTokensPerLLMEndConfig)
|
|
90
|
+
async def register_avg_tokens_per_llm_end_evaluator(config: AverageTokensPerLLMEndConfig, builder: EvalBuilder):
|
|
91
|
+
from .evaluate import AverageTokensPerLLMEndEvaluator
|
|
92
|
+
|
|
93
|
+
evaluator = AverageTokensPerLLMEndEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
94
|
+
|
|
95
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
96
|
+
return await evaluator.evaluate(eval_input)
|
|
97
|
+
|
|
98
|
+
yield EvaluatorInfo(config=config,
|
|
99
|
+
evaluate_fn=evaluate_fn,
|
|
100
|
+
description="Average total tokens per LLM_END (prompt + completion)")
|
|
@@ -97,7 +97,11 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
97
97
|
f"function without a description.")
|
|
98
98
|
|
|
99
99
|
# Get the function dependencies of the augmented function
|
|
100
|
-
|
|
100
|
+
function_dependencies = builder.get_function_dependencies(config.augmented_fn)
|
|
101
|
+
function_used_tools = set(function_dependencies.functions)
|
|
102
|
+
for function_group in function_dependencies.function_groups:
|
|
103
|
+
function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
|
|
104
|
+
|
|
101
105
|
tool_list = "Tool: Description\n"
|
|
102
106
|
|
|
103
107
|
for tool in function_used_tools:
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import typing
|
|
17
|
+
from abc import ABC
|
|
18
|
+
from collections.abc import AsyncGenerator
|
|
19
|
+
from contextlib import asynccontextmanager
|
|
20
|
+
|
|
21
|
+
if typing.TYPE_CHECKING:
|
|
22
|
+
from dask.distributed import Client
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DaskClientMixin(ABC):
|
|
26
|
+
|
|
27
|
+
@asynccontextmanager
|
|
28
|
+
async def client(self, address: str) -> AsyncGenerator["Client"]:
|
|
29
|
+
"""
|
|
30
|
+
Async context manager for obtaining a Dask client connection.
|
|
31
|
+
|
|
32
|
+
Yields
|
|
33
|
+
------
|
|
34
|
+
Client
|
|
35
|
+
An active Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
36
|
+
context manager.
|
|
37
|
+
"""
|
|
38
|
+
from dask.distributed import Client
|
|
39
|
+
client = await Client(address=address, asynchronous=True)
|
|
40
|
+
|
|
41
|
+
yield client
|
|
42
|
+
|
|
43
|
+
await client.close()
|
|
@@ -197,9 +197,20 @@ class FastApiFrontEndConfig(FrontEndBaseConfig, name="fastapi"):
|
|
|
197
197
|
port: int = Field(default=8000, description="Port to bind the server to", ge=0, le=65535)
|
|
198
198
|
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
|
199
199
|
workers: int = Field(default=1, description="Number of workers to run", ge=1)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
200
|
+
scheduler_address: str | None = Field(
|
|
201
|
+
default=None,
|
|
202
|
+
description=("Address of the Dask scheduler to use for async jobs. If None, a Dask local cluster is created. "
|
|
203
|
+
"Note: This requires the optional dask dependency to be installed."))
|
|
204
|
+
db_url: str | None = Field(
|
|
205
|
+
default=None,
|
|
206
|
+
description=
|
|
207
|
+
"SQLAlchemy database URL for storing async job metadata, if unset a temporary SQLite database is used.")
|
|
208
|
+
max_running_async_jobs: int = Field(
|
|
209
|
+
default=10,
|
|
210
|
+
description=(
|
|
211
|
+
"Maximum number of async jobs to run concurrently, this controls the number of dask workers created. "
|
|
212
|
+
"This parameter is only used when scheduler_address is `None` and a Dask local cluster is created."),
|
|
213
|
+
ge=1)
|
|
203
214
|
step_adaptor: StepAdaptorConfig = StepAdaptorConfig()
|
|
204
215
|
|
|
205
216
|
workflow: typing.Annotated[EndpointBase, Field(description="Endpoint for the default workflow.")] = EndpointBase(
|
|
@@ -13,21 +13,36 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
19
|
+
import sys
|
|
18
20
|
import tempfile
|
|
19
21
|
import typing
|
|
20
22
|
|
|
21
23
|
from nat.builder.front_end import FrontEndBase
|
|
24
|
+
from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
|
|
22
25
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
23
26
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorkerBase
|
|
24
27
|
from nat.front_ends.fastapi.main import get_app
|
|
28
|
+
from nat.front_ends.fastapi.utils import get_class_name
|
|
25
29
|
from nat.utils.io.yaml_tools import yaml_dump
|
|
26
30
|
|
|
31
|
+
if (typing.TYPE_CHECKING):
|
|
32
|
+
from nat.data_models.config import Config
|
|
33
|
+
|
|
27
34
|
logger = logging.getLogger(__name__)
|
|
28
35
|
|
|
29
36
|
|
|
30
|
-
class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
37
|
+
class FastApiFrontEndPlugin(DaskClientMixin, FrontEndBase[FastApiFrontEndConfig]):
|
|
38
|
+
|
|
39
|
+
def __init__(self, full_config: "Config"):
|
|
40
|
+
super().__init__(full_config)
|
|
41
|
+
|
|
42
|
+
# This attribute is set if dask is installed, and an external cluster is not used (scheduler_address is None)
|
|
43
|
+
self._cluster = None
|
|
44
|
+
self._periodic_cleanup_future = None
|
|
45
|
+
self._scheduler_address = None
|
|
31
46
|
|
|
32
47
|
def get_worker_class(self) -> type[FastApiFrontEndPluginWorkerBase]:
|
|
33
48
|
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
|
|
@@ -42,7 +57,38 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
42
57
|
|
|
43
58
|
worker_class = self.get_worker_class()
|
|
44
59
|
|
|
45
|
-
return
|
|
60
|
+
return get_class_name(worker_class)
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
async def _periodic_cleanup(scheduler_address: str,
|
|
64
|
+
db_url: str,
|
|
65
|
+
sleep_time_sec: int = 300,
|
|
66
|
+
log_level: int = logging.INFO):
|
|
67
|
+
from nat.front_ends.fastapi.job_store import JobStore
|
|
68
|
+
|
|
69
|
+
job_store = JobStore(scheduler_address=scheduler_address, db_url=db_url)
|
|
70
|
+
|
|
71
|
+
logging.basicConfig(level=log_level)
|
|
72
|
+
logger.info("Starting periodic cleanup of expired jobs every %d seconds", sleep_time_sec)
|
|
73
|
+
while True:
|
|
74
|
+
await asyncio.sleep(sleep_time_sec)
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
await job_store.cleanup_expired_jobs()
|
|
78
|
+
logger.debug("Expired jobs cleaned up")
|
|
79
|
+
except: # noqa: E722
|
|
80
|
+
logger.exception("Error during job cleanup")
|
|
81
|
+
|
|
82
|
+
async def _submit_cleanup_task(self, scheduler_address: str, db_url: str):
|
|
83
|
+
"""Submit a cleanup task to the cluster to remove the job after expiry."""
|
|
84
|
+
logger.info("Submitting periodic cleanup task to Dask cluster at %s", scheduler_address)
|
|
85
|
+
async with self.client(self._scheduler_address) as client:
|
|
86
|
+
self._periodic_cleanup_future = client.submit(self._periodic_cleanup,
|
|
87
|
+
scheduler_address=self._scheduler_address,
|
|
88
|
+
db_url=db_url,
|
|
89
|
+
log_level=logger.getEffectiveLevel())
|
|
90
|
+
|
|
91
|
+
logger.info("Submitted periodic cleanup task to Dask cluster at %s", scheduler_address)
|
|
46
92
|
|
|
47
93
|
async def run(self):
|
|
48
94
|
|
|
@@ -52,6 +98,45 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
52
98
|
# Get as dict
|
|
53
99
|
config_dict = self.full_config.model_dump(mode="json", by_alias=True, round_trip=True)
|
|
54
100
|
|
|
101
|
+
# Three possible cases:
|
|
102
|
+
# 1. Dask is installed and scheduler_address is None, we create a LocalCluster
|
|
103
|
+
# 2. Dask is installed and scheduler_address is set, we use the existing cluster
|
|
104
|
+
# 3. Dask is not installed, we skip the cluster setup
|
|
105
|
+
self._scheduler_address = self.front_end_config.scheduler_address
|
|
106
|
+
if self._scheduler_address is None:
|
|
107
|
+
try:
|
|
108
|
+
from dask.distributed import LocalCluster
|
|
109
|
+
|
|
110
|
+
self._cluster = LocalCluster(n_workers=self.front_end_config.max_running_async_jobs,
|
|
111
|
+
threads_per_worker=1)
|
|
112
|
+
|
|
113
|
+
self._scheduler_address = self._cluster.scheduler.address
|
|
114
|
+
logger.info("Created local Dask cluster with scheduler at %s", self._scheduler_address)
|
|
115
|
+
|
|
116
|
+
except ImportError:
|
|
117
|
+
logger.warning("Dask is not installed, async execution and evaluation will not be available.")
|
|
118
|
+
|
|
119
|
+
if self._scheduler_address is not None:
|
|
120
|
+
# If we are here then either the user provided a scheduler address, or we created a LocalCluster
|
|
121
|
+
|
|
122
|
+
from nat.front_ends.fastapi.job_store import Base
|
|
123
|
+
from nat.front_ends.fastapi.job_store import get_db_engine
|
|
124
|
+
|
|
125
|
+
db_engine = get_db_engine(self.front_end_config.db_url, use_async=True)
|
|
126
|
+
async with db_engine.begin() as conn:
|
|
127
|
+
await conn.run_sync(Base.metadata.create_all, checkfirst=True) # create tables if they do not exist
|
|
128
|
+
|
|
129
|
+
# If self.front_end_config.db_url is None, then we need to get the actual url from the engine
|
|
130
|
+
db_url = str(db_engine.url)
|
|
131
|
+
await self._submit_cleanup_task(scheduler_address=self._scheduler_address, db_url=db_url)
|
|
132
|
+
|
|
133
|
+
# Set environment variabls such that the worker subprocesses will know how to connect to dask and to
|
|
134
|
+
# the database
|
|
135
|
+
os.environ.update({
|
|
136
|
+
"NAT_DASK_SCHEDULER_ADDRESS": self._scheduler_address,
|
|
137
|
+
"NAT_JOB_STORE_DB_URL": db_url,
|
|
138
|
+
})
|
|
139
|
+
|
|
55
140
|
# Write to YAML file
|
|
56
141
|
yaml_dump(config_dict, config_file)
|
|
57
142
|
|
|
@@ -70,13 +155,25 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
70
155
|
|
|
71
156
|
reload_excludes = ["./.*"]
|
|
72
157
|
|
|
158
|
+
# By default, Uvicorn uses "auto" event loop policy, which prefers `uvloop` if installed. However,
|
|
159
|
+
# uvloop’s event loop policy for macOS doesn’t provide a child watcher (which is needed for MCP server),
|
|
160
|
+
# so setting loop="asyncio" forces Uvicorn to use the standard event loop, which includes child-watcher
|
|
161
|
+
# support.
|
|
162
|
+
if sys.platform == "darwin" or sys.platform.startswith("linux"):
|
|
163
|
+
# For macOS
|
|
164
|
+
event_loop_policy = "asyncio"
|
|
165
|
+
else:
|
|
166
|
+
# For non-macOS platforms
|
|
167
|
+
event_loop_policy = "auto"
|
|
168
|
+
|
|
73
169
|
uvicorn.run("nat.front_ends.fastapi.main:get_app",
|
|
74
170
|
host=self.front_end_config.host,
|
|
75
171
|
port=self.front_end_config.port,
|
|
76
172
|
workers=self.front_end_config.workers,
|
|
77
173
|
reload=self.front_end_config.reload,
|
|
78
174
|
factory=True,
|
|
79
|
-
reload_excludes=reload_excludes
|
|
175
|
+
reload_excludes=reload_excludes,
|
|
176
|
+
loop=event_loop_policy)
|
|
80
177
|
|
|
81
178
|
else:
|
|
82
179
|
app = get_app()
|
|
@@ -110,6 +207,17 @@ class FastApiFrontEndPlugin(FrontEndBase[FastApiFrontEndConfig]):
|
|
|
110
207
|
StandaloneApplication(app, options=options).run()
|
|
111
208
|
|
|
112
209
|
finally:
|
|
210
|
+
logger.debug("Shutting down")
|
|
211
|
+
if self._periodic_cleanup_future is not None:
|
|
212
|
+
logger.info("Cancelling periodic cleanup task.")
|
|
213
|
+
# Use the scheduler address, because self._cluster is None if an external cluster is used
|
|
214
|
+
async with self.client(self._scheduler_address) as client:
|
|
215
|
+
await client.cancel([self._periodic_cleanup_future], asynchronous=True, force=True)
|
|
216
|
+
|
|
217
|
+
if self._cluster is not None:
|
|
218
|
+
# Only shut down the cluster if we created it
|
|
219
|
+
logger.info("Closing Local Dask cluster.")
|
|
220
|
+
self._cluster.close()
|
|
113
221
|
try:
|
|
114
222
|
os.remove(config_file_name)
|
|
115
223
|
except OSError as e:
|