nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/base.py +16 -0
- nat/agent/react_agent/agent.py +38 -13
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +6 -3
- nat/agent/rewoo_agent/prompt.py +3 -0
- nat/agent/rewoo_agent/register.py +4 -3
- nat/agent/tool_calling_agent/agent.py +92 -22
- nat/agent/tool_calling_agent/register.py +9 -13
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +3 -0
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +7 -9
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/openai_embedder.py +1 -2
- nat/embedder/register.py +1 -1
- nat/eval/config.py +2 -0
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +64 -20
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +14 -7
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +49 -0
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +1 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/meta/pypi.md +9 -9
- nat/object_store/models.py +2 -0
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +9 -2
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +13 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
- nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
- nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
nat/eval/evaluate.py
CHANGED
|
@@ -42,7 +42,7 @@ from nat.runtime.session import SessionManager
|
|
|
42
42
|
logger = logging.getLogger(__name__)
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
class EvaluationRun:
|
|
45
|
+
class EvaluationRun:
|
|
46
46
|
"""
|
|
47
47
|
Instantiated for each evaluation run and used to store data for that single run.
|
|
48
48
|
|
|
@@ -63,7 +63,16 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
63
63
|
|
|
64
64
|
# Helpers
|
|
65
65
|
self.intermediate_step_adapter: IntermediateStepAdapter = IntermediateStepAdapter()
|
|
66
|
-
|
|
66
|
+
|
|
67
|
+
# Create evaluation trace context
|
|
68
|
+
try:
|
|
69
|
+
from nat.eval.utils.eval_trace_ctx import WeaveEvalTraceContext
|
|
70
|
+
self.eval_trace_context = WeaveEvalTraceContext()
|
|
71
|
+
except Exception:
|
|
72
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
73
|
+
self.eval_trace_context = EvalTraceContext()
|
|
74
|
+
|
|
75
|
+
self.weave_eval: WeaveEvaluationIntegration = WeaveEvaluationIntegration(self.eval_trace_context)
|
|
67
76
|
# Metadata
|
|
68
77
|
self.eval_input: EvalInput | None = None
|
|
69
78
|
self.workflow_interrupted: bool = False
|
|
@@ -310,7 +319,7 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
310
319
|
except Exception as e:
|
|
311
320
|
logger.exception("Failed to delete old job directory: %s: %s", dir_to_delete, e, exc_info=True)
|
|
312
321
|
|
|
313
|
-
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
322
|
+
def write_output(self, dataset_handler: DatasetHandler, profiler_results: ProfilerResults):
|
|
314
323
|
workflow_output_file = self.eval_config.general.output_dir / "workflow_output.json"
|
|
315
324
|
workflow_output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
316
325
|
|
|
@@ -401,6 +410,33 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
401
410
|
|
|
402
411
|
return workflow_type
|
|
403
412
|
|
|
413
|
+
async def wait_for_all_export_tasks_local(self, session_manager: SessionManager, timeout: float) -> None:
|
|
414
|
+
"""Wait for all trace export tasks to complete for local workflows.
|
|
415
|
+
|
|
416
|
+
This only works for local workflows where we have direct access to the
|
|
417
|
+
SessionManager and its underlying workflow with exporter manager.
|
|
418
|
+
"""
|
|
419
|
+
try:
|
|
420
|
+
workflow = session_manager.workflow
|
|
421
|
+
all_exporters = await workflow.get_all_exporters()
|
|
422
|
+
if not all_exporters:
|
|
423
|
+
logger.debug("No exporters to wait for")
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
logger.info("Waiting for export tasks from %d local exporters (timeout: %ds)", len(all_exporters), timeout)
|
|
427
|
+
|
|
428
|
+
for name, exporter in all_exporters.items():
|
|
429
|
+
try:
|
|
430
|
+
await exporter.wait_for_tasks(timeout=timeout)
|
|
431
|
+
logger.info("Export tasks completed for exporter: %s", name)
|
|
432
|
+
except Exception as e:
|
|
433
|
+
logger.warning("Error waiting for export tasks from %s: %s", name, e)
|
|
434
|
+
|
|
435
|
+
logger.info("All local export task waiting completed")
|
|
436
|
+
|
|
437
|
+
except Exception as e:
|
|
438
|
+
logger.warning("Failed to wait for local export tasks: %s", e)
|
|
439
|
+
|
|
404
440
|
async def run_and_evaluate(self,
|
|
405
441
|
session_manager: SessionManager | None = None,
|
|
406
442
|
job_id: str | None = None) -> EvaluationRunOutput:
|
|
@@ -442,11 +478,13 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
442
478
|
dataset_config = self.eval_config.general.dataset # Currently only one dataset is supported
|
|
443
479
|
if not dataset_config:
|
|
444
480
|
logger.info("No dataset found, nothing to evaluate")
|
|
445
|
-
return EvaluationRunOutput(
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
481
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
482
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
483
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
484
|
+
eval_input=EvalInput(eval_input_items=[]),
|
|
485
|
+
evaluation_results=[],
|
|
486
|
+
usage_stats=UsageStats(),
|
|
487
|
+
profiler_results=ProfilerResults())
|
|
450
488
|
|
|
451
489
|
dataset_handler = DatasetHandler(dataset_config=dataset_config,
|
|
452
490
|
reps=self.config.reps,
|
|
@@ -456,11 +494,13 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
456
494
|
self.eval_input = dataset_handler.get_eval_input_from_dataset(self.config.dataset)
|
|
457
495
|
if not self.eval_input.eval_input_items:
|
|
458
496
|
logger.info("Dataset is empty. Nothing to evaluate.")
|
|
459
|
-
return EvaluationRunOutput(
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
497
|
+
return EvaluationRunOutput(workflow_output_file=self.workflow_output_file,
|
|
498
|
+
evaluator_output_files=self.evaluator_output_files,
|
|
499
|
+
workflow_interrupted=self.workflow_interrupted,
|
|
500
|
+
eval_input=self.eval_input,
|
|
501
|
+
evaluation_results=self.evaluation_results,
|
|
502
|
+
usage_stats=self.usage_stats,
|
|
503
|
+
profiler_results=ProfilerResults())
|
|
464
504
|
|
|
465
505
|
# Run workflow and evaluate
|
|
466
506
|
async with WorkflowEvalBuilder.from_config(config=config) as eval_workflow:
|
|
@@ -468,18 +508,22 @@ class EvaluationRun: # pylint: disable=too-many-public-methods
|
|
|
468
508
|
self.weave_eval.initialize_logger(workflow_alias, self.eval_input, config)
|
|
469
509
|
|
|
470
510
|
# Run workflow
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
511
|
+
with self.eval_trace_context.evaluation_context():
|
|
512
|
+
if self.config.endpoint:
|
|
513
|
+
await self.run_workflow_remote()
|
|
514
|
+
elif not self.config.skip_workflow:
|
|
475
515
|
if session_manager is None:
|
|
476
516
|
session_manager = SessionManager(eval_workflow.build(),
|
|
477
517
|
max_concurrency=self.eval_config.general.max_concurrency)
|
|
478
518
|
await self.run_workflow_local(session_manager)
|
|
479
519
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
520
|
+
# Evaluate
|
|
521
|
+
evaluators = {name: eval_workflow.get_evaluator(name) for name in self.eval_config.evaluators}
|
|
522
|
+
await self.run_evaluators(evaluators)
|
|
523
|
+
|
|
524
|
+
# Wait for all trace export tasks to complete (local workflows only)
|
|
525
|
+
if session_manager and not self.config.endpoint:
|
|
526
|
+
await self.wait_for_all_export_tasks_local(session_manager, timeout=self.config.export_timeout)
|
|
483
527
|
|
|
484
528
|
# Profile the workflow
|
|
485
529
|
profiler_results = await self.profile_workflow()
|
|
@@ -73,7 +73,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
73
73
|
if isinstance(self.metric, str):
|
|
74
74
|
return self.metric
|
|
75
75
|
if isinstance(self.metric, dict) and self.metric:
|
|
76
|
-
return next(iter(self.metric.keys()))
|
|
76
|
+
return next(iter(self.metric.keys()))
|
|
77
77
|
return ""
|
|
78
78
|
|
|
79
79
|
@property
|
|
@@ -82,7 +82,7 @@ class RagasEvaluatorConfig(EvaluatorBaseConfig, name="ragas"):
|
|
|
82
82
|
if isinstance(self.metric, str):
|
|
83
83
|
return RagasMetricConfig() # Default config when only a metric name is given
|
|
84
84
|
if isinstance(self.metric, dict) and self.metric:
|
|
85
|
-
return next(iter(self.metric.values()))
|
|
85
|
+
return next(iter(self.metric.values()))
|
|
86
86
|
return RagasMetricConfig() # Default config when an invalid type is provided
|
|
87
87
|
|
|
88
88
|
|
nat/eval/register.py
CHANGED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import asyncio
|
|
17
16
|
import logging
|
|
18
17
|
from typing import Callable
|
|
19
18
|
|
|
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
|
|
|
23
22
|
from langchain.schema import SystemMessage
|
|
24
23
|
from langchain_core.language_models import BaseChatModel
|
|
25
24
|
from langchain_core.runnables import RunnableLambda
|
|
26
|
-
from tqdm import tqdm
|
|
27
25
|
|
|
28
26
|
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
29
27
|
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
|
31
29
|
|
|
32
30
|
logger = logging.getLogger(__name__)
|
|
33
31
|
|
|
34
|
-
# pylint: disable=line-too-long
|
|
35
32
|
# flake8: noqa: E501
|
|
36
33
|
|
|
37
34
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Type alias for evaluation call objects that have an optional 'id' attribute
|
|
24
|
+
EvalCallType = Any # Could be Weave Call object or other tracing framework objects
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EvalTraceContext:
|
|
28
|
+
"""
|
|
29
|
+
Evaluation trace context manager for coordinating traces.
|
|
30
|
+
|
|
31
|
+
This class provides a framework-agnostic way to:
|
|
32
|
+
1. Track evaluation calls/contexts
|
|
33
|
+
2. Ensure proper parent-child relationships in traces
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.eval_call: EvalCallType | None = None # Store the evaluation call/context for propagation
|
|
38
|
+
|
|
39
|
+
def set_eval_call(self, eval_call: EvalCallType | None) -> None:
|
|
40
|
+
"""Set the evaluation call/context for propagation to traces."""
|
|
41
|
+
self.eval_call = eval_call
|
|
42
|
+
if eval_call:
|
|
43
|
+
logger.debug("Set evaluation call context: %s", getattr(eval_call, 'id', str(eval_call)))
|
|
44
|
+
|
|
45
|
+
def get_eval_call(self) -> EvalCallType | None:
|
|
46
|
+
"""Get the current evaluation call/context."""
|
|
47
|
+
return self.eval_call
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def evaluation_context(self):
|
|
51
|
+
"""
|
|
52
|
+
Context manager that can be overridden by framework-specific implementations.
|
|
53
|
+
Default implementation is a no-op.
|
|
54
|
+
"""
|
|
55
|
+
yield
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WeaveEvalTraceContext(EvalTraceContext):
|
|
59
|
+
"""
|
|
60
|
+
Weave-specific implementation of evaluation trace context.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.available = False
|
|
66
|
+
self.set_call_stack: Callable[[list[EvalCallType]], Any] | None = None
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from weave.trace.context.call_context import set_call_stack
|
|
70
|
+
self.set_call_stack = set_call_stack
|
|
71
|
+
self.available = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
self.available = False
|
|
74
|
+
logger.debug("Weave not available for trace context")
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def evaluation_context(self):
|
|
78
|
+
"""Set the evaluation call as active context for Weave traces."""
|
|
79
|
+
if self.available and self.eval_call and self.set_call_stack:
|
|
80
|
+
try:
|
|
81
|
+
with self.set_call_stack([self.eval_call]):
|
|
82
|
+
logger.debug("Set Weave evaluation call context: %s",
|
|
83
|
+
getattr(self.eval_call, 'id', str(self.eval_call)))
|
|
84
|
+
yield
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.warning("Failed to set Weave evaluation call context: %s", e)
|
|
87
|
+
yield
|
|
88
|
+
else:
|
|
89
|
+
yield
|
nat/eval/utils/weave_eval.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
@@ -24,26 +25,30 @@ from nat.eval.usage_stats import UsageStats
|
|
|
24
25
|
from nat.eval.usage_stats import UsageStatsItem
|
|
25
26
|
from nat.profiler.data_models import ProfilerResults
|
|
26
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
30
|
+
|
|
27
31
|
logger = logging.getLogger(__name__)
|
|
28
32
|
|
|
29
33
|
|
|
30
|
-
class WeaveEvaluationIntegration:
|
|
34
|
+
class WeaveEvaluationIntegration:
|
|
31
35
|
"""
|
|
32
36
|
Class to handle all Weave integration functionality.
|
|
33
37
|
"""
|
|
34
38
|
|
|
35
|
-
def __init__(self):
|
|
39
|
+
def __init__(self, eval_trace_context: "EvalTraceContext"):
|
|
36
40
|
self.available = False
|
|
37
41
|
self.client = None
|
|
38
42
|
self.eval_logger = None
|
|
39
43
|
self.pred_loggers = {}
|
|
44
|
+
self.eval_trace_context = eval_trace_context
|
|
40
45
|
|
|
41
46
|
try:
|
|
42
47
|
from weave.flow.eval_imperative import EvaluationLogger
|
|
43
48
|
from weave.flow.eval_imperative import ScoreLogger
|
|
44
49
|
from weave.trace.context import weave_client_context
|
|
45
|
-
self.
|
|
46
|
-
self.
|
|
50
|
+
self.evaluation_logger_cls = EvaluationLogger
|
|
51
|
+
self.score_logger_cls = ScoreLogger
|
|
47
52
|
self.weave_client_context = weave_client_context
|
|
48
53
|
self.available = True
|
|
49
54
|
except ImportError:
|
|
@@ -89,9 +94,12 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
89
94
|
weave_dataset = self._get_weave_dataset(eval_input)
|
|
90
95
|
config_dict = config.model_dump(mode="json")
|
|
91
96
|
config_dict["name"] = workflow_alias
|
|
92
|
-
self.eval_logger = self.
|
|
97
|
+
self.eval_logger = self.evaluation_logger_cls(model=config_dict, dataset=weave_dataset)
|
|
93
98
|
self.pred_loggers = {}
|
|
94
99
|
|
|
100
|
+
# Capture the current evaluation call for context propagation
|
|
101
|
+
self.eval_trace_context.set_eval_call(self.eval_logger._evaluate_call)
|
|
102
|
+
|
|
95
103
|
return True
|
|
96
104
|
except Exception as e:
|
|
97
105
|
self.eval_logger = None
|
|
@@ -137,7 +145,7 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
137
145
|
await asyncio.gather(*coros)
|
|
138
146
|
|
|
139
147
|
async def afinish_loggers(self):
|
|
140
|
-
"""Finish all prediction loggers."""
|
|
148
|
+
"""Finish all prediction loggers and wait for exports."""
|
|
141
149
|
if not self.eval_logger:
|
|
142
150
|
return
|
|
143
151
|
|
|
@@ -157,7 +165,6 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
157
165
|
if profiler_results.workflow_runtime_metrics:
|
|
158
166
|
profile_metrics["wf_runtime_p95"] = profiler_results.workflow_runtime_metrics.p95
|
|
159
167
|
|
|
160
|
-
# TODO:get the LLM tokens from the usage stats and log them
|
|
161
168
|
profile_metrics["total_runtime"] = usage_stats.total_runtime
|
|
162
169
|
|
|
163
170
|
return profile_metrics
|
|
@@ -17,9 +17,10 @@ from abc import ABC
|
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
|
|
19
19
|
from nat.builder.builder import Builder
|
|
20
|
-
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
21
|
-
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
|
|
22
20
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
21
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class StrategyBase(ABC):
|
|
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
135
135
|
except Exception as e:
|
|
136
136
|
logger.error(f"Error parsing merged output: {e}")
|
|
137
137
|
raise ValueError("Failed to parse merged output.")
|
|
138
|
-
else:
|
|
139
|
-
merged_output = merged_output
|
|
140
138
|
|
|
141
139
|
logger.info("Merged output: %s", str(merged_output))
|
|
142
140
|
|
|
@@ -307,7 +307,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
307
307
|
async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
|
|
308
308
|
"""Handle evaluation requests."""
|
|
309
309
|
|
|
310
|
-
async with session_manager.session(
|
|
310
|
+
async with session_manager.session(http_connection=http_request):
|
|
311
311
|
|
|
312
312
|
# if job_id is present and already exists return the job info
|
|
313
313
|
if request.job_id:
|
|
@@ -336,7 +336,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
336
336
|
"""Get the status of an evaluation job."""
|
|
337
337
|
logger.info("Getting status for job %s", job_id)
|
|
338
338
|
|
|
339
|
-
async with session_manager.session(
|
|
339
|
+
async with session_manager.session(http_connection=http_request):
|
|
340
340
|
|
|
341
341
|
job = job_store.get_job(job_id)
|
|
342
342
|
if not job:
|
|
@@ -349,7 +349,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
349
349
|
"""Get the status of the last created evaluation job."""
|
|
350
350
|
logger.info("Getting last job status")
|
|
351
351
|
|
|
352
|
-
async with session_manager.session(
|
|
352
|
+
async with session_manager.session(http_connection=http_request):
|
|
353
353
|
|
|
354
354
|
job = job_store.get_last_job()
|
|
355
355
|
if not job:
|
|
@@ -361,7 +361,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
361
361
|
async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
|
|
362
362
|
"""Get all jobs, optionally filtered by status."""
|
|
363
363
|
|
|
364
|
-
async with session_manager.session(
|
|
364
|
+
async with session_manager.session(http_connection=http_request):
|
|
365
365
|
|
|
366
366
|
if status is None:
|
|
367
367
|
logger.info("Getting all jobs")
|
|
@@ -522,9 +522,9 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
522
522
|
|
|
523
523
|
workflow = session_manager.workflow
|
|
524
524
|
|
|
525
|
-
GenerateBodyType = workflow.input_schema
|
|
526
|
-
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
|
-
GenerateSingleResponseType = workflow.single_output_schema
|
|
525
|
+
GenerateBodyType = workflow.input_schema
|
|
526
|
+
GenerateStreamResponseType = workflow.streaming_output_schema
|
|
527
|
+
GenerateSingleResponseType = workflow.single_output_schema
|
|
528
528
|
|
|
529
529
|
# Append job_id and expiry_seconds to the input schema, this effectively makes these reserved keywords
|
|
530
530
|
# Consider prefixing these with "nat_" to avoid conflicts
|
|
@@ -572,7 +572,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
572
572
|
|
|
573
573
|
response.headers["Content-Type"] = "application/json"
|
|
574
574
|
|
|
575
|
-
async with session_manager.session(
|
|
575
|
+
async with session_manager.session(http_connection=request,
|
|
576
576
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
577
577
|
|
|
578
578
|
return await generate_single_response(None, session_manager, result_type=result_type)
|
|
@@ -583,7 +583,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
583
583
|
|
|
584
584
|
async def get_stream(request: Request):
|
|
585
585
|
|
|
586
|
-
async with session_manager.session(
|
|
586
|
+
async with session_manager.session(http_connection=request,
|
|
587
587
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
588
588
|
|
|
589
589
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -618,7 +618,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
618
618
|
|
|
619
619
|
response.headers["Content-Type"] = "application/json"
|
|
620
620
|
|
|
621
|
-
async with session_manager.session(
|
|
621
|
+
async with session_manager.session(http_connection=request,
|
|
622
622
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
623
623
|
|
|
624
624
|
return await generate_single_response(payload, session_manager, result_type=result_type)
|
|
@@ -632,7 +632,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
632
632
|
|
|
633
633
|
async def post_stream(request: Request, payload: request_type):
|
|
634
634
|
|
|
635
|
-
async with session_manager.session(
|
|
635
|
+
async with session_manager.session(http_connection=request,
|
|
636
636
|
user_authentication_callback=self._http_flow_handler.authenticate):
|
|
637
637
|
|
|
638
638
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -677,7 +677,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
677
677
|
# Check if streaming is requested
|
|
678
678
|
stream_requested = getattr(payload, 'stream', False)
|
|
679
679
|
|
|
680
|
-
async with session_manager.session(
|
|
680
|
+
async with session_manager.session(http_connection=request):
|
|
681
681
|
if stream_requested:
|
|
682
682
|
# Return streaming response
|
|
683
683
|
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
|
|
@@ -688,42 +688,41 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
688
688
|
step_adaptor=self.get_step_adaptor(),
|
|
689
689
|
result_type=ChatResponseChunk,
|
|
690
690
|
output_type=ChatResponseChunk))
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
691
|
+
|
|
692
|
+
# Return single response - check if workflow supports non-streaming
|
|
693
|
+
try:
|
|
694
|
+
response.headers["Content-Type"] = "application/json"
|
|
695
|
+
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
|
|
696
|
+
except ValueError as e:
|
|
697
|
+
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
698
|
+
# Workflow only supports streaming, but client requested non-streaming
|
|
699
|
+
# Fall back to streaming and collect the result
|
|
700
|
+
chunks = []
|
|
701
|
+
async for chunk_str in generate_streaming_response_as_str(
|
|
702
|
+
payload,
|
|
703
|
+
session_manager=session_manager,
|
|
704
|
+
streaming=True,
|
|
705
|
+
step_adaptor=self.get_step_adaptor(),
|
|
706
|
+
result_type=ChatResponseChunk,
|
|
707
|
+
output_type=ChatResponseChunk):
|
|
708
|
+
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
709
|
+
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
710
|
+
if chunk_data:
|
|
711
|
+
try:
|
|
712
|
+
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
713
|
+
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
714
|
+
and chunk_json.choices[0].delta
|
|
715
|
+
and chunk_json.choices[0].delta.content is not None):
|
|
716
|
+
chunks.append(chunk_json.choices[0].delta.content)
|
|
717
|
+
except Exception:
|
|
718
|
+
continue
|
|
719
|
+
|
|
720
|
+
# Create a single response from collected chunks
|
|
721
|
+
content = "".join(chunks)
|
|
722
|
+
single_response = ChatResponse.from_string(content)
|
|
694
723
|
response.headers["Content-Type"] = "application/json"
|
|
695
|
-
return
|
|
696
|
-
|
|
697
|
-
if "Cannot get a single output value for streaming workflows" in str(e):
|
|
698
|
-
# Workflow only supports streaming, but client requested non-streaming
|
|
699
|
-
# Fall back to streaming and collect the result
|
|
700
|
-
chunks = []
|
|
701
|
-
async for chunk_str in generate_streaming_response_as_str(
|
|
702
|
-
payload,
|
|
703
|
-
session_manager=session_manager,
|
|
704
|
-
streaming=True,
|
|
705
|
-
step_adaptor=self.get_step_adaptor(),
|
|
706
|
-
result_type=ChatResponseChunk,
|
|
707
|
-
output_type=ChatResponseChunk):
|
|
708
|
-
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
|
|
709
|
-
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
|
|
710
|
-
if chunk_data:
|
|
711
|
-
try:
|
|
712
|
-
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
|
|
713
|
-
if (chunk_json.choices and len(chunk_json.choices) > 0
|
|
714
|
-
and chunk_json.choices[0].delta
|
|
715
|
-
and chunk_json.choices[0].delta.content is not None):
|
|
716
|
-
chunks.append(chunk_json.choices[0].delta.content)
|
|
717
|
-
except Exception:
|
|
718
|
-
continue
|
|
719
|
-
|
|
720
|
-
# Create a single response from collected chunks
|
|
721
|
-
content = "".join(chunks)
|
|
722
|
-
single_response = ChatResponse.from_string(content)
|
|
723
|
-
response.headers["Content-Type"] = "application/json"
|
|
724
|
-
return single_response
|
|
725
|
-
else:
|
|
726
|
-
raise
|
|
724
|
+
return single_response
|
|
725
|
+
raise
|
|
727
726
|
|
|
728
727
|
return post_openai_api_compatible
|
|
729
728
|
|
|
@@ -758,7 +757,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
758
757
|
http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
|
|
759
758
|
"""Handle async generation requests."""
|
|
760
759
|
|
|
761
|
-
async with session_manager.session(
|
|
760
|
+
async with session_manager.session(http_connection=http_request):
|
|
762
761
|
|
|
763
762
|
# if job_id is present and already exists return the job info
|
|
764
763
|
if request.job_id:
|
|
@@ -804,7 +803,7 @@ class FastApiFrontEndPluginWorker(FastApiFrontEndPluginWorkerBase):
|
|
|
804
803
|
"""Get the status of an async job."""
|
|
805
804
|
logger.info("Getting status for job %s", job_id)
|
|
806
805
|
|
|
807
|
-
async with session_manager.session(
|
|
806
|
+
async with session_manager.session(http_connection=http_request):
|
|
808
807
|
|
|
809
808
|
job = job_store.get_job(job_id)
|
|
810
809
|
if not job:
|
|
@@ -86,7 +86,7 @@ class WebSocketMessageHandler:
|
|
|
86
86
|
|
|
87
87
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
|
88
88
|
|
|
89
|
-
# TODO: Handle the exit
|
|
89
|
+
# TODO: Handle the exit
|
|
90
90
|
pass
|
|
91
91
|
|
|
92
92
|
async def run(self) -> None:
|
|
@@ -105,12 +105,10 @@ class WebSocketMessageHandler:
|
|
|
105
105
|
if (isinstance(validated_message, WebSocketUserMessage)):
|
|
106
106
|
await self.process_workflow_request(validated_message)
|
|
107
107
|
|
|
108
|
-
elif isinstance(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
WebSocketSystemIntermediateStepMessage,
|
|
113
|
-
WebSocketSystemInteractionMessage)):
|
|
108
|
+
elif isinstance(validated_message,
|
|
109
|
+
(WebSocketSystemResponseTokenMessage,
|
|
110
|
+
WebSocketSystemIntermediateStepMessage,
|
|
111
|
+
WebSocketSystemInteractionMessage)):
|
|
114
112
|
# These messages are already handled by self.create_websocket_message(data_model=value, …)
|
|
115
113
|
# No further processing is needed here.
|
|
116
114
|
pass
|
|
@@ -119,11 +117,9 @@ class WebSocketMessageHandler:
|
|
|
119
117
|
user_content = await self.process_user_message_content(validated_message)
|
|
120
118
|
self._user_interaction_response.set_result(user_content)
|
|
121
119
|
except (asyncio.CancelledError, WebSocketDisconnect):
|
|
122
|
-
# TODO: Handle the disconnect
|
|
120
|
+
# TODO: Handle the disconnect
|
|
123
121
|
break
|
|
124
122
|
|
|
125
|
-
return None
|
|
126
|
-
|
|
127
123
|
async def process_user_message_content(
|
|
128
124
|
self, user_content: WebSocketUserMessage | WebSocketUserInteractionResponseMessage) -> BaseModel | None:
|
|
129
125
|
"""
|
|
@@ -162,12 +158,13 @@ class WebSocketMessageHandler:
|
|
|
162
158
|
|
|
163
159
|
if isinstance(content, TextContent) and (self._running_workflow_task is None):
|
|
164
160
|
|
|
165
|
-
def _done_callback(task: asyncio.Task):
|
|
161
|
+
def _done_callback(task: asyncio.Task):
|
|
166
162
|
self._running_workflow_task = None
|
|
167
163
|
|
|
168
164
|
self._running_workflow_task = asyncio.create_task(
|
|
169
|
-
self._run_workflow(content.text,
|
|
170
|
-
self.
|
|
165
|
+
self._run_workflow(payload=content.text,
|
|
166
|
+
user_message_id=self._message_parent_id,
|
|
167
|
+
conversation_id=self._conversation_id,
|
|
171
168
|
result_type=self._schema_output_mapping[self._workflow_schema_type],
|
|
172
169
|
output_type=self._schema_output_mapping[
|
|
173
170
|
self._workflow_schema_type])).add_done_callback(_done_callback)
|
|
@@ -290,14 +287,16 @@ class WebSocketMessageHandler:
|
|
|
290
287
|
|
|
291
288
|
async def _run_workflow(self,
|
|
292
289
|
payload: typing.Any,
|
|
290
|
+
user_message_id: str | None = None,
|
|
293
291
|
conversation_id: str | None = None,
|
|
294
292
|
result_type: type | None = None,
|
|
295
293
|
output_type: type | None = None) -> None:
|
|
296
294
|
|
|
297
295
|
try:
|
|
298
296
|
async with self._session_manager.session(
|
|
297
|
+
user_message_id=user_message_id,
|
|
299
298
|
conversation_id=conversation_id,
|
|
300
|
-
|
|
299
|
+
http_connection=self._socket,
|
|
301
300
|
user_input_callback=self.human_interaction_callback,
|
|
302
301
|
user_authentication_callback=(self._flow_handler.authenticate
|
|
303
302
|
if self._flow_handler else None)) as session:
|