nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
22
|
+
from nat.builder.function_info import FunctionInfo
|
|
23
|
+
from nat.cli.register_workflow import register_function
|
|
24
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
25
|
+
from nat.data_models.component_ref import FunctionRef
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RouterAgentWorkflowConfig(AgentBaseConfig, name="router_agent"):
|
|
31
|
+
"""
|
|
32
|
+
A router agent takes in the incoming message, combines it with a prompt and the list of branches,
|
|
33
|
+
and ask a LLM about which branch to take.
|
|
34
|
+
"""
|
|
35
|
+
description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
|
|
36
|
+
branches: list[FunctionRef] = Field(default_factory=list,
|
|
37
|
+
description="The list of branches to provide to the router agent.")
|
|
38
|
+
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
39
|
+
user_prompt: str | None = Field(default=None, description="Provides the prompt to use with the agent.")
|
|
40
|
+
max_router_retries: int = Field(
|
|
41
|
+
default=3, description="Maximum number of retries if the router agent fails to choose a branch.")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@register_function(config_type=RouterAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
45
|
+
async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Builder):
|
|
46
|
+
from langchain_core.messages.human import HumanMessage
|
|
47
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
48
|
+
|
|
49
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
50
|
+
from nat.control_flow.router_agent.agent import RouterAgentGraph
|
|
51
|
+
from nat.control_flow.router_agent.agent import RouterAgentGraphState
|
|
52
|
+
from nat.control_flow.router_agent.agent import create_router_agent_prompt
|
|
53
|
+
|
|
54
|
+
prompt = create_router_agent_prompt(config)
|
|
55
|
+
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
56
|
+
branches = await builder.get_tools(tool_names=config.branches, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
57
|
+
if not branches:
|
|
58
|
+
raise ValueError(f"No branches specified for Router Agent '{config.llm_name}'")
|
|
59
|
+
|
|
60
|
+
graph: CompiledStateGraph = await RouterAgentGraph(
|
|
61
|
+
llm=llm,
|
|
62
|
+
branches=branches,
|
|
63
|
+
prompt=prompt,
|
|
64
|
+
max_router_retries=config.max_router_retries,
|
|
65
|
+
detailed_logs=config.verbose,
|
|
66
|
+
log_response_max_chars=config.log_response_max_chars,
|
|
67
|
+
).build_graph()
|
|
68
|
+
|
|
69
|
+
async def _response_fn(input_message: str) -> str:
|
|
70
|
+
try:
|
|
71
|
+
message = HumanMessage(content=input_message)
|
|
72
|
+
state = RouterAgentGraphState(forward_message=message)
|
|
73
|
+
|
|
74
|
+
result_dict = await graph.ainvoke(state)
|
|
75
|
+
result_state = RouterAgentGraphState(**result_dict)
|
|
76
|
+
|
|
77
|
+
output_message = result_state.messages[-1]
|
|
78
|
+
return str(output_message.content)
|
|
79
|
+
|
|
80
|
+
except Exception as ex:
|
|
81
|
+
logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
82
|
+
if config.verbose:
|
|
83
|
+
return str(ex)
|
|
84
|
+
return f"Router agent failed with exception: {ex}"
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
yield FunctionInfo.from_fn(_response_fn, description=config.description)
|
|
88
|
+
except GeneratorExit:
|
|
89
|
+
logger.exception("%s Workflow exited early!", AGENT_LOG_PREFIX)
|
|
90
|
+
finally:
|
|
91
|
+
logger.debug("%s Cleaning up router_agent workflow.", AGENT_LOG_PREFIX)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from langchain_core.tools.base import BaseTool
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import Field
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.builder.function import Function
|
|
26
|
+
from nat.builder.function_info import FunctionInfo
|
|
27
|
+
from nat.cli.register_workflow import register_function
|
|
28
|
+
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
from nat.utils.type_utils import DecomposedType
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ToolExecutionConfig(BaseModel):
|
|
36
|
+
"""Configuration for individual tool execution within sequential execution."""
|
|
37
|
+
|
|
38
|
+
use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
|
|
42
|
+
"""Configuration for sequential execution of a list of functions."""
|
|
43
|
+
|
|
44
|
+
tool_list: list[FunctionRef] = Field(default_factory=list,
|
|
45
|
+
description="A list of functions to execute sequentially.")
|
|
46
|
+
tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
|
|
47
|
+
description="Optional configuration for each"
|
|
48
|
+
"tool in the sequential execution tool list."
|
|
49
|
+
"Keys must match the tool names from the"
|
|
50
|
+
"tool_list.")
|
|
51
|
+
raise_type_incompatibility: bool = Field(
|
|
52
|
+
default=False,
|
|
53
|
+
description="Default to False. Check if the adjacent tools are type compatible,"
|
|
54
|
+
"which means the output type of the previous function is compatible with the input type of the next function."
|
|
55
|
+
"If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
|
|
56
|
+
"generate a warning message and the sequential execution will continue.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
|
|
60
|
+
function_config = tool_execution_config.get(function.instance_name, None)
|
|
61
|
+
if function_config:
|
|
62
|
+
return function.streaming_output_type if function_config.use_streaming else function.single_output_type
|
|
63
|
+
else:
|
|
64
|
+
return function.single_output_type
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _validate_function_type_compatibility(src_fn: Function,
|
|
68
|
+
target_fn: Function,
|
|
69
|
+
tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
|
|
70
|
+
src_output_type = _get_function_output_type(src_fn, tool_execution_config)
|
|
71
|
+
target_input_type = target_fn.input_type
|
|
72
|
+
logger.debug(
|
|
73
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
74
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
75
|
+
|
|
76
|
+
is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
|
|
77
|
+
if not is_compatible:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
80
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
async def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
|
|
84
|
+
builder: Builder) -> tuple[type, type]:
|
|
85
|
+
tool_list = sequential_executor_config.tool_list
|
|
86
|
+
tool_execution_config = sequential_executor_config.tool_execution_config
|
|
87
|
+
|
|
88
|
+
function_list = await builder.get_functions(tool_list)
|
|
89
|
+
if not function_list:
|
|
90
|
+
raise RuntimeError("The function list is empty")
|
|
91
|
+
input_type = function_list[0].input_type
|
|
92
|
+
|
|
93
|
+
if len(function_list) > 1:
|
|
94
|
+
for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
|
|
95
|
+
try:
|
|
96
|
+
_validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
|
|
97
|
+
except ValueError as e:
|
|
98
|
+
raise ValueError(f"The sequential tool list has incompatible types: {e}")
|
|
99
|
+
|
|
100
|
+
output_type = _get_function_output_type(function_list[-1], tool_execution_config)
|
|
101
|
+
logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
|
|
102
|
+
f"the output type is {str(output_type)}")
|
|
103
|
+
|
|
104
|
+
return (input_type, output_type)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
108
|
+
async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
|
|
109
|
+
logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
|
|
110
|
+
|
|
111
|
+
tools: list[BaseTool] = await builder.get_tools(tool_names=config.tool_list,
|
|
112
|
+
wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
113
|
+
tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
input_type, output_type = await _validate_tool_list_type_compatibility(config, builder)
|
|
117
|
+
except ValueError as e:
|
|
118
|
+
if config.raise_type_incompatibility:
|
|
119
|
+
logger.error(f"The sequential executor tool list has incompatible types: {e}")
|
|
120
|
+
raise
|
|
121
|
+
else:
|
|
122
|
+
logger.warning(f"The sequential executor tool list has incompatible types: {e}")
|
|
123
|
+
input_type = typing.Any
|
|
124
|
+
output_type = typing.Any
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise ValueError(f"Error with the sequential executor tool list: {e}")
|
|
127
|
+
|
|
128
|
+
# The type annotation of _sequential_function_execution is dynamically set according to the tool list
|
|
129
|
+
async def _sequential_function_execution(initial_tool_input):
|
|
130
|
+
logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
|
|
131
|
+
|
|
132
|
+
tool_list: list[FunctionRef] = config.tool_list
|
|
133
|
+
tool_input = initial_tool_input
|
|
134
|
+
tool_response = None
|
|
135
|
+
|
|
136
|
+
for tool_name in tool_list:
|
|
137
|
+
tool = tools_dict[tool_name]
|
|
138
|
+
tool_execution_config = config.tool_execution_config.get(tool_name, None)
|
|
139
|
+
logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
|
|
140
|
+
try:
|
|
141
|
+
if tool_execution_config:
|
|
142
|
+
if tool_execution_config.use_streaming:
|
|
143
|
+
output = ""
|
|
144
|
+
async for chunk in tool.astream(tool_input):
|
|
145
|
+
output += chunk.content
|
|
146
|
+
tool_response = output
|
|
147
|
+
else:
|
|
148
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
149
|
+
else:
|
|
150
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.error(f"Error with tool {tool_name}: {e}")
|
|
153
|
+
raise
|
|
154
|
+
|
|
155
|
+
# The input of the next tool is the response of the previous tool
|
|
156
|
+
tool_input = tool_response
|
|
157
|
+
|
|
158
|
+
return tool_response
|
|
159
|
+
|
|
160
|
+
# Dynamically set the annotations for the function
|
|
161
|
+
_sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
|
|
162
|
+
logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
|
|
163
|
+
|
|
164
|
+
yield FunctionInfo.from_fn(_sequential_function_execution,
|
|
165
|
+
description="Executes a list of functions sequentially."
|
|
166
|
+
"The input of the next tool is the response of the previous tool.")
|
nat/data_models/agent.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import Field
|
|
17
|
+
from pydantic import PositiveInt
|
|
18
|
+
|
|
19
|
+
from nat.data_models.component_ref import LLMRef
|
|
20
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AgentBaseConfig(FunctionBaseConfig):
|
|
24
|
+
"""Base configuration class for all NAT agents with common fields."""
|
|
25
|
+
|
|
26
|
+
workflow_alias: str | None = Field(
|
|
27
|
+
default=None,
|
|
28
|
+
description=("The alias of the workflow. Useful when the agent is configured as a workflow "
|
|
29
|
+
"and needs to expose a customized name as a tool."))
|
|
30
|
+
llm_name: LLMRef = Field(description="The LLM model to use with the agent.")
|
|
31
|
+
verbose: bool = Field(default=False, description="Set the verbosity of the agent's logging.")
|
|
32
|
+
description: str = Field(description="The description of this function's use.")
|
|
33
|
+
log_response_max_chars: PositiveInt = Field(
|
|
34
|
+
default=1000, description="Maximum number of characters to display in logs when logging responses.")
|
nat/data_models/api_server.py
CHANGED
|
@@ -273,7 +273,7 @@ class ChatResponse(ResponseBaseModelOutput):
|
|
|
273
273
|
if model is None:
|
|
274
274
|
model = ""
|
|
275
275
|
if created is None:
|
|
276
|
-
created = datetime.datetime.now(datetime.
|
|
276
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
277
277
|
|
|
278
278
|
return ChatResponse(id=id_,
|
|
279
279
|
object=object_,
|
|
@@ -317,7 +317,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
317
317
|
if id_ is None:
|
|
318
318
|
id_ = str(uuid.uuid4())
|
|
319
319
|
if created is None:
|
|
320
|
-
created = datetime.datetime.now(datetime.
|
|
320
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
321
321
|
if model is None:
|
|
322
322
|
model = ""
|
|
323
323
|
if object_ is None:
|
|
@@ -343,7 +343,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
|
|
|
343
343
|
if id_ is None:
|
|
344
344
|
id_ = str(uuid.uuid4())
|
|
345
345
|
if created is None:
|
|
346
|
-
created = datetime.datetime.now(datetime.
|
|
346
|
+
created = datetime.datetime.now(datetime.UTC)
|
|
347
347
|
if model is None:
|
|
348
348
|
model = ""
|
|
349
349
|
|
|
@@ -485,7 +485,7 @@ class WebSocketUserMessage(BaseModel):
|
|
|
485
485
|
security: Security = Security()
|
|
486
486
|
error: Error = Error()
|
|
487
487
|
schema_version: str = "1.0.0"
|
|
488
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
488
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
489
489
|
|
|
490
490
|
|
|
491
491
|
class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
@@ -501,7 +501,7 @@ class WebSocketUserInteractionResponseMessage(BaseModel):
|
|
|
501
501
|
security: Security = Security()
|
|
502
502
|
error: Error = Error()
|
|
503
503
|
schema_version: str = "1.0.0"
|
|
504
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
504
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
505
505
|
|
|
506
506
|
|
|
507
507
|
class SystemIntermediateStepContent(BaseModel):
|
|
@@ -527,7 +527,7 @@ class WebSocketSystemIntermediateStepMessage(BaseModel):
|
|
|
527
527
|
conversation_id: str | None = None
|
|
528
528
|
content: SystemIntermediateStepContent
|
|
529
529
|
status: WebSocketMessageStatus
|
|
530
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
530
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
531
531
|
|
|
532
532
|
|
|
533
533
|
class SystemResponseContent(BaseModel):
|
|
@@ -551,7 +551,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
551
551
|
conversation_id: str | None = None
|
|
552
552
|
content: SystemResponseContent | Error | GenerateResponse
|
|
553
553
|
status: WebSocketMessageStatus
|
|
554
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
554
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
555
555
|
|
|
556
556
|
@field_validator("content")
|
|
557
557
|
@classmethod
|
|
@@ -560,7 +560,7 @@ class WebSocketSystemResponseTokenMessage(BaseModel):
|
|
|
560
560
|
raise ValueError(f"Field: content must be 'Error' when type is {WebSocketMessageType.ERROR_MESSAGE}")
|
|
561
561
|
|
|
562
562
|
if info.data.get("type") == WebSocketMessageType.RESPONSE_MESSAGE and not isinstance(
|
|
563
|
-
value,
|
|
563
|
+
value, SystemResponseContent | GenerateResponse):
|
|
564
564
|
raise ValueError(
|
|
565
565
|
f"Field: content must be 'SystemResponseContent' when type is {WebSocketMessageType.RESPONSE_MESSAGE}")
|
|
566
566
|
return value
|
|
@@ -582,7 +582,7 @@ class WebSocketSystemInteractionMessage(BaseModel):
|
|
|
582
582
|
conversation_id: str | None = None
|
|
583
583
|
content: HumanPrompt
|
|
584
584
|
status: WebSocketMessageStatus
|
|
585
|
-
timestamp: str = str(datetime.datetime.now(datetime.
|
|
585
|
+
timestamp: str = str(datetime.datetime.now(datetime.UTC))
|
|
586
586
|
|
|
587
587
|
|
|
588
588
|
# ======== GenerateResponse Converters ========
|
|
@@ -688,7 +688,7 @@ GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
|
|
|
688
688
|
|
|
689
689
|
# ======== AINodeMessageChunk Converters ========
|
|
690
690
|
def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
|
|
691
|
-
'''Converts LangChain AINodeMessageChunk to ChatResponseChunk'''
|
|
691
|
+
'''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk'''
|
|
692
692
|
content = ""
|
|
693
693
|
if hasattr(data, 'content') and data.content is not None:
|
|
694
694
|
content = str(data.content)
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from datetime import UTC
|
|
17
18
|
from datetime import datetime
|
|
18
|
-
from datetime import timezone
|
|
19
19
|
from enum import Enum
|
|
20
20
|
|
|
21
21
|
import httpx
|
|
@@ -166,17 +166,31 @@ class BearerTokenCred(_CredBase):
|
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
Credential = typing.Annotated[
|
|
169
|
-
|
|
170
|
-
HeaderCred,
|
|
171
|
-
QueryCred,
|
|
172
|
-
CookieCred,
|
|
173
|
-
BasicAuthCred,
|
|
174
|
-
BearerTokenCred,
|
|
175
|
-
],
|
|
169
|
+
HeaderCred | QueryCred | CookieCred | BasicAuthCred | BearerTokenCred,
|
|
176
170
|
Field(discriminator="kind"),
|
|
177
171
|
]
|
|
178
172
|
|
|
179
173
|
|
|
174
|
+
class TokenValidationResult(BaseModel):
|
|
175
|
+
"""
|
|
176
|
+
Standard result for Bearer Token Validation.
|
|
177
|
+
"""
|
|
178
|
+
model_config = ConfigDict(extra="forbid")
|
|
179
|
+
|
|
180
|
+
client_id: str | None = Field(description="OAuth2 client identifier")
|
|
181
|
+
scopes: list[str] | None = Field(default=None, description="List of granted scopes (introspection only)")
|
|
182
|
+
expires_at: int | None = Field(default=None, description="Token expiration time (Unix timestamp)")
|
|
183
|
+
audience: list[str] | None = Field(default=None, description="Token audiences (aud claim)")
|
|
184
|
+
subject: str | None = Field(default=None, description="Token subject (sub claim)")
|
|
185
|
+
issuer: str | None = Field(default=None, description="Token issuer (iss claim)")
|
|
186
|
+
token_type: str = Field(description="Token type")
|
|
187
|
+
active: bool | None = Field(default=True, description="Token active status")
|
|
188
|
+
nbf: int | None = Field(default=None, description="Not before time (Unix timestamp)")
|
|
189
|
+
iat: int | None = Field(default=None, description="Issued at time (Unix timestamp)")
|
|
190
|
+
jti: str | None = Field(default=None, description="JWT ID")
|
|
191
|
+
username: str | None = Field(default=None, description="Username (introspection only)")
|
|
192
|
+
|
|
193
|
+
|
|
180
194
|
class AuthResult(BaseModel):
|
|
181
195
|
"""
|
|
182
196
|
Represents the result of an authentication process.
|
|
@@ -193,7 +207,7 @@ class AuthResult(BaseModel):
|
|
|
193
207
|
"""
|
|
194
208
|
Checks if the authentication token has expired.
|
|
195
209
|
"""
|
|
196
|
-
return bool(self.token_expires_at and datetime.now(
|
|
210
|
+
return bool(self.token_expires_at and datetime.now(UTC) >= self.token_expires_at)
|
|
197
211
|
|
|
198
212
|
def as_requests_kwargs(self) -> dict[str, typing.Any]:
|
|
199
213
|
"""
|
nat/data_models/common.py
CHANGED
|
@@ -160,7 +160,7 @@ class TypedBaseModel(BaseModel):
|
|
|
160
160
|
|
|
161
161
|
@staticmethod
|
|
162
162
|
def discriminator(v: typing.Any) -> str | None:
|
|
163
|
-
# If
|
|
163
|
+
# If it's serialized, then we use the alias
|
|
164
164
|
if isinstance(v, dict):
|
|
165
165
|
return v.get("_type", v.get("type"))
|
|
166
166
|
|
nat/data_models/component.py
CHANGED
|
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
|
|
|
27
27
|
EVALUATOR = "evaluator"
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
|
+
FUNCTION_GROUP = "function_group"
|
|
30
31
|
TTC_STRATEGY = "ttc_strategy"
|
|
31
32
|
LLM_CLIENT = "llm_client"
|
|
32
33
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
|
|
|
47
48
|
AUTHENTICATION = "authentication"
|
|
48
49
|
EMBEDDERS = "embedders"
|
|
49
50
|
FUNCTIONS = "functions"
|
|
51
|
+
FUNCTION_GROUPS = "function_groups"
|
|
50
52
|
TTC_STRATEGIES = "ttc_strategies"
|
|
51
53
|
LLMS = "llms"
|
|
52
54
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
|
|
|
102
102
|
return ComponentGroup.FUNCTIONS
|
|
103
103
|
|
|
104
104
|
|
|
105
|
+
class FunctionGroupRef(ComponentRef):
|
|
106
|
+
"""
|
|
107
|
+
A reference to a function group in a NAT configuration object.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
@override
|
|
112
|
+
def component_group(self):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
114
|
+
|
|
115
|
+
|
|
105
116
|
class LLMRef(ComponentRef):
|
|
106
117
|
"""
|
|
107
118
|
A reference to an LLM in a NAT configuration object.
|
nat/data_models/config.py
CHANGED
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
from pydantic import ConfigDict
|
|
22
22
|
from pydantic import Discriminator
|
|
23
|
+
from pydantic import Field
|
|
23
24
|
from pydantic import ValidationError
|
|
24
25
|
from pydantic import ValidationInfo
|
|
25
26
|
from pydantic import ValidatorFunctionWrapHandler
|
|
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
|
|
|
29
30
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
31
|
from nat.data_models.function import EmptyFunctionConfig
|
|
31
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
34
|
from nat.data_models.logging import LoggingBaseConfig
|
|
35
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
33
36
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
37
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
38
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
@@ -47,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|
|
47
50
|
|
|
48
51
|
|
|
49
52
|
def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
50
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
53
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
51
54
|
|
|
52
55
|
new_errors = []
|
|
53
56
|
logged_once = False
|
|
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
error_type = e['type']
|
|
58
61
|
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
62
|
requested_type = e["ctx"]["tag"]
|
|
60
|
-
|
|
61
63
|
if (info.field_name in ('workflow', 'functions')):
|
|
62
64
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
65
|
+
elif (info.field_name == "function_groups"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
|
|
63
67
|
elif (info.field_name == "authentication"):
|
|
64
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
69
|
elif (info.field_name == "llms"):
|
|
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
135
139
|
|
|
136
140
|
class TelemetryConfig(BaseModel):
|
|
137
141
|
|
|
138
|
-
logging: dict[str, LoggingBaseConfig] =
|
|
139
|
-
tracing: dict[str, TelemetryExporterBaseConfig] =
|
|
142
|
+
logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
|
|
143
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
|
|
140
144
|
|
|
141
145
|
@field_validator("logging", "tracing", mode="wrap")
|
|
142
146
|
@classmethod
|
|
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
|
|
|
185
189
|
|
|
186
190
|
model_config = ConfigDict(protected_namespaces=())
|
|
187
191
|
|
|
188
|
-
use_uvloop: bool =
|
|
192
|
+
use_uvloop: bool | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
deprecated=
|
|
195
|
+
"`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
|
|
196
|
+
"automatically determined based on platform")
|
|
189
197
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
198
|
+
This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
|
|
199
|
+
usage is now determined automatically based on the platform.
|
|
192
200
|
"""
|
|
193
201
|
|
|
194
202
|
telemetry: TelemetryConfig = TelemetryConfig()
|
|
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
|
|
|
240
248
|
general: GeneralConfig = GeneralConfig()
|
|
241
249
|
|
|
242
250
|
# Functions Configuration
|
|
243
|
-
functions: dict[str, FunctionBaseConfig] =
|
|
251
|
+
functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
# Function Groups Configuration
|
|
254
|
+
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
244
255
|
|
|
245
256
|
# LLMs Configuration
|
|
246
|
-
llms: dict[str, LLMBaseConfig] =
|
|
257
|
+
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
247
258
|
|
|
248
259
|
# Embedders Configuration
|
|
249
|
-
embedders: dict[str, EmbedderBaseConfig] =
|
|
260
|
+
embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
|
|
250
261
|
|
|
251
262
|
# Memory Configuration
|
|
252
|
-
memory: dict[str, MemoryBaseConfig] =
|
|
263
|
+
memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
|
|
253
264
|
|
|
254
265
|
# Object Stores Configuration
|
|
255
|
-
object_stores: dict[str, ObjectStoreBaseConfig] =
|
|
266
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
|
|
267
|
+
|
|
268
|
+
# Optimizer Configuration
|
|
269
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
256
270
|
|
|
257
271
|
# Retriever Configuration
|
|
258
|
-
retrievers: dict[str, RetrieverBaseConfig] =
|
|
272
|
+
retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
|
|
259
273
|
|
|
260
274
|
# TTC Strategies
|
|
261
|
-
ttc_strategies: dict[str, TTCStrategyBaseConfig] =
|
|
275
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
|
|
262
276
|
|
|
263
277
|
# Workflow Configuration
|
|
264
278
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
279
|
|
|
266
280
|
# Authentication Configuration
|
|
267
|
-
authentication: dict[str, AuthProviderBaseConfig] =
|
|
281
|
+
authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
|
|
268
282
|
|
|
269
283
|
# Evaluation Options
|
|
270
284
|
eval: EvalConfig = EvalConfig()
|
|
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
|
|
|
278
292
|
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
293
|
|
|
280
294
|
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
295
|
+
stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
|
|
281
296
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
297
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
298
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
|
|
|
287
302
|
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
303
|
|
|
289
304
|
@field_validator("functions",
|
|
305
|
+
"function_groups",
|
|
290
306
|
"llms",
|
|
291
307
|
"embedders",
|
|
292
308
|
"memory",
|
|
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
|
|
|
328
344
|
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
345
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
346
|
|
|
347
|
+
FunctionGroupsAnnotation = dict[str,
|
|
348
|
+
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
|
+
|
|
331
351
|
MemoryAnnotation = dict[str,
|
|
332
352
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
353
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
|
|
|
335
355
|
ObjectStoreAnnotation = dict[str,
|
|
336
356
|
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
357
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
-
|
|
339
358
|
RetrieverAnnotation = dict[str,
|
|
340
359
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
360
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
|
|
|
344
363
|
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
365
|
|
|
347
|
-
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
366
|
+
WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
|
|
348
367
|
Discriminator(TypedBaseModel.discriminator)]
|
|
349
368
|
|
|
350
369
|
should_rebuild = False
|
|
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
|
|
|
369
388
|
functions_field.annotation = FunctionsAnnotation
|
|
370
389
|
should_rebuild = True
|
|
371
390
|
|
|
391
|
+
function_groups_field = cls.model_fields.get("function_groups")
|
|
392
|
+
if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
|
|
393
|
+
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
|
+
should_rebuild = True
|
|
395
|
+
|
|
372
396
|
memory_field = cls.model_fields.get("memory")
|
|
373
397
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
398
|
memory_field.annotation = MemoryAnnotation
|
|
@@ -80,7 +80,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
def read_jsonl(file_path: FilePath):
|
|
83
|
-
with open(file_path,
|
|
83
|
+
with open(file_path, encoding='utf-8') as f:
|
|
84
84
|
data = [json.loads(line) for line in f]
|
|
85
85
|
return pd.DataFrame(data)
|
|
86
86
|
|