nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
nat/eval/evaluate.py
CHANGED
|
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
|
|
|
42
42
|
logger = logging.getLogger(__name__)
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
class EvaluationRun:
|
|
45
|
+
class EvaluationRun:
|
|
46
46
|
"""
|
|
47
47
|
Instantiated for each evaluation run and used to store data for that single run.
|
|
48
48
|
|
|
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
63
63
|
|
|
64
64
|
# Helpers
|
|
65
65
|
self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
|
|
66
|
-
|
|
66
|
+
|
|
67
|
+
# Create evaluation trace context
|
|
68
|
+
try:
|
|
69
|
+
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
|
|
70
|
+
self.eval_trace_context = WeaveEvalTraceContext()
|
|
71
|
+
except Exception:
|
|
72
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
73
|
+
self.eval_trace_context = EvalTraceContext()
|
|
74
|
+
|
|
75
|
+
self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
|
|
67
76
|
# Metadata
|
|
68
77
|
self.eval_input: EvalInput | None = None
|
|
69
78
|
self.workflow_interrupted: bool = False
|
|
@@ -159,17 +168,17 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
159
168
|
intermediate_future = None
|
|
160
169
|
|
|
161
170
|
try:
|
|
162
|
-
|
|
163
171
|
# Start usage stats and intermediate steps collection in parallel
|
|
164
172
|
intermediate_future = pull_intermediate()
|
|
165
173
|
runner_result = runner.result()
|
|
166
174
|
base_output = await runner_result
|
|
167
175
|
intermediate_steps = await intermediate_future
|
|
168
176
|
except NotImplementedError as e:
|
|
177
|
+
logger.error("Failed to run the workflow: %s", e)
|
|
169
178
|
# raise original error
|
|
170
|
-
raise
|
|
179
|
+
raise
|
|
171
180
|
except Exception as e:
|
|
172
|
-
logger.exception("Failed to run the workflow: %s", e
|
|
181
|
+
logger.exception("Failed to run the workflow: %s", e)
|
|
173
182
|
# stop processing if a workflow error occurs
|
|
174
183
|
self.workflow_interrupted = True
|
|
175
184
|
|
|
@@ -308,9 +317,9 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
308
317
|
logger.info("Deleting old job directory: %s", dir_to_delete)
|
|
309
318
|
shutil.rmtree(dir_to_delete)
|
|
310
319
|
except Exception as e:
|
|
311
|
-
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e
|
|
320
|
+
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e)
|
|
312
321
|
|
|
313
|
-
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
322
|
+
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
314
323
|
workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
|
|
315
324
|
workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
316
325
|
|
|
@@ -358,7 +367,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
358
367
|
|
|
359
368
|
await self.weave_eval.alog_score(eval_output, evaluator_name)
|
|
360
369
|
except Exception as e:
|
|
361
|
-
logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e
|
|
370
|
+
logger.exception("An error occurred while running evaluator %s: %s", evaluator_name, e)
|
|
362
371
|
|
|
363
372
|
async def run_evaluators(self, evaluators: dict[str, Any]):
|
|
364
373
|
"""Run all configured evaluators asynchronously."""
|
|
@@ -371,7 +380,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
371
380
|
try:
|
|
372
381
|
await asyncio.gather(*tasks)
|
|
373
382
|
except Exception as e:
|
|
374
|
-
logger.
|
|
383
|
+
logger.error("An error occurred while running evaluators: %s", e)
|
|
375
384
|
raise
|
|
376
385
|
finally:
|
|
377
386
|
# Finish prediction loggers in Weave
|
|
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
401
410
|
|
|
402
411
|
return workflow_type
|
|
403
412
|
|
|
413
|
+
async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
|
|
414
|
+
"""Wait for all trace export tasks to complete for local workflows.
|
|
415
|
+
|
|
416
|
+
This only works for local workflows where we have direct access to the
|
|
417
|
+
SessionManager and its underlying workflow with exporter manager.
|
|
418
|
+
"""
|
|
419
|
+
try:
|
|
420
|
+
workflow = session_manager.workflow
|
|
421
|
+
all_exporters = await workflow.get_all_exporters()
|
|
422
|
+
if not all_exporters:
|
|
423
|
+
logger.debug("No exporters to wait for")
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
|
|
427
|
+
|
|
428
|
+
for name, exporter in all_exporters.items():
|
|
429
|
+
try:
|
|
430
|
+
await exporter.wait_for_tasks(timeout=timeout)
|
|
431
|
+
logger.info("Export tasks completed for exporter: %s", name)
|
|
432
|
+
except Exception as e:
|
|
433
|
+
logger.warning("Error waiting for export tasks from %s: %s", name, e)
|
|
434
|
+
|
|
435
|
+
logger.info("All local export task waiting completed")
|
|
436
|
+
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.warning("Failed to wait for local export tasks: %s", e)
|
|
439
|
+
|
|
404
440
|
async def run_and_evaluate(self,
|
|
405
441
|
session_manager: SessionManager | None = None,
|
|
406
442
|
job_id: str | None = None) -> EvaluationRunOutput:
|
|
@@ -413,10 +449,14 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
413
449
|
from nat.runtime.loader import load_config
|
|
414
450
|
|
|
415
451
|
# Load and override the config
|
|
416
|
-
|
|
452
|
+
config = None
|
|
453
|
+
if isinstance(self.config.config_file, BaseModel):
|
|
454
|
+
config = self.config.config_file
|
|
455
|
+
elif self.config.override:
|
|
417
456
|
config = self.apply_overrides()
|
|
418
457
|
else:
|
|
419
458
|
config = load_config(self.config.config_file)
|
|
459
|
+
|
|
420
460
|
self.eval_config = config.eval
|
|
421
461
|
workflow_alias = self._get_workflow_alias(config.workflow.type)
|
|
422
462
|
logger.debug("Loaded %s evaluation configuration: %s", workflow_alias, self.eval_config)
|
|
@@ -442,44 +482,59 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
442
482
|
dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
|
|
443
483
|
if not dataset_config:
|
|
444
484
|
logger.info("No dataset found, nothing to evaluate")
|
|
445
|
-
return EvaluationRunOutput(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
485
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
486
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
487
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
488
|
+
eval_input=EvalInput(eval_input_items=[]),
|
|
489
|
+
evaluation_results=[],
|
|
490
|
+
usage_stats=UsageStats(),
|
|
491
|
+
profiler_results=ProfilerResults())
|
|
492
|
+
|
|
493
|
+
custom_pre_eval_process_function = self.eval_config.general.output.custom_pre_eval_process_function \
|
|
494
|
+
if self.eval_config.general.output else None
|
|
451
495
|
dataset_handler = DatasetHandler(dataset_config=dataset_config,
|
|
452
496
|
reps=self.config.reps,
|
|
453
497
|
concurrency=self.eval_config.general.max_concurrency,
|
|
454
498
|
num_passes=self.config.num_passes,
|
|
455
|
-
adjust_dataset_size=self.config.adjust_dataset_size
|
|
499
|
+
adjust_dataset_size=self.config.adjust_dataset_size,
|
|
500
|
+
custom_pre_eval_process_function=custom_pre_eval_process_function)
|
|
456
501
|
self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
|
|
457
502
|
if not self.eval_input.eval_input_items:
|
|
458
503
|
logger.info("Dataset is empty. Nothing to evaluate.")
|
|
459
|
-
return EvaluationRunOutput(
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
504
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
505
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
506
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
507
|
+
eval_input=self.eval_input,
|
|
508
|
+
evaluation_results=self.evaluation_results,
|
|
509
|
+
usage_stats=self.usage_stats,
|
|
510
|
+
profiler_results=ProfilerResults())
|
|
464
511
|
|
|
465
512
|
# Run workflow and evaluate
|
|
466
513
|
async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
|
|
467
514
|
# Initialize Weave integration
|
|
468
515
|
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
|
|
469
516
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
517
|
+
with self.eval_trace_context.evaluation_context():
|
|
518
|
+
# Run workflow
|
|
519
|
+
if self.config.endpoint:
|
|
520
|
+
await self.run_workflow_remote()
|
|
521
|
+
elif not self.config.skip_workflow:
|
|
475
522
|
if session_manager is None:
|
|
476
|
-
|
|
523
|
+
workflow = await eval_workflow.build()
|
|
524
|
+
session_manager = SessionManager(workflow,
|
|
477
525
|
max_concurrency=self.eval_config.general.max_concurrency)
|
|
478
526
|
await self.run_workflow_local(session_manager)
|
|
479
527
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
528
|
+
# Pre-evaluation process the workflow output
|
|
529
|
+
self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
|
|
530
|
+
|
|
531
|
+
# Evaluate
|
|
532
|
+
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
|
|
533
|
+
await self.run_evaluators(evaluators)
|
|
534
|
+
|
|
535
|
+
# Wait for all trace export tasks to complete (local workflows only)
|
|
536
|
+
if session_manager and not self.config.endpoint:
|
|
537
|
+
await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
|
|
483
538
|
|
|
484
539
|
# Profile the workflow
|
|
485
540
|
profiler_results = await self.profile_workflow()
|
|
@@ -71,7 +71,7 @@ class BaseEvaluator(ABC):
|
|
|
71
71
|
TqdmPositionRegistry.release(tqdm_position)
|
|
72
72
|
|
|
73
73
|
# Compute average if possible
|
|
74
|
-
numeric_scores = [item.score for item in output_items if isinstance(item.score,
|
|
74
|
+
numeric_scores = [item.score for item in output_items if isinstance(item.score, int | float)]
|
|
75
75
|
avg_score = round(sum(numeric_scores) / len(numeric_scores), 2) if numeric_scores else None
|
|
76
76
|
|
|
77
77
|
return EvalOutput(average_score=avg_score, eval_output_items=output_items)
|
|
@@ -29,6 +29,19 @@ class EvalInputItem(BaseModel):
|
|
|
29
29
|
trajectory: list[IntermediateStep] = [] # populated by the workflow
|
|
30
30
|
full_dataset_entry: typing.Any
|
|
31
31
|
|
|
32
|
+
def copy_with_updates(self, **updates) -> "EvalInputItem":
|
|
33
|
+
"""
|
|
34
|
+
Copy EvalInputItem with optional field updates.
|
|
35
|
+
"""
|
|
36
|
+
# Get all current fields
|
|
37
|
+
item_data = self.model_dump()
|
|
38
|
+
|
|
39
|
+
# Apply any updates
|
|
40
|
+
item_data.update(updates)
|
|
41
|
+
|
|
42
|
+
# Create new item with all fields
|
|
43
|
+
return EvalInputItem(**item_data)
|
|
44
|
+
|
|
32
45
|
|
|
33
46
|
class EvalInput(BaseModel):
|
|
34
47
|
eval_input_items: list[EvalInputItem]
|
|
@@ -40,7 +40,7 @@ class IntermediateStepAdapter:
|
|
|
40
40
|
try:
|
|
41
41
|
validated_steps.append(IntermediateStep.model_validate(step_data))
|
|
42
42
|
except Exception as e:
|
|
43
|
-
logger.exception("Validation failed for step: %r, Error: %s", step_data, e
|
|
43
|
+
logger.exception("Validation failed for step: %r, Error: %s", step_data, e)
|
|
44
44
|
return validated_steps
|
|
45
45
|
|
|
46
46
|
def serialize_intermediate_steps(self, intermediate_steps: list[IntermediateStep]) -> list[dict]:
|
|
@@ -102,7 +102,7 @@ class RAGEvaluator:
|
|
|
102
102
|
"""Converts the ragas EvaluationResult to nat EvalOutput"""
|
|
103
103
|
|
|
104
104
|
if not results_dataset:
|
|
105
|
-
logger.error("Ragas evaluation failed with no results")
|
|
105
|
+
logger.error("Ragas evaluation failed with no results", exc_info=True)
|
|
106
106
|
return EvalOutput(average_score=0.0, eval_output_items=[])
|
|
107
107
|
|
|
108
108
|
scores: list[dict[str, float]] = results_dataset.scores
|
|
@@ -169,7 +169,7 @@ class RAGEvaluator:
|
|
|
169
169
|
_pbar=pbar)
|
|
170
170
|
except Exception as e:
|
|
171
171
|
# On exception we still continue with other evaluators. Log and return an avg_score of 0.0
|
|
172
|
-
logger.exception("Error evaluating ragas metric, Error: %s", e
|
|
172
|
+
logger.exception("Error evaluating ragas metric, Error: %s", e)
|
|
173
173
|
results_dataset = None
|
|
174
174
|
finally:
|
|
175
175
|
pbar.close()
|
|
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
73
73
|
if isinstance(self.metric, str):
|
|
74
74
|
return self.metric
|
|
75
75
|
if isinstance(self.metric, dict) and self.metric:
|
|
76
|
-
return next(iter(self.metric.keys()))
|
|
76
|
+
return next(iter(self.metric.keys()))
|
|
77
77
|
return ""
|
|
78
78
|
|
|
79
79
|
@property
|
|
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
82
82
|
if isinstance(self.metric, str):
|
|
83
83
|
return RagasMetricConfig() # Default config when only a metric name is given
|
|
84
84
|
if isinstance(self.metric, dict) and self.metric:
|
|
85
|
-
return next(iter(self.metric.values()))
|
|
85
|
+
return next(iter(self.metric.values()))
|
|
86
86
|
return RagasMetricConfig() # Default config when an invalid type is provided
|
|
87
87
|
|
|
88
88
|
|
|
@@ -104,7 +104,7 @@ async def register_ragas_evaluator(config: RagasEvaluatorConfig, builder: EvalBu
|
|
|
104
104
|
raise ValueError(message) from e
|
|
105
105
|
except AttributeError as e:
|
|
106
106
|
message = f"Ragas metric {metric_name} not found {e}."
|
|
107
|
-
logger.
|
|
107
|
+
logger.exception(message)
|
|
108
108
|
return None
|
|
109
109
|
|
|
110
110
|
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
nat/eval/register.py
CHANGED
|
@@ -14,10 +14,13 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
# flake8: noqa
|
|
17
|
-
# pylint: disable=unused-import
|
|
18
17
|
|
|
19
18
|
# Import evaluators which need to be automatically registered here
|
|
20
19
|
from .rag_evaluator.register import register_ragas_evaluator
|
|
20
|
+
from .runtime_evaluator.register import register_avg_llm_latency_evaluator
|
|
21
|
+
from .runtime_evaluator.register import register_avg_num_llm_calls_evaluator
|
|
22
|
+
from .runtime_evaluator.register import register_avg_tokens_per_llm_end_evaluator
|
|
23
|
+
from .runtime_evaluator.register import register_avg_workflow_runtime_evaluator
|
|
21
24
|
from .swe_bench_evaluator.register import register_swe_bench_evaluator
|
|
22
25
|
from .trajectory_evaluator.register import register_trajectory_evaluator
|
|
23
26
|
from .tunable_rag_evaluator.register import register_tunable_rag_evaluator
|
nat/eval/remote_workflow.py
CHANGED
|
@@ -74,7 +74,7 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
74
74
|
if chunk_data.get("value"):
|
|
75
75
|
final_response = chunk_data.get("value")
|
|
76
76
|
except json.JSONDecodeError as e:
|
|
77
|
-
logger.
|
|
77
|
+
logger.exception("Failed to parse generate response chunk: %s", e)
|
|
78
78
|
continue
|
|
79
79
|
elif line.startswith(INTERMEDIATE_DATA_PREFIX):
|
|
80
80
|
# This is an intermediate step
|
|
@@ -90,12 +90,12 @@ class EvaluationRemoteWorkflowHandler:
|
|
|
90
90
|
payload=payload)
|
|
91
91
|
intermediate_steps.append(intermediate_step)
|
|
92
92
|
except (json.JSONDecodeError, ValidationError) as e:
|
|
93
|
-
logger.
|
|
93
|
+
logger.exception("Failed to parse intermediate step: %s", e)
|
|
94
94
|
continue
|
|
95
95
|
|
|
96
96
|
except aiohttp.ClientError as e:
|
|
97
97
|
# Handle connection or HTTP-related errors
|
|
98
|
-
logger.
|
|
98
|
+
logger.exception("Request failed for question %s: %s", question, e)
|
|
99
99
|
item.output_obj = None
|
|
100
100
|
item.trajectory = []
|
|
101
101
|
return
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
from nat.data_models.intermediate_step import IntermediateStepType
|
|
22
|
+
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
24
|
+
from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
25
|
+
from nat.profiler.intermediate_property_adapter import IntermediatePropertyAdaptor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class _CallTiming:
|
|
30
|
+
start_ts: float | None = None
|
|
31
|
+
end_ts: float | None = None
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def latency(self) -> float | None:
|
|
35
|
+
if self.start_ts is None or self.end_ts is None:
|
|
36
|
+
return None
|
|
37
|
+
return max(0.0, self.end_ts - self.start_ts)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AverageLLMLatencyEvaluator(BaseEvaluator):
|
|
41
|
+
"""
|
|
42
|
+
Mean difference between connected LLM_START and LLM_END events (same UUID).
|
|
43
|
+
The score is the average latency in seconds for the item. Reasoning contains per-call latencies.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, max_concurrency: int = 8):
|
|
47
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg LLM Latency")
|
|
48
|
+
|
|
49
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
50
|
+
calls: dict[str, _CallTiming] = defaultdict(_CallTiming)
|
|
51
|
+
|
|
52
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
53
|
+
if step.event_type == IntermediateStepType.LLM_START:
|
|
54
|
+
calls[step.UUID].start_ts = step.event_timestamp
|
|
55
|
+
elif step.event_type == IntermediateStepType.LLM_END:
|
|
56
|
+
calls[step.UUID].end_ts = step.event_timestamp
|
|
57
|
+
|
|
58
|
+
latencies = [ct.latency for ct in calls.values() if ct.latency is not None]
|
|
59
|
+
avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
|
|
60
|
+
|
|
61
|
+
reasoning = {
|
|
62
|
+
"num_llm_calls": len(latencies),
|
|
63
|
+
"latencies": latencies,
|
|
64
|
+
}
|
|
65
|
+
return EvalOutputItem(id=item.id, score=round(avg_latency, 4), reasoning=reasoning)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AverageWorkflowRuntimeEvaluator(BaseEvaluator):
|
|
69
|
+
"""
|
|
70
|
+
Average workflow runtime per item: max(event_timestamp) - min(event_timestamp) across the trajectory.
|
|
71
|
+
The score is the runtime in seconds for the item.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, max_concurrency: int = 8):
|
|
75
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Workflow Runtime")
|
|
76
|
+
|
|
77
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
78
|
+
if not item.trajectory:
|
|
79
|
+
return EvalOutputItem(id=item.id, score=0.0, reasoning={"note": "no steps"})
|
|
80
|
+
|
|
81
|
+
timestamps = [s.event_timestamp for s in item.trajectory]
|
|
82
|
+
runtime = max(timestamps) - min(timestamps)
|
|
83
|
+
return EvalOutputItem(id=item.id, score=round(max(0.0, runtime), 4), reasoning={"steps": len(timestamps)})
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AverageNumberOfLLMCallsEvaluator(BaseEvaluator):
|
|
87
|
+
"""
|
|
88
|
+
Average number of LLM calls per item. The score is the count for the item.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, max_concurrency: int = 8):
|
|
92
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg # LLM Calls")
|
|
93
|
+
|
|
94
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
95
|
+
num_calls = sum(1 for s in item.trajectory if s.event_type == IntermediateStepType.LLM_END)
|
|
96
|
+
return EvalOutputItem(id=item.id, score=float(num_calls), reasoning={"num_llm_end": num_calls})
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class AverageTokensPerLLMEndEvaluator(BaseEvaluator):
|
|
100
|
+
"""
|
|
101
|
+
Average total tokens per LLM_END event: sum of prompt and completion tokens if available.
|
|
102
|
+
The score is the average tokens per LLM_END for the item (0 if none).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, max_concurrency: int = 8):
|
|
106
|
+
super().__init__(max_concurrency=max_concurrency, tqdm_desc="Evaluating Avg Tokens/LLM_END")
|
|
107
|
+
|
|
108
|
+
async def evaluate_item(self, item: EvalInputItem) -> EvalOutputItem: # noqa: D401
|
|
109
|
+
totals: list[int] = []
|
|
110
|
+
for step in (IntermediatePropertyAdaptor.from_intermediate_step(s) for s in item.trajectory):
|
|
111
|
+
if step.event_type == IntermediateStepType.LLM_END:
|
|
112
|
+
total_tokens = step.token_usage.total_tokens
|
|
113
|
+
# If framework doesn't set total, compute from prompt+completion
|
|
114
|
+
if total_tokens == 0:
|
|
115
|
+
total_tokens = step.token_usage.prompt_tokens + step.token_usage.completion_tokens
|
|
116
|
+
totals.append(total_tokens)
|
|
117
|
+
|
|
118
|
+
avg_tokens = (sum(totals) / len(totals)) if totals else 0.0
|
|
119
|
+
reasoning = {
|
|
120
|
+
"num_llm_end": len(totals),
|
|
121
|
+
"totals": totals,
|
|
122
|
+
}
|
|
123
|
+
return EvalOutputItem(id=item.id, score=round(avg_tokens, 2), reasoning=reasoning)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import EvalBuilder
|
|
19
|
+
from nat.builder.evaluator import EvaluatorInfo
|
|
20
|
+
from nat.cli.register_workflow import register_evaluator
|
|
21
|
+
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
22
|
+
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
23
|
+
from nat.eval.evaluator.evaluator_model import EvalOutput
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AverageLLMLatencyConfig(EvaluatorBaseConfig, name="avg_llm_latency"):
|
|
27
|
+
"""Mean difference between connected LLM_START and LLM_END events (same UUID)."""
|
|
28
|
+
|
|
29
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AverageWorkflowRuntimeConfig(EvaluatorBaseConfig, name="avg_workflow_runtime"):
|
|
33
|
+
"""Average workflow runtime per item (max timestamp - min timestamp)."""
|
|
34
|
+
|
|
35
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AverageNumberOfLLMCallsConfig(EvaluatorBaseConfig, name="avg_num_llm_calls"):
|
|
39
|
+
"""Average number of LLM calls per item (count of LLM_END)."""
|
|
40
|
+
|
|
41
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AverageTokensPerLLMEndConfig(EvaluatorBaseConfig, name="avg_tokens_per_llm_end"):
|
|
45
|
+
"""Average total tokens per LLM_END event (prompt + completion if available)."""
|
|
46
|
+
|
|
47
|
+
max_concurrency: int = Field(default=8, description="Max concurrency for evaluation.")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_evaluator(config_type=AverageLLMLatencyConfig)
|
|
51
|
+
async def register_avg_llm_latency_evaluator(config: AverageLLMLatencyConfig, builder: EvalBuilder):
|
|
52
|
+
from .evaluate import AverageLLMLatencyEvaluator
|
|
53
|
+
|
|
54
|
+
evaluator = AverageLLMLatencyEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
55
|
+
|
|
56
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
57
|
+
return await evaluator.evaluate(eval_input)
|
|
58
|
+
|
|
59
|
+
yield EvaluatorInfo(config=config,
|
|
60
|
+
evaluate_fn=evaluate_fn,
|
|
61
|
+
description="Average LLM latency (s) from LLM_START to LLM_END")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@register_evaluator(config_type=AverageWorkflowRuntimeConfig)
|
|
65
|
+
async def register_avg_workflow_runtime_evaluator(config: AverageWorkflowRuntimeConfig, builder: EvalBuilder):
|
|
66
|
+
from .evaluate import AverageWorkflowRuntimeEvaluator
|
|
67
|
+
|
|
68
|
+
evaluator = AverageWorkflowRuntimeEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
69
|
+
|
|
70
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
71
|
+
return await evaluator.evaluate(eval_input)
|
|
72
|
+
|
|
73
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average workflow runtime (s)")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@register_evaluator(config_type=AverageNumberOfLLMCallsConfig)
|
|
77
|
+
async def register_avg_num_llm_calls_evaluator(config: AverageNumberOfLLMCallsConfig, builder: EvalBuilder):
|
|
78
|
+
from .evaluate import AverageNumberOfLLMCallsEvaluator
|
|
79
|
+
|
|
80
|
+
evaluator = AverageNumberOfLLMCallsEvaluator(
|
|
81
|
+
max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
82
|
+
|
|
83
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
84
|
+
return await evaluator.evaluate(eval_input)
|
|
85
|
+
|
|
86
|
+
yield EvaluatorInfo(config=config, evaluate_fn=evaluate_fn, description="Average number of LLM calls")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@register_evaluator(config_type=AverageTokensPerLLMEndConfig)
|
|
90
|
+
async def register_avg_tokens_per_llm_end_evaluator(config: AverageTokensPerLLMEndConfig, builder: EvalBuilder):
|
|
91
|
+
from .evaluate import AverageTokensPerLLMEndEvaluator
|
|
92
|
+
|
|
93
|
+
evaluator = AverageTokensPerLLMEndEvaluator(max_concurrency=config.max_concurrency or builder.get_max_concurrency())
|
|
94
|
+
|
|
95
|
+
async def evaluate_fn(eval_input: EvalInput) -> EvalOutput:
|
|
96
|
+
return await evaluator.evaluate(eval_input)
|
|
97
|
+
|
|
98
|
+
yield EvaluatorInfo(config=config,
|
|
99
|
+
evaluate_fn=evaluate_fn,
|
|
100
|
+
description="Average total tokens per LLM_END (prompt + completion)")
|
|
@@ -69,13 +69,13 @@ class SweBenchEvaluator:
|
|
|
69
69
|
try:
|
|
70
70
|
shutil.move(swe_bench_report_file, report_dir)
|
|
71
71
|
except Exception as e:
|
|
72
|
-
logger.exception("Error moving report file: %s", e
|
|
72
|
+
logger.exception("Error moving report file: %s", e)
|
|
73
73
|
|
|
74
74
|
try:
|
|
75
75
|
dest_logs_dir = os.path.join(report_dir, 'logs')
|
|
76
76
|
shutil.move(logs_dir, dest_logs_dir)
|
|
77
77
|
except Exception as e:
|
|
78
|
-
logger.exception("Error moving logs directory: %s", e
|
|
78
|
+
logger.exception("Error moving logs directory: %s", e)
|
|
79
79
|
|
|
80
80
|
def is_repo_supported(self, repo: str, version: str) -> bool:
|
|
81
81
|
"""Check if the repo is supported by swebench"""
|
|
@@ -106,7 +106,7 @@ class SweBenchEvaluator:
|
|
|
106
106
|
self._model_name_or_path = swebench_output.model_name_or_path
|
|
107
107
|
|
|
108
108
|
except Exception as e:
|
|
109
|
-
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e
|
|
109
|
+
logger.exception("Failed to parse EvalInputItem %s: %s", item.id, e)
|
|
110
110
|
|
|
111
111
|
# Filter out repos/version not supported by SWEBench
|
|
112
112
|
supported_inputs = [
|
|
@@ -114,7 +114,7 @@ class SweBenchEvaluator:
|
|
|
114
114
|
]
|
|
115
115
|
|
|
116
116
|
if not supported_inputs:
|
|
117
|
-
logger.
|
|
117
|
+
logger.exception("No supported instances; nothing to evaluate")
|
|
118
118
|
return None, None
|
|
119
119
|
|
|
120
120
|
if len(supported_inputs) < len(swebench_inputs):
|
|
@@ -135,7 +135,7 @@ class SweBenchEvaluator:
|
|
|
135
135
|
filtered_outputs = [output for output in swebench_outputs if output.instance_id in valid_instance_ids]
|
|
136
136
|
|
|
137
137
|
if not filtered_outputs:
|
|
138
|
-
logger.error("No supported outputs; nothing to evaluate")
|
|
138
|
+
logger.error("No supported outputs; nothing to evaluate", exc_info=True)
|
|
139
139
|
return None, None
|
|
140
140
|
|
|
141
141
|
# Write SWEBenchOutput to file
|
|
@@ -204,7 +204,7 @@ class SweBenchEvaluator:
|
|
|
204
204
|
# if report file is not present, return empty EvalOutput
|
|
205
205
|
avg_score = 0.0
|
|
206
206
|
if report_file.exists():
|
|
207
|
-
with open(report_file,
|
|
207
|
+
with open(report_file, encoding="utf-8") as f:
|
|
208
208
|
report = json.load(f)
|
|
209
209
|
resolved_instances = report.get("resolved_instances", 0)
|
|
210
210
|
total_instances = report.get("total_instances", 0)
|
|
@@ -65,7 +65,7 @@ class TrajectoryEvaluator(BaseEvaluator):
|
|
|
65
65
|
prediction=generated_answer,
|
|
66
66
|
)
|
|
67
67
|
except Exception as e:
|
|
68
|
-
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e
|
|
68
|
+
logger.exception("Error evaluating trajectory for question: %s, Error: %s", question, e)
|
|
69
69
|
return EvalOutputItem(id=item.id, score=0.0, reasoning=f"Error evaluating trajectory: {e}")
|
|
70
70
|
|
|
71
71
|
reasoning = {
|
|
@@ -33,7 +33,7 @@ async def register_trajectory_evaluator(config: TrajectoryEvaluatorConfig, build
|
|
|
33
33
|
|
|
34
34
|
from .evaluate import TrajectoryEvaluator
|
|
35
35
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
-
tools = builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
36
|
+
tools = await builder.get_all_tools(wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
37
37
|
|
|
38
38
|
_evaluator = TrajectoryEvaluator(llm, tools, builder.get_max_concurrency())
|
|
39
39
|
|