nvidia-nat 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250923__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nat/agent/react_agent/register.py +3 -10
  2. nat/agent/reasoning_agent/reasoning_agent.py +3 -6
  3. nat/agent/register.py +0 -1
  4. nat/agent/rewoo_agent/agent.py +6 -1
  5. nat/agent/rewoo_agent/register.py +9 -10
  6. nat/agent/tool_calling_agent/register.py +3 -10
  7. nat/authentication/credential_validator/__init__.py +14 -0
  8. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  9. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  10. nat/builder/context.py +28 -6
  11. nat/builder/function.py +165 -19
  12. nat/builder/workflow_builder.py +2 -0
  13. nat/cli/entrypoint.py +2 -9
  14. nat/control_flow/register.py +20 -0
  15. nat/control_flow/router_agent/__init__.py +0 -0
  16. nat/{agent → control_flow}/router_agent/agent.py +3 -3
  17. nat/{agent → control_flow}/router_agent/register.py +8 -14
  18. nat/control_flow/sequential_executor.py +167 -0
  19. nat/data_models/agent.py +34 -0
  20. nat/data_models/authentication.py +38 -0
  21. nat/front_ends/fastapi/dask_client_mixin.py +26 -4
  22. nat/front_ends/fastapi/fastapi_front_end_config.py +4 -0
  23. nat/front_ends/fastapi/fastapi_front_end_plugin.py +30 -7
  24. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  25. nat/front_ends/mcp/mcp_front_end_config.py +5 -1
  26. nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
  27. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +108 -1
  28. nat/front_ends/mcp/tool_converter.py +3 -0
  29. nat/observability/mixin/type_introspection_mixin.py +19 -0
  30. nat/profiler/parameter_optimization/parameter_optimizer.py +5 -1
  31. nat/utils/log_levels.py +25 -0
  32. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/METADATA +3 -1
  33. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/RECORD +40 -31
  34. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/entry_points.txt +1 -0
  35. /nat/{agent/router_agent → control_flow}/__init__.py +0 -0
  36. /nat/{agent → control_flow}/router_agent/prompt.py +0 -0
  37. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/WHEEL +0 -0
  38. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  39. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/licenses/LICENSE.md +0 -0
  40. {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250923.dist-info}/top_level.txt +0 -0
nat/builder/context.py CHANGED
@@ -69,12 +69,10 @@ class ContextState(metaclass=Singleton):
69
69
  self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
70
70
  self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
71
71
  self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
72
- self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
73
- self.event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=Subject())
74
- self.active_function: ContextVar[InvocationNode] = ContextVar("active_function",
75
- default=InvocationNode(function_id="root",
76
- function_name="root"))
77
- self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
72
+ self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
73
+ self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
74
+ self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
75
+ self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
78
76
 
79
77
  # Default is a lambda no-op which returns NoneType
80
78
  self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
@@ -85,6 +83,30 @@ class ContextState(metaclass=Singleton):
85
83
  Awaitable[AuthenticatedContext]]
86
84
  | None] = ContextVar("user_auth_callback", default=None)
87
85
 
86
+ @property
87
+ def metadata(self) -> ContextVar[RequestAttributes]:
88
+ if self._metadata.get() is None:
89
+ self._metadata.set(RequestAttributes())
90
+ return typing.cast(ContextVar[RequestAttributes], self._metadata)
91
+
92
+ @property
93
+ def active_function(self) -> ContextVar[InvocationNode]:
94
+ if self._active_function.get() is None:
95
+ self._active_function.set(InvocationNode(function_id="root", function_name="root"))
96
+ return typing.cast(ContextVar[InvocationNode], self._active_function)
97
+
98
+ @property
99
+ def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
100
+ if self._event_stream.get() is None:
101
+ self._event_stream.set(Subject())
102
+ return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
103
+
104
+ @property
105
+ def active_span_id_stack(self) -> ContextVar[list[str]]:
106
+ if self._active_span_id_stack.get() is None:
107
+ self._active_span_id_stack.set(["root"])
108
+ return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
109
+
88
110
  @staticmethod
89
111
  def get() -> "ContextState":
90
112
  return ContextState()
nat/builder/function.py CHANGED
@@ -21,6 +21,7 @@ from abc import abstractmethod
21
21
  from collections.abc import AsyncGenerator
22
22
  from collections.abc import Awaitable
23
23
  from collections.abc import Callable
24
+ from collections.abc import Sequence
24
25
 
25
26
  from pydantic import BaseModel
26
27
 
@@ -352,7 +353,11 @@ class FunctionGroup:
352
353
  A group of functions that can be used together, sharing the same configuration, context, and resources.
353
354
  """
354
355
 
355
- def __init__(self, *, config: FunctionGroupBaseConfig, instance_name: str | None = None):
356
+ def __init__(self,
357
+ *,
358
+ config: FunctionGroupBaseConfig,
359
+ instance_name: str | None = None,
360
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
356
361
  """
357
362
  Creates a new function group.
358
363
 
@@ -362,10 +367,15 @@ class FunctionGroup:
362
367
  The configuration for the function group.
363
368
  instance_name : str | None, optional
364
369
  The name of the function group. If not provided, the type of the function group will be used.
370
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
371
+ A callback function to additionally filter the functions in the function group dynamically when
372
+ the functions are accessed via any accessor method.
365
373
  """
366
374
  self._config = config
367
375
  self._instance_name = instance_name or config.type
368
- self._functions: dict[str, Function] = {}
376
+ self._functions: dict[str, Function] = dict()
377
+ self._filter_fn = filter_fn
378
+ self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
369
379
 
370
380
  def add_function(self,
371
381
  name: str,
@@ -373,7 +383,8 @@ class FunctionGroup:
373
383
  *,
374
384
  input_schema: type[BaseModel] | None = None,
375
385
  description: str | None = None,
376
- converters: list[Callable] | None = None):
386
+ converters: list[Callable] | None = None,
387
+ filter_fn: Callable[[str], bool] | None = None):
377
388
  """
378
389
  Adds a function to the function group.
379
390
 
@@ -389,6 +400,11 @@ class FunctionGroup:
389
400
  The description of the function.
390
401
  converters : list[Callable] | None, optional
391
402
  The converters to use for the function.
403
+ filter_fn : Callable[[str], bool] | None, optional
404
+ A callback to determine if the function should be included in the function group. The
405
+ callback will be called with the function name. The callback is invoked dynamically when
406
+ the functions are accessed via any accessor method such as `get_accessible_functions`,
407
+ `get_included_functions`, `get_excluded_functions`, `get_all_functions`.
392
408
 
393
409
  Raises
394
410
  ------
@@ -408,6 +424,8 @@ class FunctionGroup:
408
424
  full_name = self._get_fn_name(name)
409
425
  lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
410
426
  self._functions[name] = lambda_fn
427
+ if filter_fn:
428
+ self._per_function_filter_fn[name] = filter_fn
411
429
 
412
430
  def get_config(self) -> FunctionGroupBaseConfig:
413
431
  """
@@ -423,24 +441,54 @@ class FunctionGroup:
423
441
  def _get_fn_name(self, name: str) -> str:
424
442
  return f"{self._instance_name}.{name}"
425
443
 
426
- def _get_all_but_excluded_functions(self) -> dict[str, Function]:
444
+ def _fn_should_be_included(self, name: str) -> bool:
445
+ return (name not in self._per_function_filter_fn or self._per_function_filter_fn[name](name))
446
+
447
+ def _get_all_but_excluded_functions(
448
+ self,
449
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
450
+ ) -> dict[str, Function]:
427
451
  """
428
452
  Returns a dictionary of all functions in the function group except the excluded functions.
429
453
  """
430
454
  missing = set(self._config.exclude) - set(self._functions.keys())
431
455
  if missing:
432
456
  raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
457
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
433
458
  excluded = set(self._config.exclude)
434
- return {self._get_fn_name(name): self._functions[name] for name in self._functions if name not in excluded}
459
+ included = set(filter_fn(list(self._functions.keys())))
460
+
461
+ def predicate(name: str) -> bool:
462
+ if name in excluded:
463
+ return False
464
+ if not self._fn_should_be_included(name):
465
+ return False
466
+ return name in included
435
467
 
436
- def get_accessible_functions(self) -> dict[str, Function]:
468
+ return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
469
+
470
+ def get_accessible_functions(
471
+ self,
472
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
473
+ ) -> dict[str, Function]:
437
474
  """
438
475
  Returns a dictionary of all accessible functions in the function group.
476
+
477
+ First, the functions are filtered by the function group's configuration.
439
478
  If the function group is configured to:
440
479
  - include some functions, this will return only the included functions.
441
480
  - not include or exclude any function, this will return all functions in the group.
442
481
  - exclude some functions, this will return all functions in the group except the excluded functions.
443
482
 
483
+ Then, the functions are filtered by filter function and per-function filter functions.
484
+
485
+ Parameters
486
+ ----------
487
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
488
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
489
+ then fall back to the function group's filter function. If no filter function is set for the function group
490
+ all functions will be returned.
491
+
444
492
  Returns
445
493
  -------
446
494
  dict[str, Function]
@@ -452,15 +500,25 @@ class FunctionGroup:
452
500
  When the function group is configured to include functions that are not found in the group.
453
501
  """
454
502
  if self._config.include:
455
- return self.get_included_functions()
503
+ return self.get_included_functions(filter_fn=filter_fn)
456
504
  if self._config.exclude:
457
- return self._get_all_but_excluded_functions()
458
- return self.get_all_functions()
505
+ return self._get_all_but_excluded_functions(filter_fn=filter_fn)
506
+ return self.get_all_functions(filter_fn=filter_fn)
459
507
 
460
- def get_excluded_functions(self) -> dict[str, Function]:
508
+ def get_excluded_functions(
509
+ self,
510
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
511
+ ) -> dict[str, Function]:
461
512
  """
462
- Returns a dictionary of all functions in the function group which are configured to be excluded.
463
- If the function group is configured to not exclude any functions, this will return an empty dictionary.
513
+ Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
514
+ out by a filter function or per-function filter function.
515
+
516
+ Parameters
517
+ ----------
518
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
519
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
520
+ then fall back to the function group's filter function. If no filter function is set for the function group
521
+ then no functions will be added to the returned dictionary.
464
522
 
465
523
  Returns
466
524
  -------
@@ -475,14 +533,35 @@ class FunctionGroup:
475
533
  missing = set(self._config.exclude) - set(self._functions.keys())
476
534
  if missing:
477
535
  raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
478
- return {self._get_fn_name(name): self._functions[name] for name in self._config.exclude}
536
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
537
+ excluded = set(self._config.exclude)
538
+ included = set(filter_fn(list(self._functions.keys())))
539
+
540
+ def predicate(name: str) -> bool:
541
+ if name in excluded:
542
+ return True
543
+ if not self._fn_should_be_included(name):
544
+ return True
545
+ return name not in included
479
546
 
480
- def get_included_functions(self) -> dict[str, Function]:
547
+ return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
548
+
549
+ def get_included_functions(
550
+ self,
551
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
552
+ ) -> dict[str, Function]:
481
553
  """
482
554
  Returns a dictionary of all functions in the function group which are:
483
555
  - configured to be included and added to the global function registry
484
556
  - not configured to be excluded.
485
- If the function group is configured to not include any functions, this will return an empty dictionary.
557
+ - not filtered out by a filter function.
558
+
559
+ Parameters
560
+ ----------
561
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
562
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
563
+ then fall back to the function group's filter function. If no filter function is set for the function group
564
+ all functions will be returned.
486
565
 
487
566
  Returns
488
567
  -------
@@ -497,15 +576,82 @@ class FunctionGroup:
497
576
  missing = set(self._config.include) - set(self._functions.keys())
498
577
  if missing:
499
578
  raise ValueError(f"Unknown included functions: {sorted(missing)}")
500
- return {self._get_fn_name(name): self._functions[name] for name in self._config.include}
501
-
502
- def get_all_functions(self) -> dict[str, Function]:
579
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
580
+ included = set(filter_fn(list(self._config.include)))
581
+ included = {name for name in included if self._fn_should_be_included(name)}
582
+ return {self._get_fn_name(name): self._functions[name] for name in included}
583
+
584
+ def get_all_functions(
585
+ self,
586
+ filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
587
+ ) -> dict[str, Function]:
503
588
  """
504
589
  Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
505
590
 
591
+ If a filter function has been set, the returned functions will additionally be filtered by the callback.
592
+
593
+ Parameters
594
+ ----------
595
+ filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
596
+ A callback function to additionally filter the functions in the function group dynamically. If not provided
597
+ then fall back to the function group's filter function. If no filter function is set for the function group
598
+ all functions will be returned.
599
+
506
600
  Returns
507
601
  -------
508
602
  dict[str, Function]
509
603
  A dictionary of all functions in the function group.
510
604
  """
511
- return {self._get_fn_name(name): self._functions[name] for name in self._functions}
605
+ filter_fn = filter_fn or self._filter_fn or (lambda x: x)
606
+ included = set(filter_fn(list(self._functions.keys())))
607
+ included = {name for name in included if self._fn_should_be_included(name)}
608
+ return {self._get_fn_name(name): self._functions[name] for name in included}
609
+
610
+ def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
611
+ """
612
+ Sets the filter function for the function group.
613
+
614
+ Parameters
615
+ ----------
616
+ filter_fn : Callable[[Sequence[str]], Sequence[str]]
617
+ The filter function to set for the function group.
618
+ """
619
+ self._filter_fn = filter_fn
620
+
621
+ def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
622
+ """
623
+ Sets the a per-function filter function for the a function within the function group.
624
+
625
+ Parameters
626
+ ----------
627
+ name : str
628
+ The name of the function.
629
+ filter_fn : Callable[[str], bool]
630
+ The per-function filter function to set for the function group.
631
+
632
+ Raises
633
+ ------
634
+ ValueError
635
+ When the function is not found in the function group.
636
+ """
637
+ if name not in self._functions:
638
+ raise ValueError(f"Function {name} not found in function group {self._instance_name}")
639
+ self._per_function_filter_fn[name] = filter_fn
640
+
641
+ def set_instance_name(self, instance_name: str):
642
+ """
643
+ Sets the instance name for the function group.
644
+
645
+ Parameters
646
+ ----------
647
+ instance_name : str
648
+ The instance name to set for the function group.
649
+ """
650
+ self._instance_name = instance_name
651
+
652
+ @property
653
+ def instance_name(self) -> str:
654
+ """
655
+ Returns the instance name for the function group.
656
+ """
657
+ return self._instance_name
@@ -415,6 +415,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
415
415
  raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
416
416
  f"Got {type(build_result)}")
417
417
 
418
+ # set the instance name for the function group based on the workflow-provided name
419
+ build_result.set_instance_name(name)
418
420
  return ConfiguredFunctionGroup(config=config, instance=build_result)
419
421
 
420
422
  @override
nat/cli/entrypoint.py CHANGED
@@ -30,6 +30,8 @@ import time
30
30
  import click
31
31
  import nest_asyncio
32
32
 
33
+ from nat.utils.log_levels import LOG_LEVELS
34
+
33
35
  from .commands.configure.configure import configure_command
34
36
  from .commands.evaluate import eval_command
35
37
  from .commands.info.info import info_command
@@ -45,15 +47,6 @@ from .commands.workflow.workflow import workflow_command
45
47
  # Apply at the beginning of the file to avoid issues with asyncio
46
48
  nest_asyncio.apply()
47
49
 
48
- # Define log level choices
49
- LOG_LEVELS = {
50
- 'DEBUG': logging.DEBUG,
51
- 'INFO': logging.INFO,
52
- 'WARNING': logging.WARNING,
53
- 'ERROR': logging.ERROR,
54
- 'CRITICAL': logging.CRITICAL
55
- }
56
-
57
50
 
58
51
  def setup_logging(log_level: str):
59
52
  """Configure logging with the specified level"""
@@ -0,0 +1,20 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # flake8: noqa
17
+
18
+ # Import any control flows which need to be automatically registered here
19
+ from . import sequential_executor
20
+ from .router_agent import register
File without changes
@@ -31,7 +31,7 @@ from nat.agent.base import AGENT_LOG_PREFIX
31
31
  from nat.agent.base import BaseAgent
32
32
 
33
33
  if typing.TYPE_CHECKING:
34
- from nat.agent.router_agent.register import RouterAgentWorkflowConfig
34
+ from nat.control_flow.router_agent.register import RouterAgentWorkflowConfig
35
35
 
36
36
  logger = logging.getLogger(__name__)
37
37
 
@@ -304,8 +304,8 @@ def create_router_agent_prompt(config: "RouterAgentWorkflowConfig") -> ChatPromp
304
304
  Raises:
305
305
  ValueError: If the system_prompt or user_prompt validation fails.
306
306
  """
307
- from nat.agent.router_agent.prompt import SYSTEM_PROMPT
308
- from nat.agent.router_agent.prompt import USER_PROMPT
307
+ from nat.control_flow.router_agent.prompt import SYSTEM_PROMPT
308
+ from nat.control_flow.router_agent.prompt import USER_PROMPT
309
309
  # the Router Agent prompt can be customized via config option system_prompt and user_prompt.
310
310
 
311
311
  if config.system_prompt:
@@ -16,35 +16,29 @@
16
16
  import logging
17
17
 
18
18
  from pydantic import Field
19
- from pydantic import PositiveInt
20
19
 
21
20
  from nat.builder.builder import Builder
22
21
  from nat.builder.framework_enum import LLMFrameworkEnum
23
22
  from nat.builder.function_info import FunctionInfo
24
23
  from nat.cli.register_workflow import register_function
24
+ from nat.data_models.agent import AgentBaseConfig
25
25
  from nat.data_models.component_ref import FunctionRef
26
- from nat.data_models.component_ref import LLMRef
27
- from nat.data_models.function import FunctionBaseConfig
28
26
 
29
27
  logger = logging.getLogger(__name__)
30
28
 
31
29
 
32
- class RouterAgentWorkflowConfig(FunctionBaseConfig, name="router_agent"):
30
+ class RouterAgentWorkflowConfig(AgentBaseConfig, name="router_agent"):
33
31
  """
34
32
  A router agent takes in the incoming message, combines it with a prompt and the list of branches,
35
33
  and ask a LLM about which branch to take.
36
34
  """
35
+ description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
37
36
  branches: list[FunctionRef] = Field(default_factory=list,
38
37
  description="The list of branches to provide to the router agent.")
39
- llm_name: LLMRef = Field(description="The LLM model to use with the routing agent.")
40
- description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
41
38
  system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
42
39
  user_prompt: str | None = Field(default=None, description="Provides the prompt to use with the agent.")
43
40
  max_router_retries: int = Field(
44
41
  default=3, description="Maximum number of retries if the router agent fails to choose a branch.")
45
- detailed_logs: bool = Field(default=False, description="Set the verbosity of the router agent's logging.")
46
- log_response_max_chars: PositiveInt = Field(
47
- default=1000, description="Maximum number of characters to display in logs when logging branch responses.")
48
42
 
49
43
 
50
44
  @register_function(config_type=RouterAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
@@ -53,9 +47,9 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
53
47
  from langgraph.graph.state import CompiledStateGraph
54
48
 
55
49
  from nat.agent.base import AGENT_LOG_PREFIX
56
- from nat.agent.router_agent.agent import RouterAgentGraph
57
- from nat.agent.router_agent.agent import RouterAgentGraphState
58
- from nat.agent.router_agent.agent import create_router_agent_prompt
50
+ from nat.control_flow.router_agent.agent import RouterAgentGraph
51
+ from nat.control_flow.router_agent.agent import RouterAgentGraphState
52
+ from nat.control_flow.router_agent.agent import create_router_agent_prompt
59
53
 
60
54
  prompt = create_router_agent_prompt(config)
61
55
  llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -68,7 +62,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
68
62
  branches=branches,
69
63
  prompt=prompt,
70
64
  max_router_retries=config.max_router_retries,
71
- detailed_logs=config.detailed_logs,
65
+ detailed_logs=config.verbose,
72
66
  log_response_max_chars=config.log_response_max_chars,
73
67
  ).build_graph()
74
68
 
@@ -85,7 +79,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
85
79
 
86
80
  except Exception as ex:
87
81
  logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
88
- if config.detailed_logs:
82
+ if config.verbose:
89
83
  return str(ex)
90
84
  return "Router agent failed with exception: %s" % ex
91
85
 
@@ -0,0 +1,167 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import logging
17
+ import typing
18
+
19
+ from langchain_core.tools.base import BaseTool
20
+ from pydantic import BaseModel
21
+ from pydantic import Field
22
+
23
+ from nat.builder.builder import Builder
24
+ from nat.builder.framework_enum import LLMFrameworkEnum
25
+ from nat.builder.function import Function
26
+ from nat.builder.function_info import FunctionInfo
27
+ from nat.cli.register_workflow import register_function
28
+ from nat.data_models.component_ref import FunctionRef
29
+ from nat.data_models.function import FunctionBaseConfig
30
+ from nat.utils.type_utils import DecomposedType
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class ToolExecutionConfig(BaseModel):
36
+ """Configuration for individual tool execution within sequential execution."""
37
+
38
+ use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
39
+
40
+
41
+ class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
42
+ """Configuration for sequential execution of a list of functions."""
43
+
44
+ tool_list: list[FunctionRef] = Field(default_factory=list,
45
+ description="A list of functions to execute sequentially.")
46
+ tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
47
+ description="Optional configuration for each"
48
+ "tool in the sequential execution tool list."
49
+ "Keys must match the tool names from the"
50
+ "tool_list.")
51
+ raise_type_incompatibility: bool = Field(
52
+ default=False,
53
+ description="Default to False. Check if the adjacent tools are type compatible,"
54
+ "which means the output type of the previous function is compatible with the input type of the next function."
55
+ "If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
56
+ "generate a warning message and the sequential execution will continue.")
57
+
58
+
59
+ def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
60
+ function_config = tool_execution_config.get(function.instance_name, None)
61
+ if function_config:
62
+ return function.streaming_output_type if function_config.use_streaming else function.single_output_type
63
+ else:
64
+ return function.single_output_type
65
+
66
+
67
+ def _validate_function_type_compatibility(src_fn: Function,
68
+ target_fn: Function,
69
+ tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
70
+ src_output_type = _get_function_output_type(src_fn, tool_execution_config)
71
+ target_input_type = target_fn.input_type
72
+ logger.debug(
73
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
74
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
75
+
76
+ is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
77
+ if not is_compatible:
78
+ raise ValueError(
79
+ f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
80
+ f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
81
+
82
+
83
+ def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
84
+ builder: Builder) -> tuple[type, type]:
85
+ tool_list = sequential_executor_config.tool_list
86
+ tool_execution_config = sequential_executor_config.tool_execution_config
87
+
88
+ function_list: list[Function] = []
89
+ for function_ref in tool_list:
90
+ function_list.append(builder.get_function(function_ref))
91
+ if not function_list:
92
+ raise RuntimeError("The function list is empty")
93
+ input_type = function_list[0].input_type
94
+
95
+ if len(function_list) > 1:
96
+ for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
97
+ try:
98
+ _validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
99
+ except ValueError as e:
100
+ raise ValueError(f"The sequential tool list has incompatible types: {e}")
101
+
102
+ output_type = _get_function_output_type(function_list[-1], tool_execution_config)
103
+ logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
104
+ f"the output type is {str(output_type)}")
105
+
106
+ return (input_type, output_type)
107
+
108
+
109
+ @register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
110
+ async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
111
+ logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
112
+
113
+ tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
114
+ tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
115
+
116
+ try:
117
+ input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
118
+ except ValueError as e:
119
+ if config.raise_type_incompatibility:
120
+ logger.error(f"The sequential executor tool list has incompatible types: {e}")
121
+ raise
122
+ else:
123
+ logger.warning(f"The sequential executor tool list has incompatible types: {e}")
124
+ input_type = typing.Any
125
+ output_type = typing.Any
126
+ except Exception as e:
127
+ raise ValueError(f"Error with the sequential executor tool list: {e}")
128
+
129
+ # The type annotation of _sequential_function_execution is dynamically set according to the tool list
130
+ async def _sequential_function_execution(initial_tool_input):
131
+ logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
132
+
133
+ tool_list: list[FunctionRef] = config.tool_list
134
+ tool_input = initial_tool_input
135
+ tool_response = None
136
+
137
+ for tool_name in tool_list:
138
+ tool = tools_dict[tool_name]
139
+ tool_execution_config = config.tool_execution_config.get(tool_name, None)
140
+ logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
141
+ try:
142
+ if tool_execution_config:
143
+ if tool_execution_config.use_streaming:
144
+ output = ""
145
+ async for chunk in tool.astream(tool_input):
146
+ output += chunk.content
147
+ tool_response = output
148
+ else:
149
+ tool_response = await tool.ainvoke(tool_input)
150
+ else:
151
+ tool_response = await tool.ainvoke(tool_input)
152
+ except Exception as e:
153
+ logger.error(f"Error with tool {tool_name}: {e}")
154
+ raise
155
+
156
+ # The input of the next tool is the response of the previous tool
157
+ tool_input = tool_response
158
+
159
+ return tool_response
160
+
161
+ # Dynamically set the annotations for the function
162
+ _sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
163
+ logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
164
+
165
+ yield FunctionInfo.from_fn(_sequential_function_execution,
166
+ description="Executes a list of functions sequentially."
167
+ "The input of the next tool is the response of the previous tool.")