nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,9 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import asyncio
|
|
17
16
|
import logging
|
|
18
|
-
from
|
|
17
|
+
from collections.abc import Callable
|
|
19
18
|
|
|
20
19
|
from langchain.output_parsers import ResponseSchema
|
|
21
20
|
from langchain.output_parsers import StructuredOutputParser
|
|
@@ -23,7 +22,6 @@ from langchain.schema import HumanMessage
|
|
|
23
22
|
from langchain.schema import SystemMessage
|
|
24
23
|
from langchain_core.language_models import BaseChatModel
|
|
25
24
|
from langchain_core.runnables import RunnableLambda
|
|
26
|
-
from tqdm import tqdm
|
|
27
25
|
|
|
28
26
|
from nat.eval.evaluator.base_evaluator import BaseEvaluator
|
|
29
27
|
from nat.eval.evaluator.evaluator_model import EvalInputItem
|
|
@@ -31,7 +29,6 @@ from nat.eval.evaluator.evaluator_model import EvalOutputItem
|
|
|
31
29
|
|
|
32
30
|
logger = logging.getLogger(__name__)
|
|
33
31
|
|
|
34
|
-
# pylint: disable=line-too-long
|
|
35
32
|
# flake8: noqa: E501
|
|
36
33
|
|
|
37
34
|
|
|
@@ -185,8 +182,8 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
185
182
|
relevance_score = parsed_response["relevance_score"]
|
|
186
183
|
reasoning = parsed_response["reasoning"]
|
|
187
184
|
except KeyError as e:
|
|
188
|
-
logger.
|
|
189
|
-
|
|
185
|
+
logger.exception("Missing required keys in default scoring response: %s",
|
|
186
|
+
", ".join(str(arg) for arg in e.args))
|
|
190
187
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
191
188
|
|
|
192
189
|
coverage_weight = self.default_score_weights.get("coverage", 1 / 3)
|
|
@@ -218,7 +215,7 @@ class TunableRagEvaluator(BaseEvaluator):
|
|
|
218
215
|
reasoning = f"Error in evaluator from parsing judge LLM response. Missing required key(s): {', '.join(str(arg) for arg in e.args)}"
|
|
219
216
|
raise
|
|
220
217
|
except (KeyError, ValueError) as e:
|
|
221
|
-
logger.
|
|
218
|
+
logger.exception("Error parsing judge LLM response: %s", e)
|
|
222
219
|
score = 0.0
|
|
223
220
|
reasoning = "Error in evaluator from parsing judge LLM response."
|
|
224
221
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Type alias for evaluation call objects that have an optional 'id' attribute
|
|
24
|
+
EvalCallType = Any # Could be Weave Call object or other tracing framework objects
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EvalTraceContext:
|
|
28
|
+
"""
|
|
29
|
+
Evaluation trace context manager for coordinating traces.
|
|
30
|
+
|
|
31
|
+
This class provides a framework-agnostic way to:
|
|
32
|
+
1. Track evaluation calls/contexts
|
|
33
|
+
2. Ensure proper parent-child relationships in traces
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
self.eval_call: EvalCallType | None = None # Store the evaluation call/context for propagation
|
|
38
|
+
|
|
39
|
+
def set_eval_call(self, eval_call: EvalCallType | None) -> None:
|
|
40
|
+
"""Set the evaluation call/context for propagation to traces."""
|
|
41
|
+
self.eval_call = eval_call
|
|
42
|
+
if eval_call:
|
|
43
|
+
logger.debug("Set evaluation call context: %s", getattr(eval_call, 'id', str(eval_call)))
|
|
44
|
+
|
|
45
|
+
def get_eval_call(self) -> EvalCallType | None:
|
|
46
|
+
"""Get the current evaluation call/context."""
|
|
47
|
+
return self.eval_call
|
|
48
|
+
|
|
49
|
+
@contextmanager
|
|
50
|
+
def evaluation_context(self):
|
|
51
|
+
"""
|
|
52
|
+
Context manager that can be overridden by framework-specific implementations.
|
|
53
|
+
Default implementation is a no-op.
|
|
54
|
+
"""
|
|
55
|
+
yield
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class WeaveEvalTraceContext(EvalTraceContext):
|
|
59
|
+
"""
|
|
60
|
+
Weave-specific implementation of evaluation trace context.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.available = False
|
|
66
|
+
self.set_call_stack: Callable[[list[EvalCallType]], Any] | None = None
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from weave.trace.context.call_context import set_call_stack
|
|
70
|
+
self.set_call_stack = set_call_stack
|
|
71
|
+
self.available = True
|
|
72
|
+
except ImportError:
|
|
73
|
+
self.available = False
|
|
74
|
+
logger.debug("Weave not available for trace context")
|
|
75
|
+
|
|
76
|
+
@contextmanager
|
|
77
|
+
def evaluation_context(self):
|
|
78
|
+
"""Set the evaluation call as active context for Weave traces."""
|
|
79
|
+
if self.available and self.eval_call and self.set_call_stack:
|
|
80
|
+
try:
|
|
81
|
+
with self.set_call_stack([self.eval_call]):
|
|
82
|
+
logger.debug("Set Weave evaluation call context: %s",
|
|
83
|
+
getattr(self.eval_call, 'id', str(self.eval_call)))
|
|
84
|
+
yield
|
|
85
|
+
except Exception as e:
|
|
86
|
+
logger.warning("Failed to set Weave evaluation call context: %s", e)
|
|
87
|
+
yield
|
|
88
|
+
else:
|
|
89
|
+
yield
|
nat/eval/utils/weave_eval.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import logging
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
from nat.eval.evaluator.evaluator_model import EvalInput
|
|
@@ -24,26 +25,28 @@ from nat.eval.usage_stats import UsageStats
|
|
|
24
25
|
from nat.eval.usage_stats import UsageStatsItem
|
|
25
26
|
from nat.profiler.data_models import ProfilerResults
|
|
26
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from nat.eval.utils.eval_trace_ctx import EvalTraceContext
|
|
30
|
+
|
|
27
31
|
logger = logging.getLogger(__name__)
|
|
28
32
|
|
|
29
33
|
|
|
30
|
-
class WeaveEvaluationIntegration:
|
|
34
|
+
class WeaveEvaluationIntegration:
|
|
31
35
|
"""
|
|
32
36
|
Class to handle all Weave integration functionality.
|
|
33
37
|
"""
|
|
34
38
|
|
|
35
|
-
def __init__(self):
|
|
39
|
+
def __init__(self, eval_trace_context: "EvalTraceContext"):
|
|
36
40
|
self.available = False
|
|
37
41
|
self.client = None
|
|
38
42
|
self.eval_logger = None
|
|
39
43
|
self.pred_loggers = {}
|
|
44
|
+
self.eval_trace_context = eval_trace_context
|
|
40
45
|
|
|
41
46
|
try:
|
|
42
|
-
from weave
|
|
43
|
-
from weave.flow.eval_imperative import ScoreLogger
|
|
47
|
+
from weave import EvaluationLogger
|
|
44
48
|
from weave.trace.context import weave_client_context
|
|
45
|
-
self.
|
|
46
|
-
self.ScoreLogger = ScoreLogger
|
|
49
|
+
self.evaluation_logger_cls = EvaluationLogger
|
|
47
50
|
self.weave_client_context = weave_client_context
|
|
48
51
|
self.available = True
|
|
49
52
|
except ImportError:
|
|
@@ -89,9 +92,15 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
89
92
|
weave_dataset = self._get_weave_dataset(eval_input)
|
|
90
93
|
config_dict = config.model_dump(mode="json")
|
|
91
94
|
config_dict["name"] = workflow_alias
|
|
92
|
-
self.eval_logger = self.
|
|
95
|
+
self.eval_logger = self.evaluation_logger_cls(model=config_dict,
|
|
96
|
+
dataset=weave_dataset,
|
|
97
|
+
name=workflow_alias,
|
|
98
|
+
eval_attributes={})
|
|
93
99
|
self.pred_loggers = {}
|
|
94
100
|
|
|
101
|
+
# Capture the current evaluation call for context propagation
|
|
102
|
+
self.eval_trace_context.set_eval_call(self.eval_logger._evaluate_call)
|
|
103
|
+
|
|
95
104
|
return True
|
|
96
105
|
except Exception as e:
|
|
97
106
|
self.eval_logger = None
|
|
@@ -137,7 +146,7 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
137
146
|
await asyncio.gather(*coros)
|
|
138
147
|
|
|
139
148
|
async def afinish_loggers(self):
|
|
140
|
-
"""Finish all prediction loggers."""
|
|
149
|
+
"""Finish all prediction loggers and wait for exports."""
|
|
141
150
|
if not self.eval_logger:
|
|
142
151
|
return
|
|
143
152
|
|
|
@@ -157,7 +166,6 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
157
166
|
if profiler_results.workflow_runtime_metrics:
|
|
158
167
|
profile_metrics["wf_runtime_p95"] = profiler_results.workflow_runtime_metrics.p95
|
|
159
168
|
|
|
160
|
-
# TODO:get the LLM tokens from the usage stats and log them
|
|
161
169
|
profile_metrics["total_runtime"] = usage_stats.total_runtime
|
|
162
170
|
|
|
163
171
|
return profile_metrics
|
|
@@ -182,3 +190,4 @@ class WeaveEvaluationIntegration: # pylint: disable=too-many-public-methods
|
|
|
182
190
|
# Log the summary to finish the evaluation, disable auto-summarize
|
|
183
191
|
# as we will be adding profiler metrics to the summary
|
|
184
192
|
self.eval_logger.log_summary(summary, auto_summarize=False)
|
|
193
|
+
logger.info("Logged Evaluation Summary to Weave")
|
|
@@ -16,7 +16,12 @@
|
|
|
16
16
|
import functools
|
|
17
17
|
import inspect
|
|
18
18
|
import logging
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from collections.abc import Callable
|
|
21
|
+
from collections.abc import Generator
|
|
19
22
|
from typing import Any
|
|
23
|
+
from typing import TypeVar
|
|
24
|
+
from typing import overload
|
|
20
25
|
|
|
21
26
|
logger = logging.getLogger(__name__)
|
|
22
27
|
|
|
@@ -25,6 +30,9 @@ BASE_WARNING_MESSAGE = ("is experimental and the API may change in future releas
|
|
|
25
30
|
|
|
26
31
|
_warning_issued = set()
|
|
27
32
|
|
|
33
|
+
# Type variables for overloads
|
|
34
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
35
|
+
|
|
28
36
|
|
|
29
37
|
def issue_experimental_warning(function_name: str,
|
|
30
38
|
feature_name: str | None = None,
|
|
@@ -53,7 +61,20 @@ def issue_experimental_warning(function_name: str,
|
|
|
53
61
|
_warning_issued.add(function_name)
|
|
54
62
|
|
|
55
63
|
|
|
56
|
-
|
|
64
|
+
# Overloads for different function types
|
|
65
|
+
@overload
|
|
66
|
+
def experimental(func: F, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
67
|
+
"""Overload for when a function is passed directly."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@overload
|
|
72
|
+
def experimental(*, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
73
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def experimental(func: Any = None, *, feature_name: str | None = None, metadata: dict[str, Any] | None = None) -> Any:
|
|
57
78
|
"""
|
|
58
79
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
59
80
|
async generator) and logs a warning that the function is experimental.
|
|
@@ -90,7 +111,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
90
111
|
# ---------------------
|
|
91
112
|
|
|
92
113
|
@functools.wraps(func)
|
|
93
|
-
async def async_gen_wrapper(*args, **kwargs):
|
|
114
|
+
async def async_gen_wrapper(*args, **kwargs) -> AsyncGenerator[Any, Any]:
|
|
94
115
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
95
116
|
async for item in func(*args, **kwargs):
|
|
96
117
|
yield item # yield the original item
|
|
@@ -102,7 +123,7 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
102
123
|
# ASYNC FUNCTION
|
|
103
124
|
# ---------------------
|
|
104
125
|
@functools.wraps(func)
|
|
105
|
-
async def async_wrapper(*args, **kwargs):
|
|
126
|
+
async def async_wrapper(*args, **kwargs) -> Any:
|
|
106
127
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
107
128
|
result = await func(*args, **kwargs)
|
|
108
129
|
return result
|
|
@@ -114,15 +135,14 @@ def experimental(func: Any = None, *, feature_name: str | None = None, metadata:
|
|
|
114
135
|
# SYNC GENERATOR
|
|
115
136
|
# ---------------------
|
|
116
137
|
@functools.wraps(func)
|
|
117
|
-
def sync_gen_wrapper(*args, **kwargs):
|
|
138
|
+
def sync_gen_wrapper(*args, **kwargs) -> Generator[Any, Any, Any]:
|
|
118
139
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
119
|
-
|
|
120
|
-
yield item # yield the original item
|
|
140
|
+
yield from func(*args, **kwargs) # yield the original item
|
|
121
141
|
|
|
122
142
|
return sync_gen_wrapper
|
|
123
143
|
|
|
124
144
|
@functools.wraps(func)
|
|
125
|
-
def sync_wrapper(*args, **kwargs):
|
|
145
|
+
def sync_wrapper(*args, **kwargs) -> Any:
|
|
126
146
|
issue_experimental_warning(function_name, feature_name, metadata)
|
|
127
147
|
result = func(*args, **kwargs)
|
|
128
148
|
return result
|
|
@@ -86,7 +86,7 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
86
86
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
87
87
|
|
|
88
88
|
# Get the augmented function's description
|
|
89
|
-
augmented_function = builder.get_function(config.augmented_fn)
|
|
89
|
+
augmented_function = await builder.get_function(config.augmented_fn)
|
|
90
90
|
|
|
91
91
|
# For now, we rely on runtime checking for type conversion
|
|
92
92
|
|
|
@@ -97,11 +97,15 @@ async def plan_select_execute_function(config: PlanSelectExecuteFunctionConfig,
|
|
|
97
97
|
f"function without a description.")
|
|
98
98
|
|
|
99
99
|
# Get the function dependencies of the augmented function
|
|
100
|
-
|
|
100
|
+
function_dependencies = builder.get_function_dependencies(config.augmented_fn)
|
|
101
|
+
function_used_tools = set(function_dependencies.functions)
|
|
102
|
+
for function_group in function_dependencies.function_groups:
|
|
103
|
+
function_used_tools.update(builder.get_function_group_dependencies(function_group).functions)
|
|
104
|
+
|
|
101
105
|
tool_list = "Tool: Description\n"
|
|
102
106
|
|
|
103
107
|
for tool in function_used_tools:
|
|
104
|
-
tool_impl = builder.get_function(tool)
|
|
108
|
+
tool_impl = await builder.get_function(tool)
|
|
105
109
|
tool_list += f"- {tool}: {tool_impl.description if hasattr(tool_impl, 'description') else ''}\n"
|
|
106
110
|
|
|
107
111
|
# Draft the reasoning prompt for the augmented function
|
|
@@ -82,7 +82,7 @@ async def register_ttc_tool_orchestration_function(
|
|
|
82
82
|
function_map = {}
|
|
83
83
|
for fn_ref in config.augmented_fns:
|
|
84
84
|
# Retrieve the actual function from the builder
|
|
85
|
-
fn_obj = builder.get_function(fn_ref)
|
|
85
|
+
fn_obj = await builder.get_function(fn_ref)
|
|
86
86
|
function_map[fn_ref] = fn_obj
|
|
87
87
|
|
|
88
88
|
# 2) Instantiate search, editing, scoring, selection strategies (if any)
|
|
@@ -148,13 +148,13 @@ async def register_ttc_tool_orchestration_function(
|
|
|
148
148
|
result = await fn.acall_invoke(item.output)
|
|
149
149
|
return item, result, None
|
|
150
150
|
except Exception as e:
|
|
151
|
-
logger.
|
|
151
|
+
logger.exception(f"Error invoking function '{item.name}': {e}")
|
|
152
152
|
return item, None, str(e)
|
|
153
153
|
|
|
154
154
|
tasks = []
|
|
155
155
|
for item in ttc_items:
|
|
156
156
|
if item.name not in function_map:
|
|
157
|
-
logger.error(f"Function '{item.name}' not found in function map.")
|
|
157
|
+
logger.error(f"Function '{item.name}' not found in function map.", exc_info=True)
|
|
158
158
|
item.output = f"Error: Function '{item.name}' not found in function map. Check your input"
|
|
159
159
|
else:
|
|
160
160
|
fn = function_map[item.name]
|
|
@@ -80,7 +80,7 @@ async def register_ttc_tool_wrapper_function(
|
|
|
80
80
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
81
81
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
82
82
|
|
|
83
|
-
augmented_function: Function = builder.get_function(config.augmented_fn)
|
|
83
|
+
augmented_function: Function = await builder.get_function(config.augmented_fn)
|
|
84
84
|
input_llm: BaseChatModel = await builder.get_llm(config.input_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
85
85
|
|
|
86
86
|
if not augmented_function.has_single_output:
|
|
@@ -17,9 +17,10 @@ from abc import ABC
|
|
|
17
17
|
from abc import abstractmethod
|
|
18
18
|
|
|
19
19
|
from nat.builder.builder import Builder
|
|
20
|
-
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
21
|
-
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum, PipelineTypeEnum
|
|
22
20
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
21
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
22
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
23
|
+
from nat.experimental.test_time_compute.models.ttc_item import TTCItem
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class StrategyBase(ABC):
|
|
@@ -45,11 +46,11 @@ class StrategyBase(ABC):
|
|
|
45
46
|
items: list[TTCItem],
|
|
46
47
|
original_prompt: str | None = None,
|
|
47
48
|
agent_context: str | None = None,
|
|
48
|
-
**kwargs) -> [TTCItem]:
|
|
49
|
+
**kwargs) -> list[TTCItem]:
|
|
49
50
|
pass
|
|
50
51
|
|
|
51
52
|
@abstractmethod
|
|
52
|
-
def supported_pipeline_types(self) -> [PipelineTypeEnum]:
|
|
53
|
+
def supported_pipeline_types(self) -> list[PipelineTypeEnum]:
|
|
53
54
|
"""Return the stage types supported by this selector."""
|
|
54
55
|
pass
|
|
55
56
|
|
|
@@ -71,7 +71,7 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
71
71
|
raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
|
|
72
72
|
"This error can be resolved by installing nvidia-nat-langchain.")
|
|
73
73
|
|
|
74
|
-
from
|
|
74
|
+
from collections.abc import Callable
|
|
75
75
|
|
|
76
76
|
from pydantic import BaseModel
|
|
77
77
|
|
|
@@ -135,8 +135,6 @@ class LLMBasedOutputMergingSelector(StrategyBase):
|
|
|
135
135
|
except Exception as e:
|
|
136
136
|
logger.error(f"Error parsing merged output: {e}")
|
|
137
137
|
raise ValueError("Failed to parse merged output.")
|
|
138
|
-
else:
|
|
139
|
-
merged_output = merged_output
|
|
140
138
|
|
|
141
139
|
logger.info("Merged output: %s", str(merged_output))
|
|
142
140
|
|
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import logging
|
|
17
18
|
import secrets
|
|
18
19
|
import webbrowser
|
|
19
20
|
from dataclasses import dataclass
|
|
20
21
|
from dataclasses import field
|
|
21
22
|
|
|
22
23
|
import click
|
|
24
|
+
import httpx
|
|
23
25
|
import pkce
|
|
26
|
+
from authlib.common.errors import AuthlibBaseError as OAuthError
|
|
24
27
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
25
28
|
from fastapi import FastAPI
|
|
26
29
|
from fastapi import Request
|
|
@@ -32,6 +35,8 @@ from nat.data_models.authentication import AuthFlowType
|
|
|
32
35
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
33
36
|
from nat.front_ends.fastapi.fastapi_front_end_controller import _FastApiFrontEndController
|
|
34
37
|
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
35
40
|
|
|
36
41
|
# --------------------------------------------------------------------------- #
|
|
37
42
|
# Helpers #
|
|
@@ -87,17 +92,53 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
87
92
|
"""
|
|
88
93
|
Separated for easy overriding in tests (to inject ASGITransport).
|
|
89
94
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
95
|
+
try:
|
|
96
|
+
client = AsyncOAuth2Client(
|
|
97
|
+
client_id=cfg.client_id,
|
|
98
|
+
client_secret=cfg.client_secret,
|
|
99
|
+
redirect_uri=cfg.redirect_uri,
|
|
100
|
+
scope=" ".join(cfg.scopes) if cfg.scopes else None,
|
|
101
|
+
token_endpoint=cfg.token_url,
|
|
102
|
+
token_endpoint_auth_method=cfg.token_endpoint_auth_method,
|
|
103
|
+
code_challenge_method="S256" if cfg.use_pkce else None,
|
|
104
|
+
)
|
|
105
|
+
self._oauth_client = client
|
|
106
|
+
return client
|
|
107
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
108
|
+
raise RuntimeError(f"Invalid OAuth2 configuration: {e}") from e
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Failed to create OAuth2 client: {e}") from e
|
|
111
|
+
|
|
112
|
+
def _create_authorization_url(self,
|
|
113
|
+
client: AsyncOAuth2Client,
|
|
114
|
+
config: OAuth2AuthCodeFlowProviderConfig,
|
|
115
|
+
state: str,
|
|
116
|
+
verifier: str | None = None,
|
|
117
|
+
challenge: str | None = None) -> str:
|
|
118
|
+
"""
|
|
119
|
+
Create OAuth authorization URL with proper error handling.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
client: The OAuth2 client instance
|
|
123
|
+
config: OAuth2 configuration
|
|
124
|
+
state: OAuth state parameter
|
|
125
|
+
verifier: PKCE verifier (if using PKCE)
|
|
126
|
+
challenge: PKCE challenge (if using PKCE)
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The authorization URL
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
auth_url, _ = client.create_authorization_url(
|
|
133
|
+
config.authorization_url,
|
|
134
|
+
state=state,
|
|
135
|
+
code_verifier=verifier if config.use_pkce else None,
|
|
136
|
+
code_challenge=challenge if config.use_pkce else None,
|
|
137
|
+
**(config.authorization_kwargs or {})
|
|
138
|
+
)
|
|
139
|
+
return auth_url
|
|
140
|
+
except (OAuthError, ValueError, TypeError) as e:
|
|
141
|
+
raise RuntimeError(f"Error creating OAuth authorization URL: {e}") from e
|
|
101
142
|
|
|
102
143
|
# --------------------------- HTTP Basic ------------------------------ #
|
|
103
144
|
@staticmethod
|
|
@@ -131,13 +172,12 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
131
172
|
flow_state.verifier = verifier
|
|
132
173
|
flow_state.challenge = challenge
|
|
133
174
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
)
|
|
175
|
+
# Create authorization URL using helper function
|
|
176
|
+
auth_url = self._create_authorization_url(client=client,
|
|
177
|
+
config=cfg,
|
|
178
|
+
state=state,
|
|
179
|
+
verifier=flow_state.verifier,
|
|
180
|
+
challenge=flow_state.challenge)
|
|
141
181
|
|
|
142
182
|
# Register flow + maybe spin up redirect handler
|
|
143
183
|
async with self._server_lock:
|
|
@@ -149,14 +189,18 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
149
189
|
self._flows[state] = flow_state
|
|
150
190
|
self._active_flows += 1
|
|
151
191
|
|
|
152
|
-
|
|
153
|
-
|
|
192
|
+
try:
|
|
193
|
+
webbrowser.open(auth_url)
|
|
194
|
+
click.echo("Your browser has been opened for authentication.")
|
|
195
|
+
except Exception as e:
|
|
196
|
+
logger.error("Browser open failed: %s", e)
|
|
197
|
+
raise RuntimeError(f"Browser open failed: {e}") from e
|
|
154
198
|
|
|
155
199
|
# Wait for the redirect to land
|
|
156
200
|
try:
|
|
157
201
|
token = await asyncio.wait_for(flow_state.future, timeout=300)
|
|
158
|
-
except
|
|
159
|
-
raise RuntimeError("Authentication timed out (5 min).")
|
|
202
|
+
except TimeoutError as exc:
|
|
203
|
+
raise RuntimeError("Authentication timed out (5 min).") from exc
|
|
160
204
|
finally:
|
|
161
205
|
async with self._server_lock:
|
|
162
206
|
self._flows.pop(state, None)
|
|
@@ -175,9 +219,9 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
175
219
|
# --------------- redirect server / in‑process app -------------------- #
|
|
176
220
|
async def _build_redirect_app(self) -> FastAPI:
|
|
177
221
|
"""
|
|
178
|
-
* If cfg.run_redirect_local_server == True → start a
|
|
179
|
-
* Else → only build the
|
|
180
|
-
for in‑process testing
|
|
222
|
+
* If cfg.run_redirect_local_server == True → start a local server.
|
|
223
|
+
* Else → only build the redirect app and save it to `self._redirect_app`
|
|
224
|
+
for in‑process testing.
|
|
181
225
|
"""
|
|
182
226
|
app = FastAPI()
|
|
183
227
|
|
|
@@ -195,8 +239,16 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
195
239
|
state=state,
|
|
196
240
|
)
|
|
197
241
|
flow_state.future.set_result(token)
|
|
198
|
-
except
|
|
199
|
-
flow_state.future.set_exception(
|
|
242
|
+
except OAuthError as e:
|
|
243
|
+
flow_state.future.set_exception(
|
|
244
|
+
RuntimeError(f"Authorization server rejected request: {e.error} ({e.description})"))
|
|
245
|
+
return "Authentication failed: Authorization server rejected the request. You may close this tab."
|
|
246
|
+
except httpx.HTTPError as e:
|
|
247
|
+
flow_state.future.set_exception(RuntimeError(f"Network error during token fetch: {e}"))
|
|
248
|
+
return "Authentication failed: Network error occurred. You may close this tab."
|
|
249
|
+
except Exception as e:
|
|
250
|
+
flow_state.future.set_exception(RuntimeError(f"Authentication failed: {e}"))
|
|
251
|
+
return "Authentication failed: An unexpected error occurred. You may close this tab."
|
|
200
252
|
return "Authentication successful – you may close this tab."
|
|
201
253
|
|
|
202
254
|
return app
|
|
@@ -213,7 +265,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
213
265
|
|
|
214
266
|
asyncio.create_task(self._server_controller.start_server(host="localhost", port=8000))
|
|
215
267
|
|
|
216
|
-
# Give
|
|
268
|
+
# Give the server a moment to bind sockets before we return
|
|
217
269
|
await asyncio.sleep(0.3)
|
|
218
270
|
except Exception as exc: # noqa: BLE001
|
|
219
271
|
raise RuntimeError(f"Failed to start redirect server: {exc}") from exc
|
|
@@ -227,7 +279,7 @@ class ConsoleAuthenticationFlowHandler(FlowHandlerBase):
|
|
|
227
279
|
@property
|
|
228
280
|
def redirect_app(self) -> FastAPI | None:
|
|
229
281
|
"""
|
|
230
|
-
In
|
|
231
|
-
app is exposed
|
|
282
|
+
In test mode (run_redirect_local_server=False) the in‑memory redirect
|
|
283
|
+
app is exposed for testing purposes.
|
|
232
284
|
"""
|
|
233
285
|
return self._redirect_app
|
|
@@ -55,9 +55,10 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
55
55
|
self.auth_flow_handler = ConsoleAuthenticationFlowHandler()
|
|
56
56
|
|
|
57
57
|
async def pre_run(self):
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
if (self.front_end_config.input_query is not None and self.front_end_config.input_file is not None):
|
|
59
|
+
raise click.UsageError("Must specify either --input or --input_file, not both")
|
|
60
|
+
if (self.front_end_config.input_query is None and self.front_end_config.input_file is None):
|
|
61
|
+
raise click.UsageError("Must specify either --input or --input_file")
|
|
61
62
|
|
|
62
63
|
async def run_workflow(self, session_manager: SessionManager):
|
|
63
64
|
|
|
@@ -80,12 +81,14 @@ class ConsoleFrontEndPlugin(SimpleFrontEndPluginBase[ConsoleFrontEndConfig]):
|
|
|
80
81
|
input_list = list(self.front_end_config.input_query)
|
|
81
82
|
logger.debug("Processing input: %s", self.front_end_config.input_query)
|
|
82
83
|
|
|
83
|
-
|
|
84
|
+
# Make `return_exceptions=False` explicit; all exceptions are raised instead of being silenced
|
|
85
|
+
runner_outputs = await asyncio.gather(*[run_single_query(query) for query in input_list],
|
|
86
|
+
return_exceptions=False)
|
|
84
87
|
|
|
85
88
|
elif (self.front_end_config.input_file):
|
|
86
89
|
|
|
87
90
|
# Run the workflow
|
|
88
|
-
with open(self.front_end_config.input_file,
|
|
91
|
+
with open(self.front_end_config.input_file, encoding="utf-8") as f:
|
|
89
92
|
|
|
90
93
|
async with session_manager.workflow.run(f) as runner:
|
|
91
94
|
runner_outputs = await runner.result(to_type=str)
|