nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
nat/eval/remote_workflow.py
CHANGED
|
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
74
74
|
if chunk_data.get("value"):
|
|
75
75
|
final_response = chunk_data.get("value")
|
|
76
76
|
except json.JSONDecodeError as e:
|
|
77
|
-
logger.
|
|
77
|
+
logger.exception("Failed to parse generate response chunk: %s", e)
|
|
78
78
|
continue
|
|
79
79
|
elif line.startswith(INTERMEDIATE_DATA_PREFIX):
|
|
80
80
|
# This is an intermediate step
|
|
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
90
90
|
payload=payload)
|
|
91
91
|
intermediate_steps.append(intermediate_step)
|
|
92
92
|
except (json.JSONDecodeError, ValidationError) as e:
|
|
93
|
-
logger.
|
|
93
|
+
logger.exception("Failed to parse intermediate step: %s", e)
|
|
94
94
|
continue
|
|
95
95
|
|
|
96
96
|
except aiohttp.ClientError as e:
|
|
97
97
|
# Handle connection or HTTP-related errors
|
|
98
|
-
logger.
|
|
98
|
+
logger.exception("Request failed for question %s: %s", question, e)
|
|
99
99
|
item.output_obj = None
|
|
100
100
|
item.trajectory = []
|
|
101
101
|
return
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
22
|
+
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
24
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
25
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _CallTiming:
|
|
30
|
+
start_ts: float | None = None
|
|
31
|
+
end_ts: float | None = None
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def latency(self) -> float | None:
|
|
35
|
+
if self.start_ts is None or self.end_ts is None:
|
|
36
|
+
return None
|
|
37
|
+
return max(0.0, self.end_ts - self.start_ts)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AverageLLMLatencyEvaluator(BaseEvaluator):
|
|
41
|
+
"""
|
|
42
|
+
Mean difference between connected LLM_START and LLM_END events (same UUID).
|
|
43
|
+
The score is the average latency in seconds for the item. Reasoning contains per-call latencies.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, max_concurrency: int = 8):
|
|
47
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg LLM Latency")
|
|
48
|
+
|
|
49
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
50
|
+
calls: dict[str, _CallTiming] = defaultdict(_CallTiming)
|
|
51
|
+
|
|
52
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
53
|
+
if step.event_type == IntermediateStepType.LLM_START:
|
|
54
|
+
calls[step.UUID].start_ts = step.event_timestamp
|
|
55
|
+
elif step.event_type == IntermediateStepType.LLM_END:
|
|
56
|
+
calls[step.UUID].end_ts = step.event_timestamp
|
|
57
|
+
|
|
58
|
+
latencies = [ct.latency for ct in calls.values() if ct.latency is not None]
|
|
59
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
|
|
60
|
+
|
|
61
|
+
reasoning = {
|
|
62
|
+
"num_llm_calls": len(latencies),
|
|
63
|
+
"latencies": latencies,
|
|
64
|
+
}
|
|
65
|
+
return EvalOutputItem(id=item.id, score=round(avg_latency, 4), reasoning=reasoning)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AverageWorkflowRuntimeEvaluator(BaseEvaluator):
|
|
69
|
+
"""
|
|
70
|
+
Average workflow runtime per item: max(event_timestamp) - min(event_timestamp) across the trajectory.
|
|
71
|
+
The score is the runtime in seconds for the item.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, max_concurrency: int = 8):
|
|
75
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Workflow Runtime")
|
|
76
|
+
|
|
77
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
78
|
+
if not item.trajectory:
|
|
79
|
+
return EvalOutputItem(id=item.id, score=0.0, reasoning={"note": "no steps"})
|
|
80
|
+
|
|
81
|
+
timestamps = [s.event_timestamp for s in item.trajectory]
|
|
82
|
+
runtime = max(timestamps) - min(timestamps)
|
|
83
|
+
return EvalOutputItem(id=item.id, score=round(max(0.0, runtime), 4), reasoning={"steps": len(timestamps)})
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AverageNumberOfLLMCallsEvaluator(BaseEvaluator):
|
|
87
|
+
"""
|
|
88
|
+
Average number of LLM calls per item. The score is the count for the item.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, max_concurrency: int = 8):
|
|
92
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg # LLM Calls")
|
|
93
|
+
|
|
94
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
95
|
+
num_calls = sum(1 for s in item.trajectory if s.event_type == IntermediateStepType.LLM_END)
|
|
96
|
+
return EvalOutputItem(id=item.id, score=float(num_calls), reasoning={"num_llm_end": num_calls})
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class AverageTokensPerLLMEndEvaluator(BaseEvaluator):
|
|
100
|
+
"""
|
|
101
|
+
Average total tokens per LLM_END event: sum of prompt and completion tokens if available.
|
|
102
|
+
The score is the average tokens per LLM_END for the item (0 if none).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, max_concurrency: int = 8):
|
|
106
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Tokens/LLM_END")
|
|
107
|
+
|
|
108
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
109
|
+
totals: list[int] = []
|
|
110
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
111
|
+
if step.event_type == IntermediateStepType.LLM_END:
|
|
112
|
+
total_tokens = step.token_usage.total_tokens
|
|
113
|
+
# If framework doesn't set total, compute from prompt+completion
|
|
114
|
+
if total_tokens == 0:
|
|
115
|
+
total_tokens = step.token_usage.prompt_tokens + step.token_usage.completion_tokens
|
|
116
|
+
totals.append(total_tokens)
|
|
117
|
+
|
|
118
|
+
avg_tokens = (sum(totals) / len(totals)) if totals else 0.0
|
|
119
|
+
reasoning = {
|
|
120
|
+
"num_llm_end": len(totals),
|
|
121
|
+
"totals": totals,
|
|
122
|
+
}
|
|
123
|
+
return EvalOutputItem(id=item.id, score=round(avg_tokens, 2), reasoning=reasoning)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import EvalBuilder
|
|
19
|
+
from nat.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from nat.cli.register_workflow import register_evaluator
|
|
21
|
+
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
22
|
+
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalOutput
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AverageLLMLatencyConfig(EvaluatorBaseConfig, name="avg_llm_latency"):
|
|
27
|
+
"""Mean difference between connected LLM_START and LLM_END events (same UUID)."""
|
|
28
|
+
|
|
29
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AverageWorkflowRuntimeConfig(EvaluatorBaseConfig, name="avg_workflow_runtime"):
|
|
33
|
+
"""Average workflow runtime per item (max timestamp - min timestamp)."""
|
|
34
|
+
|
|
35
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AverageNumberOfLLMCallsConfig(EvaluatorBaseConfig, name="avg_num_llm_calls"):
|
|
39
|
+
"""Average number of LLM calls per item (count of LLM_END)."""
|
|
40
|
+
|
|
41
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AverageTokensPerLLMEndConfig(EvaluatorBaseConfig, name="avg_tokens_per_llm_end"):
|
|
45
|
+
"""Average total tokens per LLM_END event (prompt + completion if available)."""
|
|
46
|
+
|
|
47
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_evaluator(config_type=AverageLLMLatencyConfig)
|
|
51
|
+
async def register_avg_llm_latency_evaluator(config: AverageLLMLatencyConfig, builder: EvalBuilder):
|
|
52
|
+
from .evaluate import AverageLLMLatencyEvaluator
|
|
53
|
+
|
|
54
|
+
evaluator = AverageLLMLatencyEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
55
|
+
|
|
56
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
57
|
+
return await evaluator.evaluate(eval_input)
|
|
58
|
+
|
|
59
|
+
yield EvaluatorInfo(config=config,
|
|
60
|
+
evaluate_fn=evaluate_fn,
|
|
61
|
+
description="Average LLM latency (s) from LLM_START to LLM_END")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@register_evaluator(config_type=AverageWorkflowRuntimeConfig)
|
|
65
|
+
async def register_avg_workflow_runtime_evaluator(config: AverageWorkflowRuntimeConfig, builder: EvalBuilder):
|
|
66
|
+
from .evaluate import AverageWorkflowRuntimeEvaluator
|
|
67
|
+
|
|
68
|
+
evaluator = AverageWorkflowRuntimeEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
69
|
+
|
|
70
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
71
|
+
return await evaluator.evaluate(eval_input)
|
|
72
|
+
|
|
73
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average workflow runtime (s)")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@register_evaluator(config_type=AverageNumberOfLLMCallsConfig)
|
|
77
|
+
async def register_avg_num_llm_calls_evaluator(config: AverageNumberOfLLMCallsConfig, builder: EvalBuilder):
|
|
78
|
+
from .evaluate import AverageNumberOfLLMCallsEvaluator
|
|
79
|
+
|
|
80
|
+
evaluator = AverageNumberOfLLMCallsEvaluator(
|
|
81
|
+
max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
82
|
+
|
|
83
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
84
|
+
return await evaluator.evaluate(eval_input)
|
|
85
|
+
|
|
86
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average number of LLM calls")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@register_evaluator(config_type=AverageTokensPerLLMEndConfig)
|
|
90
|
+
async def register_avg_tokens_per_llm_end_evaluator(config: AverageTokensPerLLMEndConfig, builder: EvalBuilder):
|
|
91
|
+
from .evaluate import AverageTokensPerLLMEndEvaluator
|
|
92
|
+
|
|
93
|
+
evaluator = AverageTokensPerLLMEndEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
94
|
+
|
|
95
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
96
|
+
return await evaluator.evaluate(eval_input)
|
|
97
|
+
|
|
98
|
+
yield EvaluatorInfo(config=config,
|
|
99
|
+
evaluate_fn=evaluate_fn,
|
|
100
|
+
description="Average total tokens per LLM_END (prompt + completion)")
|
|
@@ -69,13 +69,13 @@ class SweBenchEvaluator:
|
|
|
69
69
|
try:
|
|
70
70
|
shutil.move(swe_bench_report_file, report_dir)
|
|
71
71
|
except Exception as e:
|
|
72
|
-
logger.exception("Error moving report file: %s", e
|
|
72
|
+
logger.exception("Error moving report file: %s", e)
|
|
73
73
|
|
|
74
74
|
try:
|
|
75
75
|
dest_logs_dir = os.path.join(report_dir, 'logs')
|
|
76
76
|
shutil.move(logs_dir, dest_logs_dir)
|
|
77
77
|
except Exception as e:
|
|
78
|
-
logger.exception("Error moving logs directory: %s", e
|
|
78
|
+
logger.exception("Error moving logs directory: %s", e)
|
|
79
79
|
|
|
80
80
|
def is_repo_supported(self, repo: str, version: str) -> bool:
|
|
81
81
|
"""Check if the repo is supported by swebench"""
|
|
@@ -106,7 +106,7 @@ class SweBenchEvaluator:
|
|
|
106
106
|
self._model_name_or_path = swebench_output.model_name_or_path
|
|
107
107
|
|
|
108
108
|
except Exception as e:
|
|
109
|
-
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e
|
|
109
|
+
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e)
|
|
110
110
|
|
|
111
111
|
# Filter out repos/version not supported by SWEBench
|
|
112
112
|
supported_inputs = [
|
|
@@ -114,7 +114,7 @@ class SweBenchEvaluator:
|
|
|
114
114
|
]
|
|
115
115
|
|
|
116
116
|
if not supported_inputs:
|
|
117
|
-
logger.
|
|
117
|
+
logger.exception("No supported instances; nothing to evaluate")
|
|
118
118
|
return None, None
|
|
119
119
|
|
|
120
120
|
if len(supported_inputs) < len(swebench_inputs):
|
|
@@ -135,7 +135,7 @@ class SweBenchEvaluator:
|
|
|
135
135
|
filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
|
|
136
136
|
|
|
137
137
|
if not filtered_outputs:
|
|
138
|
-
logger.error("No supported outputs; nothing to evaluate")
|
|
138
|
+
logger.error("No supported outputs; nothing to evaluate", exc_info=True)
|
|
139
139
|
return None, None
|
|
140
140
|
|
|
141
141
|
# Write SWEBenchOutput to file
|
|
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
|
|
|
204
204
|
# if report file is not present, return empty EvalOutput
|
|
205
205
|
avg_score = 0.0
|
|
206
206
|
if report_file.exists():
|
|
207
|
-
with open(report_file,
|
|
207
|
+
with open(report_file, encoding="utf-8") as f:
|
|
208
208
|
report = json.load(f)
|
|
209
209
|
resolved_instances = report.get("resolved_instances", 0)
|
|
210
210
|
total_instances = report.get("total_instances", 0)
|
|
@@ -65,7 +65,7 @@ class TrajectoryEvaluator(BaseEvaluator):
|
|
|
65
65
|
prediction=generated_answer,
|
|
66
66
|
)
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e
|
|
68
|
+
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e)
|
|
69
69
|
return EvalOutputItem(id=item.id, score=0.0, reasoning=f"Error evaluating trajectory: {e}")
|
|
70
70
|
|
|
71
71
|
reasoning = {
|
|
@@ -33,7 +33,7 @@ async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, build
|
|
|
33
33
|
|
|
34
34
|
from .evaluate import TrajectoryEvaluator
|
|
35
35
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
-
tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
+
tools = await builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
37
37
|
|
|
38
38
|
_evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
|
|
39
39
|
|
|
@@ -13,9 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import asyncio
|
|
17
16
|
import logging
|
|
18
|
-
from
|
|
17
|
+
from collections.abc import Callable
|
|
19
18
|
|
|
20
19
|
from langchain.output_parsers import ResponseSchema
|
|
21
20
|
from langchain.output_parsers import StructuredOutputParser
|
|
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
|
|
|
23
22
|
from langchain.schema import SystemMessage
|
|
24
23
|
from langchain_core.language_models import BaseChatModel
|
|
25
24
|
from langchain_core.runnables import RunnableLambda
|
|
26
|
-
from tqdm import tqdm
|
|
27
25
|
|
|
28
26
|
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
29
27
|
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
|
31
29
|
|
|
32
30
|
logger = logging.getLogger(__name__)
|
|
33
31
|
|
|
34
|
-
# pylint: disable=line-too-long
|
|
35
32
|
# flake8: noqa: E501
|
|
36
33
|
|
|
37
34
|
|
|
@@ -185,8 +182,8 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
185
182
|
relevance_score = parsed_response["relevance_score"]
|
|
186
183
|
reasoning = parsed_response["reasoning"]
|
|
187
184
|
except KeyError as e:
|
|
188
|
-
logger.
|
|
189
|
-
|
|
185
|
+
logger.exception("Missing required keys in default scoring response: %s",
|
|
186
|
+
", ".join(str(arg) for arg in e.args))
|
|
190
187
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
191
188
|
|
|
192
189
|
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
@@ -218,7 +215,7 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
218
215
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
219
216
|
raise
|
|
220
217
|
except (KeyError, ValueError) as e:
|
|
221
|
-
logger.
|
|
218
|
+
logger.exception("Error parsing judge LLM response: %s", e)
|
|
222
219
|
score = 0.0
|
|
223
220
|
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
224
221
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Type alias for evaluation call objects that have an optional 'id' attribute
|
|
24
|
+
EvalCallType = Any # Could be Weave Call object or other tracing framework objects
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EvalTraceContext:
|
|
28
|
+
"""
|
|
29
|
+
Evaluation trace context manager for coordinating traces.
|
|
30
|
+
|
|
31
|
+
This class provides a framework-agnostic way to:
|
|
32
|
+
1. Track evaluation calls/contexts
|
|
33
|
+
2. Ensure proper parent-child relationships in traces
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.eval_call: EvalCallType | None = None # Store the evaluation call/context for propagation
|
|
38
|
+
|
|
39
|
+
def set_eval_call(self, eval_call: EvalCallType | None) -> None:
|
|
40
|
+
"""Set the evaluation call/context for propagation to traces."""
|
|
41
|
+
self.eval_call = eval_call
|
|
42
|
+
if eval_call:
|
|
43
|
+
logger.debug("Set evaluation call context: %s", getattr(eval_call, 'id', str(eval_call)))
|
|
44
|
+
|
|
45
|
+
def get_eval_call(self) -> EvalCallType | None:
|
|
46
|
+
"""Get the current evaluation call/context."""
|
|
47
|
+
return self.eval_call
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def evaluation_context(self):
|
|
51
|
+
"""
|
|
52
|
+
Context manager that can be overridden by framework-specific implementations.
|
|
53
|
+
Default implementation is a no-op.
|
|
54
|
+
"""
|
|
55
|
+
yield
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WeaveEvalTraceContext(EvalTraceContext):
|
|
59
|
+
"""
|
|
60
|
+
Weave-specific implementation of evaluation trace context.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.available = False
|
|
66
|
+
self.set_call_stack: Callable[[list[EvalCallType]], Any] | None = None
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from weave.trace.context.call_context import set_call_stack
|
|
70
|
+
self.set_call_stack = set_call_stack
|
|
71
|
+
self.available = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
self.available = False
|
|
74
|
+
logger.debug("Weave not available for trace context")
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def evaluation_context(self):
|
|
78
|
+
"""Set the evaluation call as active context for Weave traces."""
|
|
79
|
+
if self.available and self.eval_call and self.set_call_stack:
|
|
80
|
+
try:
|
|
81
|
+
with self.set_call_stack([self.eval_call]):
|
|
82
|
+
logger.debug("Set Weave evaluation call context: %s",
|
|
83
|
+
getattr(self.eval_call, 'id', str(self.eval_call)))
|
|
84
|
+
yield
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.warning("Failed to set Weave evaluation call context: %s", e)
|
|
87
|
+
yield
|
|
88
|
+
else:
|
|
89
|
+
yield
|
nat/eval/utils/weave_eval.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
@@ -24,26 +25,28 @@ from nat.eval.usage_stats import UsageStats
|
|
|
24
25
|
from nat.eval.usage_stats import UsageStatsItem
|
|
25
26
|
from nat.profiler.data_models import ProfilerResults
|
|
26
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
30
|
+
|
|
27
31
|
logger = logging.getLogger(__name__)
|
|
28
32
|
|
|
29
33
|
|
|
30
|
-
class WeaveEvaluationIntegration:
|
|
34
|
+
class WeaveEvaluationIntegration:
|
|
31
35
|
"""
|
|
32
36
|
Class to handle all Weave integration functionality.
|
|
33
37
|
"""
|
|
34
38
|
|
|
35
|
-
def __init__(self):
|
|
39
|
+
def __init__(self, eval_trace_context: "EvalTraceContext"):
|
|
36
40
|
self.available = False
|
|
37
41
|
self.client = None
|
|
38
42
|
self.eval_logger = None
|
|
39
43
|
self.pred_loggers = {}
|
|
44
|
+
self.eval_trace_context = eval_trace_context
|
|
40
45
|
|
|
41
46
|
try:
|
|
42
|
-
from weave
|
|
43
|
-
from weave.flow.eval_imperative import ScoreLogger
|
|
47
|
+
from weave import EvaluationLogger
|
|
44
48
|
from weave.trace.context import weave_client_context
|
|
45
|
-
self.
|
|
46
|
-
self.ScoreLogger = ScoreLogger
|
|
49
|
+
self.evaluation_logger_cls = EvaluationLogger
|
|
47
50
|
self.weave_client_context = weave_client_context
|
|
48
51
|
self.available = True
|
|
49
52
|
except ImportError:
|
|
@@ -89,9 +92,15 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
89
92
|
weave_dataset = self._get_weave_dataset(eval_input)
|
|
90
93
|
config_dict = config.model_dump(mode="json")
|
|
91
94
|
config_dict["name"] = workflow_alias
|
|
92
|
-
self.eval_logger = self.
|
|
95
|
+
self.eval_logger = self.evaluation_logger_cls(model=config_dict,
|
|
96
|
+
dataset=weave_dataset,
|
|
97
|
+
name=workflow_alias,
|
|
98
|
+
eval_attributes={})
|
|
93
99
|
self.pred_loggers = {}
|
|
94
100
|
|
|
101
|
+
# Capture the current evaluation call for context propagation
|
|
102
|
+
self.eval_trace_context.set_eval_call(self.eval_logger._evaluate_call)
|
|
103
|
+
|
|
95
104
|
return True
|
|
96
105
|
except Exception as e:
|
|
97
106
|
self.eval_logger = None
|
|
@@ -137,7 +146,7 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
137
146
|
await asyncio.gather(*coros)
|
|
138
147
|
|
|
139
148
|
async def afinish_loggers(self):
|
|
140
|
-
"""Finish all prediction loggers."""
|
|
149
|
+
"""Finish all prediction loggers and wait for exports."""
|
|
141
150
|
if not self.eval_logger:
|
|
142
151
|
return
|
|
143
152
|
|
|
@@ -157,7 +166,6 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
157
166
|
if profiler_results.workflow_runtime_metrics:
|
|
158
167
|
profile_metrics["wf_runtime_p95"] = profiler_results.workflow_runtime_metrics.p95
|
|
159
168
|
|
|
160
|
-
# TODO:get the LLM tokens from the usage stats and log them
|
|
161
169
|
profile_metrics["total_runtime"] = usage_stats.total_runtime
|
|
162
170
|
|
|
163
171
|
return profile_metrics
|
|
@@ -182,3 +190,4 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
182
190
|
# Log the summary to finish the evaluation, disable auto-summarize
|
|
183
191
|
# as we will be adding profiler metrics to the summary
|
|
184
192
|
self.eval_logger.log_summary(summary, auto_summarize=False)
|
|
193
|
+
logger.info("Logged Evaluation Summary to Weave")
|
|
@@ -16,7 +16,12 @@
|
|
|
16
16
|
import functools
|
|
17
17
|
import inspect
|
|
18
18
|
import logging
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from collections.abc import Generator
|
|
19
22
|
from typing import Any
|
|
23
|
+
from typing import TypeVar
|
|
24
|
+
from typing import overload
|
|
20
25
|
|
|
21
26
|
logger = logging.getLogger(__name__)
|
|
22
27
|
|
|
@@ -25,6 +30,9 @@ BASE_WARNING_MESSAGE = ("is experimental and the API may change in future releas
|
|
|
25
30
|
|
|
26
31
|
_warning_issued = set()
|
|
27
32
|
|
|
33
|
+
# Type variables for overloads
|
|
34
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
35
|
+
|
|
28
36
|
|
|
29
37
|
def issue_experimental_warning(function_name: str,
|
|
30
38
|
feature_name: str | None = None,
|
|
@@ -53,7 +61,20 @@ def issue_experimental_warning(function_name: str,
|
|
|
53
61
|
_warning_issued.add(function_name)
|
|
54
62
|
|
|
55
63
|
|
|
56
|
-
|
|
64
|
+
# Overloads for different function types
|
|
65
|
+
@overload
|
|
66
|
+
def experimental(func: F, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
67
|
+
"""Overload for when a function is passed directly."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@overload
|
|
72
|
+
def experimental(*, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
73
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Any:
|
|
57
78
|
"""
|
|
58
79
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
59
80
|
async generator) and logs a warning that the function is experimental.
|
|
@@ -90,7 +111,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
90
111
|
# ---------------------
|
|
91
112
|
|
|
92
113
|
@functools.wraps(func)
|
|
93
|
-
async def async_gen_wrapper(*args, **kwargs):
|
|
114
|
+
async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]:
|
|
94
115
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
95
116
|
async for item in func(*args, **kwargs):
|
|
96
117
|
yield item # yield the original item
|
|
@@ -102,7 +123,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
102
123
|
# ASYNC FUNCTION
|
|
103
124
|
# ---------------------
|
|
104
125
|
@functools.wraps(func)
|
|
105
|
-
async def async_wrapper(*args, **kwargs):
|
|
126
|
+
async def async_wrapper(*args, **kwargs) -> Any:
|
|
106
127
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
107
128
|
result = await func(*args, **kwargs)
|
|
108
129
|
return result
|
|
@@ -114,15 +135,14 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
114
135
|
# SYNC GENERATOR
|
|
115
136
|
# ---------------------
|
|
116
137
|
@functools.wraps(func)
|
|
117
|
-
def sync_gen_wrapper(*args, **kwargs):
|
|
138
|
+
def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
|
|
118
139
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
119
|
-
|
|
120
|
-
yield item # yield the original item
|
|
140
|
+
yield from func(*args, **kwargs) # yield the original item
|
|
121
141
|
|
|
122
142
|
return sync_gen_wrapper
|
|
123
143
|
|
|
124
144
|
@functools.wraps(func)
|
|
125
|
-
def sync_wrapper(*args, **kwargs):
|
|
145
|
+
def sync_wrapper(*args, **kwargs) -> Any:
|
|
126
146
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
127
147
|
result = func(*args, **kwargs)
|
|
128
148
|
return result
|
|
@@ -86,7 +86,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
86
86
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
87
87
|
|
|
88
88
|
# Get the augmented function's description
|
|
89
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
89
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
90
90
|
|
|
91
91
|
# For now, we rely on runtime checking for type conversion
|
|
92
92
|
|
|
@@ -97,11 +97,15 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
97
97
|
f"function without a description.")
|
|
98
98
|
|
|
99
99
|
# Get the function dependencies of the augmented function
|
|
100
|
-
|
|
100
|
+
function_dependencies = builder.get_function_dependencies(config.augmented_fn)
|
|
101
|
+
function_used_tools = set(function_dependencies.functions)
|
|
102
|
+
for function_group in function_dependencies.function_groups:
|
|
103
|
+
function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
|
|
104
|
+
|
|
101
105
|
tool_list = "Tool: Description\n"
|
|
102
106
|
|
|
103
107
|
for tool in function_used_tools:
|
|
104
|
-
tool_impl = builder.get_function(tool)
|
|
108
|
+
tool_impl = await builder.get_function(tool)
|
|
105
109
|
tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
|
|
106
110
|
|
|
107
111
|
# Draft the reasoning prompt for the augmented function
|