nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|
nat/cli/type_registry.py
CHANGED
|
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
|
|
|
37
37
|
from nat.builder.evaluator import EvaluatorInfo
|
|
38
38
|
from nat.builder.front_end import FrontEndBase
|
|
39
39
|
from nat.builder.function import Function
|
|
40
|
+
from nat.builder.function import FunctionGroup
|
|
40
41
|
from nat.builder.function_base import FunctionBase
|
|
41
42
|
from nat.builder.function_info import FunctionInfo
|
|
42
43
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
|
|
|
55
56
|
from nat.data_models.front_end import FrontEndConfigT
|
|
56
57
|
from nat.data_models.function import FunctionBaseConfig
|
|
57
58
|
from nat.data_models.function import FunctionConfigT
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
60
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
58
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
59
62
|
from nat.data_models.llm import LLMBaseConfigT
|
|
60
63
|
from nat.data_models.logging import LoggingBaseConfig
|
|
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
|
|
|
85
88
|
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
89
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
87
90
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
|
+
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
88
92
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
93
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
94
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
|
|
|
106
110
|
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
111
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
112
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
|
+
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
109
114
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
115
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
116
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
178
183
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
184
|
|
|
180
185
|
|
|
186
|
+
class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
187
|
+
"""
|
|
188
|
+
Represents a registered function group. Function groups are collections of functions that share configuration
|
|
189
|
+
and resources.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
|
|
193
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
|
+
|
|
195
|
+
|
|
181
196
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
197
|
"""
|
|
183
198
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -298,7 +313,7 @@ class RegisteredPackage(BaseModel):
|
|
|
298
313
|
discovery_metadata: DiscoveryMetadata
|
|
299
314
|
|
|
300
315
|
|
|
301
|
-
class TypeRegistry:
|
|
316
|
+
class TypeRegistry:
|
|
302
317
|
|
|
303
318
|
def __init__(self) -> None:
|
|
304
319
|
# Telemetry Exporters
|
|
@@ -313,6 +328,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
313
328
|
# Functions
|
|
314
329
|
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
330
|
|
|
331
|
+
# Function Groups
|
|
332
|
+
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
|
+
|
|
316
334
|
# LLMs
|
|
317
335
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
336
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -478,6 +496,50 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
478
496
|
|
|
479
497
|
return list(self._registered_functions.values())
|
|
480
498
|
|
|
499
|
+
def register_function_group(self, registration: RegisteredFunctionGroupInfo):
|
|
500
|
+
"""Register a function group with the type registry.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
registration: The function group registration information
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
ValueError: If a function group with the same config type is already registered
|
|
507
|
+
"""
|
|
508
|
+
if (registration.config_type in self._registered_function_groups):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A function group with the same config type `{registration.config_type}` has already been "
|
|
511
|
+
"registered.")
|
|
512
|
+
|
|
513
|
+
self._registered_function_groups[registration.config_type] = registration
|
|
514
|
+
|
|
515
|
+
self._registration_changed()
|
|
516
|
+
|
|
517
|
+
def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
|
|
518
|
+
"""Get a registered function group by its config type.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config_type: The function group configuration type
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
RegisteredFunctionGroupInfo: The registered function group information
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
KeyError: If no function group is registered for the given config type
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
return self._registered_function_groups[config_type]
|
|
531
|
+
except KeyError as err:
|
|
532
|
+
raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
|
|
533
|
+
f"Registered configs: {set(self._registered_function_groups.keys())}") from err
|
|
534
|
+
|
|
535
|
+
def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
|
|
536
|
+
"""Get all registered function groups.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
|
|
540
|
+
"""
|
|
541
|
+
return list(self._registered_function_groups.values())
|
|
542
|
+
|
|
481
543
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
544
|
|
|
483
545
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -779,7 +841,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
779
841
|
|
|
780
842
|
self._registration_changed()
|
|
781
843
|
|
|
782
|
-
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
844
|
+
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
783
845
|
|
|
784
846
|
if component_type == ComponentEnum.FRONT_END:
|
|
785
847
|
return self._registered_front_end_infos
|
|
@@ -790,6 +852,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
790
852
|
if component_type == ComponentEnum.FUNCTION:
|
|
791
853
|
return self._registered_functions
|
|
792
854
|
|
|
855
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
856
|
+
return self._registered_function_groups
|
|
857
|
+
|
|
793
858
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
794
859
|
return self._registered_tool_wrappers
|
|
795
860
|
|
|
@@ -849,12 +914,14 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
849
914
|
|
|
850
915
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
916
|
|
|
852
|
-
def get_registered_types_by_component_type(
|
|
853
|
-
self, component_type: ComponentEnum) -> list[str]:
|
|
917
|
+
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
854
918
|
|
|
855
919
|
if component_type == ComponentEnum.FUNCTION:
|
|
856
920
|
return [i.static_type() for i in self._registered_functions]
|
|
857
921
|
|
|
922
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
923
|
+
return [i.static_type() for i in self._registered_function_groups]
|
|
924
|
+
|
|
858
925
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
859
926
|
return list(self._registered_tool_wrappers)
|
|
860
927
|
|
|
@@ -925,8 +992,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
925
992
|
if (short_names[key.local_name] == 1):
|
|
926
993
|
type_list.append((key.local_name, key.config_type))
|
|
927
994
|
|
|
928
|
-
|
|
929
|
-
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
995
|
+
return typing.Union[*tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
996
|
|
|
931
997
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
|
932
998
|
|
|
@@ -945,6 +1011,9 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
945
1011
|
if issubclass(cls, FunctionBaseConfig):
|
|
946
1012
|
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
947
1013
|
|
|
1014
|
+
if issubclass(cls, FunctionGroupBaseConfig):
|
|
1015
|
+
return self._do_compute_annotation(cls, self.get_registered_function_groups())
|
|
1016
|
+
|
|
948
1017
|
if issubclass(cls, LLMBaseConfig):
|
|
949
1018
|
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
950
1019
|
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
# Import any control flows which need to be automatically registered here
|
|
19
|
+
from . import sequential_executor
|
|
20
|
+
from .router_agent import register
|
|
File without changes
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
20
|
+
from langchain_core.language_models import BaseChatModel
|
|
21
|
+
from langchain_core.messages.base import BaseMessage
|
|
22
|
+
from langchain_core.messages.human import HumanMessage
|
|
23
|
+
from langchain_core.prompts.chat import ChatPromptTemplate
|
|
24
|
+
from langchain_core.tools import BaseTool
|
|
25
|
+
from langgraph.graph import StateGraph
|
|
26
|
+
from pydantic import BaseModel
|
|
27
|
+
from pydantic import Field
|
|
28
|
+
|
|
29
|
+
from nat.agent.base import AGENT_CALL_LOG_MESSAGE
|
|
30
|
+
from nat.agent.base import AGENT_LOG_PREFIX
|
|
31
|
+
from nat.agent.base import BaseAgent
|
|
32
|
+
|
|
33
|
+
if typing.TYPE_CHECKING:
|
|
34
|
+
from nat.control_flow.router_agent.register import RouterAgentWorkflowConfig
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RouterAgentGraphState(BaseModel):
|
|
40
|
+
"""State schema for the Router Agent Graph.
|
|
41
|
+
|
|
42
|
+
This class defines the state structure used throughout the Router Agent's
|
|
43
|
+
execution graph, containing messages, routing information, and branch selection.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
messages: A list of messages representing the conversation history.
|
|
47
|
+
forward_message: The message to be forwarded to the chosen branch.
|
|
48
|
+
chosen_branch: The name of the branch selected by the router agent.
|
|
49
|
+
"""
|
|
50
|
+
messages: list[BaseMessage] = Field(default_factory=list)
|
|
51
|
+
forward_message: BaseMessage = Field(default_factory=lambda: HumanMessage(content=""))
|
|
52
|
+
chosen_branch: str = Field(default="")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RouterAgentGraph(BaseAgent):
|
|
56
|
+
"""Configurable Router Agent for routing requests to different branches.
|
|
57
|
+
|
|
58
|
+
A Router Agent analyzes incoming requests and routes them to one of the
|
|
59
|
+
configured branches based on the conte nt and context. It makes a single
|
|
60
|
+
routing decision and executes only the selected branch before returning.
|
|
61
|
+
|
|
62
|
+
This agent is useful for creating multi-path workflows where different
|
|
63
|
+
types of requests need to be handled by specialized sub-agents or tools.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
llm: BaseChatModel,
|
|
69
|
+
branches: list[BaseTool],
|
|
70
|
+
prompt: ChatPromptTemplate,
|
|
71
|
+
max_router_retries: int = 3,
|
|
72
|
+
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
73
|
+
detailed_logs: bool = False,
|
|
74
|
+
log_response_max_chars: int = 1000,
|
|
75
|
+
):
|
|
76
|
+
"""Initialize the Router Agent.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
llm: The language model to use for routing decisions.
|
|
80
|
+
branches: List of tools/branches that the agent can route to.
|
|
81
|
+
prompt: The chat prompt template for the routing agent.
|
|
82
|
+
max_router_retries: Maximum number of retries if branch selection fails.
|
|
83
|
+
callbacks: Optional list of async callback handlers.
|
|
84
|
+
detailed_logs: Whether to enable detailed logging.
|
|
85
|
+
log_response_max_chars: Maximum characters to log in responses.
|
|
86
|
+
"""
|
|
87
|
+
super().__init__(llm=llm,
|
|
88
|
+
tools=branches,
|
|
89
|
+
callbacks=callbacks,
|
|
90
|
+
detailed_logs=detailed_logs,
|
|
91
|
+
log_response_max_chars=log_response_max_chars)
|
|
92
|
+
|
|
93
|
+
self._branches = branches
|
|
94
|
+
self._branches_dict = {branch.name: branch for branch in branches}
|
|
95
|
+
branch_names = ",".join([branch.name for branch in branches])
|
|
96
|
+
branch_names_and_descriptions = "\n".join([f"{branch.name}: {branch.description}" for branch in branches])
|
|
97
|
+
|
|
98
|
+
prompt = prompt.partial(branches=branch_names_and_descriptions, branch_names=branch_names)
|
|
99
|
+
self.agent = prompt | self.llm
|
|
100
|
+
|
|
101
|
+
self.max_router_retries = max_router_retries
|
|
102
|
+
|
|
103
|
+
def _get_branch(self, branch_name: str) -> BaseTool | None:
|
|
104
|
+
return self._branches_dict.get(branch_name, None)
|
|
105
|
+
|
|
106
|
+
async def agent_node(self, state: RouterAgentGraphState):
|
|
107
|
+
"""Execute the agent node to select a branch for routing.
|
|
108
|
+
|
|
109
|
+
This method processes the incoming request and determines which branch
|
|
110
|
+
should handle it. It uses the configured LLM to analyze the request
|
|
111
|
+
and select the most appropriate branch.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
state: The current state of the router agent graph.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
RouterAgentGraphState: Updated state with the chosen branch.
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
RuntimeError: If the agent fails to choose a branch after max retries.
|
|
121
|
+
"""
|
|
122
|
+
logger.debug("%s Starting the Router Agent Node", AGENT_LOG_PREFIX)
|
|
123
|
+
chat_history = self._get_chat_history(state.messages)
|
|
124
|
+
request = state.forward_message.content
|
|
125
|
+
for attempt in range(1, self.max_router_retries + 1):
|
|
126
|
+
try:
|
|
127
|
+
agent_response = await self._call_llm(self.agent, {"request": request, "chat_history": chat_history})
|
|
128
|
+
if self.detailed_logs:
|
|
129
|
+
logger.info(AGENT_CALL_LOG_MESSAGE, request, agent_response)
|
|
130
|
+
|
|
131
|
+
state.messages += [agent_response]
|
|
132
|
+
|
|
133
|
+
# Determine chosen branch based on agent response
|
|
134
|
+
if state.chosen_branch == "":
|
|
135
|
+
for branch in self._branches:
|
|
136
|
+
if branch.name.lower() in str(agent_response.content).lower():
|
|
137
|
+
state.chosen_branch = branch.name
|
|
138
|
+
if self.detailed_logs:
|
|
139
|
+
logger.debug("%s Router Agent has chosen branch: %s", AGENT_LOG_PREFIX, branch.name)
|
|
140
|
+
return state
|
|
141
|
+
|
|
142
|
+
# The agent failed to choose a branch
|
|
143
|
+
if state.chosen_branch == "":
|
|
144
|
+
if attempt == self.max_router_retries:
|
|
145
|
+
logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
|
|
146
|
+
raise RuntimeError("Router Agent failed to choose a branch")
|
|
147
|
+
logger.warning("%s Router Agent failed to choose a branch, retrying %d out of %d",
|
|
148
|
+
AGENT_LOG_PREFIX,
|
|
149
|
+
attempt,
|
|
150
|
+
self.max_router_retries)
|
|
151
|
+
|
|
152
|
+
except Exception as ex:
|
|
153
|
+
logger.error("%s Router Agent failed to call agent_node: %s", AGENT_LOG_PREFIX, ex)
|
|
154
|
+
raise
|
|
155
|
+
|
|
156
|
+
return state
|
|
157
|
+
|
|
158
|
+
async def branch_node(self, state: RouterAgentGraphState):
|
|
159
|
+
"""Execute the selected branch with the forwarded message.
|
|
160
|
+
|
|
161
|
+
This method calls the tool/branch that was selected by the agent node
|
|
162
|
+
and processes the response.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
state: The current state containing the chosen branch and message.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
RouterAgentGraphState: Updated state with the branch response.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
RuntimeError: If no branch was chosen or branch execution fails.
|
|
172
|
+
ValueError: If the requested tool is not found in the configuration.
|
|
173
|
+
"""
|
|
174
|
+
logger.debug("%s Starting Router Agent Tool Node", AGENT_LOG_PREFIX)
|
|
175
|
+
try:
|
|
176
|
+
if state.chosen_branch == "":
|
|
177
|
+
logger.error("%s Router Agent has empty chosen branch", AGENT_LOG_PREFIX)
|
|
178
|
+
raise RuntimeError("Router Agent failed to choose a branch")
|
|
179
|
+
requested_branch = self._get_branch(state.chosen_branch)
|
|
180
|
+
if not requested_branch:
|
|
181
|
+
logger.error("%s Router Agent wants to call tool %s but it is not in the config file",
|
|
182
|
+
AGENT_LOG_PREFIX,
|
|
183
|
+
state.chosen_branch)
|
|
184
|
+
raise ValueError("Tool not found in config file")
|
|
185
|
+
|
|
186
|
+
branch_input = state.forward_message.content
|
|
187
|
+
branch_response = await self._call_tool(requested_branch, branch_input)
|
|
188
|
+
state.messages += [branch_response]
|
|
189
|
+
if self.detailed_logs:
|
|
190
|
+
self._log_tool_response(requested_branch.name, branch_input, branch_response.content)
|
|
191
|
+
|
|
192
|
+
return state
|
|
193
|
+
|
|
194
|
+
except Exception as ex:
|
|
195
|
+
logger.error("%s Router Agent throws exception during branch node execution: %s", AGENT_LOG_PREFIX, ex)
|
|
196
|
+
raise
|
|
197
|
+
|
|
198
|
+
async def _build_graph(self, state_schema):
|
|
199
|
+
logger.debug("%s Building and compiling the Router Agent Graph", AGENT_LOG_PREFIX)
|
|
200
|
+
|
|
201
|
+
graph = StateGraph(state_schema)
|
|
202
|
+
graph.add_node("agent", self.agent_node)
|
|
203
|
+
graph.add_node("branch", self.branch_node)
|
|
204
|
+
graph.add_edge("agent", "branch")
|
|
205
|
+
graph.set_entry_point("agent")
|
|
206
|
+
|
|
207
|
+
self.graph = graph.compile()
|
|
208
|
+
logger.debug("%s Router Agent Graph built and compiled successfully", AGENT_LOG_PREFIX)
|
|
209
|
+
|
|
210
|
+
return self.graph
|
|
211
|
+
|
|
212
|
+
async def build_graph(self):
|
|
213
|
+
"""Build and compile the router agent execution graph.
|
|
214
|
+
|
|
215
|
+
Creates a state graph with agent and branch nodes, configures the
|
|
216
|
+
execution flow, and compiles the graph for execution.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
The compiled execution graph.
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
Exception: If graph building or compilation fails.
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
await self._build_graph(state_schema=RouterAgentGraphState)
|
|
226
|
+
return self.graph
|
|
227
|
+
except Exception as ex:
|
|
228
|
+
logger.error("%s Router Agent failed to build graph: %s", AGENT_LOG_PREFIX, ex)
|
|
229
|
+
raise
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
def validate_system_prompt(system_prompt: str) -> bool:
|
|
233
|
+
"""Validate that the system prompt contains required variables.
|
|
234
|
+
|
|
235
|
+
Checks that the system prompt includes necessary template variables
|
|
236
|
+
for branch information that the router agent needs.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
system_prompt: The system prompt string to validate.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
True if the prompt is valid, False otherwise.
|
|
243
|
+
"""
|
|
244
|
+
errors = []
|
|
245
|
+
required_prompt_variables = {
|
|
246
|
+
"{branches}": "The system prompt must contain {branches} so the agent knows about configured branches.",
|
|
247
|
+
"{branch_names}": "The system prompt must contain {branch_names} so the agent knows branch names."
|
|
248
|
+
}
|
|
249
|
+
for variable_name, error_message in required_prompt_variables.items():
|
|
250
|
+
if variable_name not in system_prompt:
|
|
251
|
+
errors.append(error_message)
|
|
252
|
+
if errors:
|
|
253
|
+
error_text = "\n".join(errors)
|
|
254
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
255
|
+
return False
|
|
256
|
+
return True
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def validate_user_prompt(user_prompt: str) -> bool:
|
|
260
|
+
"""Validate that the user prompt contains required variables.
|
|
261
|
+
|
|
262
|
+
Checks that the user prompt includes necessary template variables
|
|
263
|
+
for chat history and other required information.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
user_prompt: The user prompt string to validate.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
True if the prompt is valid, False otherwise.
|
|
270
|
+
"""
|
|
271
|
+
errors = []
|
|
272
|
+
if not user_prompt:
|
|
273
|
+
errors.append("The user prompt cannot be empty.")
|
|
274
|
+
else:
|
|
275
|
+
required_prompt_variables = {
|
|
276
|
+
"{chat_history}":
|
|
277
|
+
"The user prompt must contain {chat_history} so the agent knows about the conversation history.",
|
|
278
|
+
"{request}":
|
|
279
|
+
"The user prompt must contain {request} so the agent sees the current request.",
|
|
280
|
+
}
|
|
281
|
+
for variable_name, error_message in required_prompt_variables.items():
|
|
282
|
+
if variable_name not in user_prompt:
|
|
283
|
+
errors.append(error_message)
|
|
284
|
+
if errors:
|
|
285
|
+
error_text = "\n".join(errors)
|
|
286
|
+
logger.error("%s %s", AGENT_LOG_PREFIX, error_text)
|
|
287
|
+
return False
|
|
288
|
+
return True
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def create_router_agent_prompt(config: "RouterAgentWorkflowConfig") -> ChatPromptTemplate:
|
|
292
|
+
"""Create a Router Agent prompt from the configuration.
|
|
293
|
+
|
|
294
|
+
Builds a ChatPromptTemplate using either custom prompts from the config
|
|
295
|
+
or default system and user prompts. Validates the prompts to ensure they
|
|
296
|
+
contain required template variables.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
config: The router agent workflow configuration containing prompt settings.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
A configured ChatPromptTemplate for the router agent.
|
|
303
|
+
|
|
304
|
+
Raises:
|
|
305
|
+
ValueError: If the system_prompt or user_prompt validation fails.
|
|
306
|
+
"""
|
|
307
|
+
from nat.control_flow.router_agent.prompt import SYSTEM_PROMPT
|
|
308
|
+
from nat.control_flow.router_agent.prompt import USER_PROMPT
|
|
309
|
+
# the Router Agent prompt can be customized via config option system_prompt and user_prompt.
|
|
310
|
+
|
|
311
|
+
if config.system_prompt:
|
|
312
|
+
system_prompt = config.system_prompt
|
|
313
|
+
else:
|
|
314
|
+
system_prompt = SYSTEM_PROMPT
|
|
315
|
+
|
|
316
|
+
if config.user_prompt:
|
|
317
|
+
user_prompt = config.user_prompt
|
|
318
|
+
else:
|
|
319
|
+
user_prompt = USER_PROMPT
|
|
320
|
+
|
|
321
|
+
if not RouterAgentGraph.validate_system_prompt(system_prompt):
|
|
322
|
+
logger.error("%s Invalid system_prompt", AGENT_LOG_PREFIX)
|
|
323
|
+
raise ValueError("Invalid system_prompt")
|
|
324
|
+
|
|
325
|
+
if not RouterAgentGraph.validate_user_prompt(user_prompt):
|
|
326
|
+
logger.error("%s Invalid user_prompt", AGENT_LOG_PREFIX)
|
|
327
|
+
raise ValueError("Invalid user_prompt")
|
|
328
|
+
|
|
329
|
+
return ChatPromptTemplate([("system", system_prompt), ("user", user_prompt)])
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
SYSTEM_PROMPT = """
|
|
17
|
+
You are a Router Agent responsible for analyzing incoming requests and routing them to the most appropriate branch.
|
|
18
|
+
|
|
19
|
+
Available branches:
|
|
20
|
+
{branches}
|
|
21
|
+
|
|
22
|
+
CRITICAL INSTRUCTIONS:
|
|
23
|
+
- Analyze the user's request carefully
|
|
24
|
+
- Select exactly ONE branch that best handles the request from: [{branch_names}]
|
|
25
|
+
- Respond with ONLY the exact branch name, nothing else
|
|
26
|
+
- Be decisive - choose the single best match, if the request could fit multiple branches,
|
|
27
|
+
choose the most specific/specialized one
|
|
28
|
+
- If no branch perfectly fits, choose the closest match
|
|
29
|
+
|
|
30
|
+
Your response MUST contain ONLY the branch name. Do not include any explanations, reasoning, or additional text.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
User: "How do I calculate 15 + 25?"
|
|
34
|
+
Response: calculator_tool
|
|
35
|
+
|
|
36
|
+
User: "What's the weather like today?"
|
|
37
|
+
Response: weather_service
|
|
38
|
+
|
|
39
|
+
User: "Send an email to John"
|
|
40
|
+
Response: email_tool"""
|
|
41
|
+
|
|
42
|
+
USER_PROMPT = """
|
|
43
|
+
Previous conversation history:
|
|
44
|
+
{chat_history}
|
|
45
|
+
|
|
46
|
+
To respond to the request: {request}, which branch should be chosen?
|
|
47
|
+
|
|
48
|
+
Respond with only the branch name."""
|