nvidia-nat 1.3.0a20250917__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/react_agent/register.py +3 -10
- nat/agent/reasoning_agent/reasoning_agent.py +3 -6
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +6 -1
- nat/agent/rewoo_agent/register.py +9 -10
- nat/agent/tool_calling_agent/register.py +3 -10
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/context.py +28 -6
- nat/builder/function.py +165 -19
- nat/builder/workflow_builder.py +2 -0
- nat/cli/entrypoint.py +2 -9
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/{agent → control_flow}/router_agent/agent.py +3 -3
- nat/{agent → control_flow}/router_agent/register.py +8 -14
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/front_ends/fastapi/dask_client_mixin.py +26 -4
- nat/front_ends/fastapi/fastapi_front_end_config.py +4 -0
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +30 -7
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +108 -1
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/observability/mixin/type_introspection_mixin.py +19 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +5 -1
- nat/utils/log_levels.py +25 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +3 -1
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +40 -31
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- /nat/{agent/router_agent → control_flow}/__init__.py +0 -0
- /nat/{agent → control_flow}/router_agent/prompt.py +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250917.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
nat/builder/context.py
CHANGED
|
@@ -69,12 +69,10 @@ class ContextState(metaclass=Singleton):
|
|
|
69
69
|
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
70
70
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
71
71
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
function_name="root"))
|
|
77
|
-
self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
|
|
72
|
+
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
73
|
+
self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
|
|
74
|
+
self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
|
|
75
|
+
self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
|
|
78
76
|
|
|
79
77
|
# Default is a lambda no-op which returns NoneType
|
|
80
78
|
self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
|
|
@@ -85,6 +83,30 @@ class ContextState(metaclass=Singleton):
|
|
|
85
83
|
Awaitable[AuthenticatedContext]]
|
|
86
84
|
| None] = ContextVar("user_auth_callback", default=None)
|
|
87
85
|
|
|
86
|
+
@property
|
|
87
|
+
def metadata(self) -> ContextVar[RequestAttributes]:
|
|
88
|
+
if self._metadata.get() is None:
|
|
89
|
+
self._metadata.set(RequestAttributes())
|
|
90
|
+
return typing.cast(ContextVar[RequestAttributes], self._metadata)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def active_function(self) -> ContextVar[InvocationNode]:
|
|
94
|
+
if self._active_function.get() is None:
|
|
95
|
+
self._active_function.set(InvocationNode(function_id="root", function_name="root"))
|
|
96
|
+
return typing.cast(ContextVar[InvocationNode], self._active_function)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
|
|
100
|
+
if self._event_stream.get() is None:
|
|
101
|
+
self._event_stream.set(Subject())
|
|
102
|
+
return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def active_span_id_stack(self) -> ContextVar[list[str]]:
|
|
106
|
+
if self._active_span_id_stack.get() is None:
|
|
107
|
+
self._active_span_id_stack.set(["root"])
|
|
108
|
+
return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
|
|
109
|
+
|
|
88
110
|
@staticmethod
|
|
89
111
|
def get() -> "ContextState":
|
|
90
112
|
return ContextState()
|
nat/builder/function.py
CHANGED
|
@@ -21,6 +21,7 @@ from abc import abstractmethod
|
|
|
21
21
|
from collections.abc import AsyncGenerator
|
|
22
22
|
from collections.abc import Awaitable
|
|
23
23
|
from collections.abc import Callable
|
|
24
|
+
from collections.abc import Sequence
|
|
24
25
|
|
|
25
26
|
from pydantic import BaseModel
|
|
26
27
|
|
|
@@ -352,7 +353,11 @@ class FunctionGroup:
|
|
|
352
353
|
A group of functions that can be used together, sharing the same configuration, context, and resources.
|
|
353
354
|
"""
|
|
354
355
|
|
|
355
|
-
def __init__(self,
|
|
356
|
+
def __init__(self,
|
|
357
|
+
*,
|
|
358
|
+
config: FunctionGroupBaseConfig,
|
|
359
|
+
instance_name: str | None = None,
|
|
360
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None):
|
|
356
361
|
"""
|
|
357
362
|
Creates a new function group.
|
|
358
363
|
|
|
@@ -362,10 +367,15 @@ class FunctionGroup:
|
|
|
362
367
|
The configuration for the function group.
|
|
363
368
|
instance_name : str | None, optional
|
|
364
369
|
The name of the function group. If not provided, the type of the function group will be used.
|
|
370
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
371
|
+
A callback function to additionally filter the functions in the function group dynamically when
|
|
372
|
+
the functions are accessed via any accessor method.
|
|
365
373
|
"""
|
|
366
374
|
self._config = config
|
|
367
375
|
self._instance_name = instance_name or config.type
|
|
368
|
-
self._functions: dict[str, Function] =
|
|
376
|
+
self._functions: dict[str, Function] = dict()
|
|
377
|
+
self._filter_fn = filter_fn
|
|
378
|
+
self._per_function_filter_fn: dict[str, Callable[[str], bool]] = dict()
|
|
369
379
|
|
|
370
380
|
def add_function(self,
|
|
371
381
|
name: str,
|
|
@@ -373,7 +383,8 @@ class FunctionGroup:
|
|
|
373
383
|
*,
|
|
374
384
|
input_schema: type[BaseModel] | None = None,
|
|
375
385
|
description: str | None = None,
|
|
376
|
-
converters: list[Callable] | None = None
|
|
386
|
+
converters: list[Callable] | None = None,
|
|
387
|
+
filter_fn: Callable[[str], bool] | None = None):
|
|
377
388
|
"""
|
|
378
389
|
Adds a function to the function group.
|
|
379
390
|
|
|
@@ -389,6 +400,11 @@ class FunctionGroup:
|
|
|
389
400
|
The description of the function.
|
|
390
401
|
converters : list[Callable] | None, optional
|
|
391
402
|
The converters to use for the function.
|
|
403
|
+
filter_fn : Callable[[str], bool] | None, optional
|
|
404
|
+
A callback to determine if the function should be included in the function group. The
|
|
405
|
+
callback will be called with the function name. The callback is invoked dynamically when
|
|
406
|
+
the functions are accessed via any accessor method such as `get_accessible_functions`,
|
|
407
|
+
`get_included_functions`, `get_excluded_functions`, `get_all_functions`.
|
|
392
408
|
|
|
393
409
|
Raises
|
|
394
410
|
------
|
|
@@ -408,6 +424,8 @@ class FunctionGroup:
|
|
|
408
424
|
full_name = self._get_fn_name(name)
|
|
409
425
|
lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
|
|
410
426
|
self._functions[name] = lambda_fn
|
|
427
|
+
if filter_fn:
|
|
428
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
411
429
|
|
|
412
430
|
def get_config(self) -> FunctionGroupBaseConfig:
|
|
413
431
|
"""
|
|
@@ -423,24 +441,54 @@ class FunctionGroup:
|
|
|
423
441
|
def _get_fn_name(self, name: str) -> str:
|
|
424
442
|
return f"{self._instance_name}.{name}"
|
|
425
443
|
|
|
426
|
-
def
|
|
444
|
+
def _fn_should_be_included(self, name: str) -> bool:
|
|
445
|
+
return (name not in self._per_function_filter_fn or self._per_function_filter_fn[name](name))
|
|
446
|
+
|
|
447
|
+
def _get_all_but_excluded_functions(
|
|
448
|
+
self,
|
|
449
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
450
|
+
) -> dict[str, Function]:
|
|
427
451
|
"""
|
|
428
452
|
Returns a dictionary of all functions in the function group except the excluded functions.
|
|
429
453
|
"""
|
|
430
454
|
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
431
455
|
if missing:
|
|
432
456
|
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
457
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
433
458
|
excluded = set(self._config.exclude)
|
|
434
|
-
|
|
459
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
460
|
+
|
|
461
|
+
def predicate(name: str) -> bool:
|
|
462
|
+
if name in excluded:
|
|
463
|
+
return False
|
|
464
|
+
if not self._fn_should_be_included(name):
|
|
465
|
+
return False
|
|
466
|
+
return name in included
|
|
435
467
|
|
|
436
|
-
|
|
468
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
|
|
469
|
+
|
|
470
|
+
def get_accessible_functions(
|
|
471
|
+
self,
|
|
472
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
473
|
+
) -> dict[str, Function]:
|
|
437
474
|
"""
|
|
438
475
|
Returns a dictionary of all accessible functions in the function group.
|
|
476
|
+
|
|
477
|
+
First, the functions are filtered by the function group's configuration.
|
|
439
478
|
If the function group is configured to:
|
|
440
479
|
- include some functions, this will return only the included functions.
|
|
441
480
|
- not include or exclude any function, this will return all functions in the group.
|
|
442
481
|
- exclude some functions, this will return all functions in the group except the excluded functions.
|
|
443
482
|
|
|
483
|
+
Then, the functions are filtered by filter function and per-function filter functions.
|
|
484
|
+
|
|
485
|
+
Parameters
|
|
486
|
+
----------
|
|
487
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
488
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
489
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
490
|
+
all functions will be returned.
|
|
491
|
+
|
|
444
492
|
Returns
|
|
445
493
|
-------
|
|
446
494
|
dict[str, Function]
|
|
@@ -452,15 +500,25 @@ class FunctionGroup:
|
|
|
452
500
|
When the function group is configured to include functions that are not found in the group.
|
|
453
501
|
"""
|
|
454
502
|
if self._config.include:
|
|
455
|
-
return self.get_included_functions()
|
|
503
|
+
return self.get_included_functions(filter_fn=filter_fn)
|
|
456
504
|
if self._config.exclude:
|
|
457
|
-
return self._get_all_but_excluded_functions()
|
|
458
|
-
return self.get_all_functions()
|
|
505
|
+
return self._get_all_but_excluded_functions(filter_fn=filter_fn)
|
|
506
|
+
return self.get_all_functions(filter_fn=filter_fn)
|
|
459
507
|
|
|
460
|
-
def get_excluded_functions(
|
|
508
|
+
def get_excluded_functions(
|
|
509
|
+
self,
|
|
510
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
511
|
+
) -> dict[str, Function]:
|
|
461
512
|
"""
|
|
462
|
-
Returns a dictionary of all functions in the function group which are configured to be excluded
|
|
463
|
-
|
|
513
|
+
Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
|
|
514
|
+
out by a filter function or per-function filter function.
|
|
515
|
+
|
|
516
|
+
Parameters
|
|
517
|
+
----------
|
|
518
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
519
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
520
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
521
|
+
then no functions will be added to the returned dictionary.
|
|
464
522
|
|
|
465
523
|
Returns
|
|
466
524
|
-------
|
|
@@ -475,14 +533,35 @@ class FunctionGroup:
|
|
|
475
533
|
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
476
534
|
if missing:
|
|
477
535
|
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
478
|
-
|
|
536
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
537
|
+
excluded = set(self._config.exclude)
|
|
538
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
539
|
+
|
|
540
|
+
def predicate(name: str) -> bool:
|
|
541
|
+
if name in excluded:
|
|
542
|
+
return True
|
|
543
|
+
if not self._fn_should_be_included(name):
|
|
544
|
+
return True
|
|
545
|
+
return name not in included
|
|
479
546
|
|
|
480
|
-
|
|
547
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions if predicate(name)}
|
|
548
|
+
|
|
549
|
+
def get_included_functions(
|
|
550
|
+
self,
|
|
551
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
552
|
+
) -> dict[str, Function]:
|
|
481
553
|
"""
|
|
482
554
|
Returns a dictionary of all functions in the function group which are:
|
|
483
555
|
- configured to be included and added to the global function registry
|
|
484
556
|
- not configured to be excluded.
|
|
485
|
-
|
|
557
|
+
- not filtered out by a filter function.
|
|
558
|
+
|
|
559
|
+
Parameters
|
|
560
|
+
----------
|
|
561
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
562
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
563
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
564
|
+
all functions will be returned.
|
|
486
565
|
|
|
487
566
|
Returns
|
|
488
567
|
-------
|
|
@@ -497,15 +576,82 @@ class FunctionGroup:
|
|
|
497
576
|
missing = set(self._config.include) - set(self._functions.keys())
|
|
498
577
|
if missing:
|
|
499
578
|
raise ValueError(f"Unknown included functions: {sorted(missing)}")
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
579
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
580
|
+
included = set(filter_fn(list(self._config.include)))
|
|
581
|
+
included = {name for name in included if self._fn_should_be_included(name)}
|
|
582
|
+
return {self._get_fn_name(name): self._functions[name] for name in included}
|
|
583
|
+
|
|
584
|
+
def get_all_functions(
|
|
585
|
+
self,
|
|
586
|
+
filter_fn: Callable[[Sequence[str]], Sequence[str]] | None = None,
|
|
587
|
+
) -> dict[str, Function]:
|
|
503
588
|
"""
|
|
504
589
|
Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
|
|
505
590
|
|
|
591
|
+
If a filter function has been set, the returned functions will additionally be filtered by the callback.
|
|
592
|
+
|
|
593
|
+
Parameters
|
|
594
|
+
----------
|
|
595
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]] | None, optional
|
|
596
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
597
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
598
|
+
all functions will be returned.
|
|
599
|
+
|
|
506
600
|
Returns
|
|
507
601
|
-------
|
|
508
602
|
dict[str, Function]
|
|
509
603
|
A dictionary of all functions in the function group.
|
|
510
604
|
"""
|
|
511
|
-
|
|
605
|
+
filter_fn = filter_fn or self._filter_fn or (lambda x: x)
|
|
606
|
+
included = set(filter_fn(list(self._functions.keys())))
|
|
607
|
+
included = {name for name in included if self._fn_should_be_included(name)}
|
|
608
|
+
return {self._get_fn_name(name): self._functions[name] for name in included}
|
|
609
|
+
|
|
610
|
+
def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Sequence[str]]):
|
|
611
|
+
"""
|
|
612
|
+
Sets the filter function for the function group.
|
|
613
|
+
|
|
614
|
+
Parameters
|
|
615
|
+
----------
|
|
616
|
+
filter_fn : Callable[[Sequence[str]], Sequence[str]]
|
|
617
|
+
The filter function to set for the function group.
|
|
618
|
+
"""
|
|
619
|
+
self._filter_fn = filter_fn
|
|
620
|
+
|
|
621
|
+
def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], bool]):
|
|
622
|
+
"""
|
|
623
|
+
Sets the a per-function filter function for the a function within the function group.
|
|
624
|
+
|
|
625
|
+
Parameters
|
|
626
|
+
----------
|
|
627
|
+
name : str
|
|
628
|
+
The name of the function.
|
|
629
|
+
filter_fn : Callable[[str], bool]
|
|
630
|
+
The per-function filter function to set for the function group.
|
|
631
|
+
|
|
632
|
+
Raises
|
|
633
|
+
------
|
|
634
|
+
ValueError
|
|
635
|
+
When the function is not found in the function group.
|
|
636
|
+
"""
|
|
637
|
+
if name not in self._functions:
|
|
638
|
+
raise ValueError(f"Function {name} not found in function group {self._instance_name}")
|
|
639
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
640
|
+
|
|
641
|
+
def set_instance_name(self, instance_name: str):
|
|
642
|
+
"""
|
|
643
|
+
Sets the instance name for the function group.
|
|
644
|
+
|
|
645
|
+
Parameters
|
|
646
|
+
----------
|
|
647
|
+
instance_name : str
|
|
648
|
+
The instance name to set for the function group.
|
|
649
|
+
"""
|
|
650
|
+
self._instance_name = instance_name
|
|
651
|
+
|
|
652
|
+
@property
|
|
653
|
+
def instance_name(self) -> str:
|
|
654
|
+
"""
|
|
655
|
+
Returns the instance name for the function group.
|
|
656
|
+
"""
|
|
657
|
+
return self._instance_name
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -415,6 +415,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
415
415
|
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
416
416
|
f"Got {type(build_result)}")
|
|
417
417
|
|
|
418
|
+
# set the instance name for the function group based on the workflow-provided name
|
|
419
|
+
build_result.set_instance_name(name)
|
|
418
420
|
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
419
421
|
|
|
420
422
|
@override
|
nat/cli/entrypoint.py
CHANGED
|
@@ -30,6 +30,8 @@ import time
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
32
|
|
|
33
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
34
|
+
|
|
33
35
|
from .commands.configure.configure import configure_command
|
|
34
36
|
from .commands.evaluate import eval_command
|
|
35
37
|
from .commands.info.info import info_command
|
|
@@ -45,15 +47,6 @@ from .commands.workflow.workflow import workflow_command
|
|
|
45
47
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
46
48
|
nest_asyncio.apply()
|
|
47
49
|
|
|
48
|
-
# Define log level choices
|
|
49
|
-
LOG_LEVELS = {
|
|
50
|
-
'DEBUG': logging.DEBUG,
|
|
51
|
-
'INFO': logging.INFO,
|
|
52
|
-
'WARNING': logging.WARNING,
|
|
53
|
-
'ERROR': logging.ERROR,
|
|
54
|
-
'CRITICAL': logging.CRITICAL
|
|
55
|
-
}
|
|
56
|
-
|
|
57
50
|
|
|
58
51
|
def setup_logging(log_level: str):
|
|
59
52
|
"""Configure logging with the specified level"""
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
|
|
18
|
+
# Import any control flows which need to be automatically registered here
|
|
19
|
+
from . import sequential_executor
|
|
20
|
+
from .router_agent import register
|
|
File without changes
|
|
@@ -31,7 +31,7 @@ from nat.agent.base import AGENT_LOG_PREFIX
|
|
|
31
31
|
from nat.agent.base import BaseAgent
|
|
32
32
|
|
|
33
33
|
if typing.TYPE_CHECKING:
|
|
34
|
-
from nat.
|
|
34
|
+
from nat.control_flow.router_agent.register import RouterAgentWorkflowConfig
|
|
35
35
|
|
|
36
36
|
logger = logging.getLogger(__name__)
|
|
37
37
|
|
|
@@ -304,8 +304,8 @@ def create_router_agent_prompt(config: "RouterAgentWorkflowConfig") -> ChatPromp
|
|
|
304
304
|
Raises:
|
|
305
305
|
ValueError: If the system_prompt or user_prompt validation fails.
|
|
306
306
|
"""
|
|
307
|
-
from nat.
|
|
308
|
-
from nat.
|
|
307
|
+
from nat.control_flow.router_agent.prompt import SYSTEM_PROMPT
|
|
308
|
+
from nat.control_flow.router_agent.prompt import USER_PROMPT
|
|
309
309
|
# the Router Agent prompt can be customized via config option system_prompt and user_prompt.
|
|
310
310
|
|
|
311
311
|
if config.system_prompt:
|
|
@@ -16,35 +16,29 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from pydantic import Field
|
|
19
|
-
from pydantic import PositiveInt
|
|
20
19
|
|
|
21
20
|
from nat.builder.builder import Builder
|
|
22
21
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
23
22
|
from nat.builder.function_info import FunctionInfo
|
|
24
23
|
from nat.cli.register_workflow import register_function
|
|
24
|
+
from nat.data_models.agent import AgentBaseConfig
|
|
25
25
|
from nat.data_models.component_ref import FunctionRef
|
|
26
|
-
from nat.data_models.component_ref import LLMRef
|
|
27
|
-
from nat.data_models.function import FunctionBaseConfig
|
|
28
26
|
|
|
29
27
|
logger = logging.getLogger(__name__)
|
|
30
28
|
|
|
31
29
|
|
|
32
|
-
class RouterAgentWorkflowConfig(
|
|
30
|
+
class RouterAgentWorkflowConfig(AgentBaseConfig, name="router_agent"):
|
|
33
31
|
"""
|
|
34
32
|
A router agent takes in the incoming message, combines it with a prompt and the list of branches,
|
|
35
33
|
and ask a LLM about which branch to take.
|
|
36
34
|
"""
|
|
35
|
+
description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
|
|
37
36
|
branches: list[FunctionRef] = Field(default_factory=list,
|
|
38
37
|
description="The list of branches to provide to the router agent.")
|
|
39
|
-
llm_name: LLMRef = Field(description="The LLM model to use with the routing agent.")
|
|
40
|
-
description: str = Field(default="Router Agent Workflow", description="Description of this functions use.")
|
|
41
38
|
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
42
39
|
user_prompt: str | None = Field(default=None, description="Provides the prompt to use with the agent.")
|
|
43
40
|
max_router_retries: int = Field(
|
|
44
41
|
default=3, description="Maximum number of retries if the router agent fails to choose a branch.")
|
|
45
|
-
detailed_logs: bool = Field(default=False, description="Set the verbosity of the router agent's logging.")
|
|
46
|
-
log_response_max_chars: PositiveInt = Field(
|
|
47
|
-
default=1000, description="Maximum number of characters to display in logs when logging branch responses.")
|
|
48
42
|
|
|
49
43
|
|
|
50
44
|
@register_function(config_type=RouterAgentWorkflowConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
@@ -53,9 +47,9 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
53
47
|
from langgraph.graph.state import CompiledStateGraph
|
|
54
48
|
|
|
55
49
|
from nat.agent.base import AGENT_LOG_PREFIX
|
|
56
|
-
from nat.
|
|
57
|
-
from nat.
|
|
58
|
-
from nat.
|
|
50
|
+
from nat.control_flow.router_agent.agent import RouterAgentGraph
|
|
51
|
+
from nat.control_flow.router_agent.agent import RouterAgentGraphState
|
|
52
|
+
from nat.control_flow.router_agent.agent import create_router_agent_prompt
|
|
59
53
|
|
|
60
54
|
prompt = create_router_agent_prompt(config)
|
|
61
55
|
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
@@ -68,7 +62,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
68
62
|
branches=branches,
|
|
69
63
|
prompt=prompt,
|
|
70
64
|
max_router_retries=config.max_router_retries,
|
|
71
|
-
detailed_logs=config.
|
|
65
|
+
detailed_logs=config.verbose,
|
|
72
66
|
log_response_max_chars=config.log_response_max_chars,
|
|
73
67
|
).build_graph()
|
|
74
68
|
|
|
@@ -85,7 +79,7 @@ async def router_agent_workflow(config: RouterAgentWorkflowConfig, builder: Buil
|
|
|
85
79
|
|
|
86
80
|
except Exception as ex:
|
|
87
81
|
logger.exception("%s Router Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
|
|
88
|
-
if config.
|
|
82
|
+
if config.verbose:
|
|
89
83
|
return str(ex)
|
|
90
84
|
return "Router agent failed with exception: %s" % ex
|
|
91
85
|
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
from langchain_core.tools.base import BaseTool
|
|
20
|
+
from pydantic import BaseModel
|
|
21
|
+
from pydantic import Field
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
25
|
+
from nat.builder.function import Function
|
|
26
|
+
from nat.builder.function_info import FunctionInfo
|
|
27
|
+
from nat.cli.register_workflow import register_function
|
|
28
|
+
from nat.data_models.component_ref import FunctionRef
|
|
29
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
30
|
+
from nat.utils.type_utils import DecomposedType
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ToolExecutionConfig(BaseModel):
|
|
36
|
+
"""Configuration for individual tool execution within sequential execution."""
|
|
37
|
+
|
|
38
|
+
use_streaming: bool = Field(default=False, description="Whether to use streaming output for the tool.")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SequentialExecutorConfig(FunctionBaseConfig, name="sequential_executor"):
|
|
42
|
+
"""Configuration for sequential execution of a list of functions."""
|
|
43
|
+
|
|
44
|
+
tool_list: list[FunctionRef] = Field(default_factory=list,
|
|
45
|
+
description="A list of functions to execute sequentially.")
|
|
46
|
+
tool_execution_config: dict[str, ToolExecutionConfig] = Field(default_factory=dict,
|
|
47
|
+
description="Optional configuration for each"
|
|
48
|
+
"tool in the sequential execution tool list."
|
|
49
|
+
"Keys must match the tool names from the"
|
|
50
|
+
"tool_list.")
|
|
51
|
+
raise_type_incompatibility: bool = Field(
|
|
52
|
+
default=False,
|
|
53
|
+
description="Default to False. Check if the adjacent tools are type compatible,"
|
|
54
|
+
"which means the output type of the previous function is compatible with the input type of the next function."
|
|
55
|
+
"If set to True, any incompatibility will raise an exception. If set to false, the incompatibility will only"
|
|
56
|
+
"generate a warning message and the sequential execution will continue.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _get_function_output_type(function: Function, tool_execution_config: dict[str, ToolExecutionConfig]) -> type:
|
|
60
|
+
function_config = tool_execution_config.get(function.instance_name, None)
|
|
61
|
+
if function_config:
|
|
62
|
+
return function.streaming_output_type if function_config.use_streaming else function.single_output_type
|
|
63
|
+
else:
|
|
64
|
+
return function.single_output_type
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _validate_function_type_compatibility(src_fn: Function,
|
|
68
|
+
target_fn: Function,
|
|
69
|
+
tool_execution_config: dict[str, ToolExecutionConfig]) -> None:
|
|
70
|
+
src_output_type = _get_function_output_type(src_fn, tool_execution_config)
|
|
71
|
+
target_input_type = target_fn.input_type
|
|
72
|
+
logger.debug(
|
|
73
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
74
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
75
|
+
|
|
76
|
+
is_compatible = DecomposedType.is_type_compatible(src_output_type, target_input_type)
|
|
77
|
+
if not is_compatible:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"The output type of the {src_fn.instance_name} function is {str(src_output_type)}, is not compatible with"
|
|
80
|
+
f"the input type of the {target_fn.instance_name} function, which is {str(target_input_type)}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _validate_tool_list_type_compatibility(sequential_executor_config: SequentialExecutorConfig,
|
|
84
|
+
builder: Builder) -> tuple[type, type]:
|
|
85
|
+
tool_list = sequential_executor_config.tool_list
|
|
86
|
+
tool_execution_config = sequential_executor_config.tool_execution_config
|
|
87
|
+
|
|
88
|
+
function_list: list[Function] = []
|
|
89
|
+
for function_ref in tool_list:
|
|
90
|
+
function_list.append(builder.get_function(function_ref))
|
|
91
|
+
if not function_list:
|
|
92
|
+
raise RuntimeError("The function list is empty")
|
|
93
|
+
input_type = function_list[0].input_type
|
|
94
|
+
|
|
95
|
+
if len(function_list) > 1:
|
|
96
|
+
for src_fn, target_fn in zip(function_list[0:-1], function_list[1:]):
|
|
97
|
+
try:
|
|
98
|
+
_validate_function_type_compatibility(src_fn, target_fn, tool_execution_config)
|
|
99
|
+
except ValueError as e:
|
|
100
|
+
raise ValueError(f"The sequential tool list has incompatible types: {e}")
|
|
101
|
+
|
|
102
|
+
output_type = _get_function_output_type(function_list[-1], tool_execution_config)
|
|
103
|
+
logger.debug(f"The input type of the sequential executor tool list is {str(input_type)},"
|
|
104
|
+
f"the output type is {str(output_type)}")
|
|
105
|
+
|
|
106
|
+
return (input_type, output_type)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@register_function(config_type=SequentialExecutorConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
110
|
+
async def sequential_execution(config: SequentialExecutorConfig, builder: Builder):
|
|
111
|
+
logger.debug(f"Initializing sequential executor with tool list: {config.tool_list}")
|
|
112
|
+
|
|
113
|
+
tools: list[BaseTool] = builder.get_tools(tool_names=config.tool_list, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
114
|
+
tools_dict: dict[str, BaseTool] = {tool.name: tool for tool in tools}
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
input_type, output_type = _validate_tool_list_type_compatibility(config, builder)
|
|
118
|
+
except ValueError as e:
|
|
119
|
+
if config.raise_type_incompatibility:
|
|
120
|
+
logger.error(f"The sequential executor tool list has incompatible types: {e}")
|
|
121
|
+
raise
|
|
122
|
+
else:
|
|
123
|
+
logger.warning(f"The sequential executor tool list has incompatible types: {e}")
|
|
124
|
+
input_type = typing.Any
|
|
125
|
+
output_type = typing.Any
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise ValueError(f"Error with the sequential executor tool list: {e}")
|
|
128
|
+
|
|
129
|
+
# The type annotation of _sequential_function_execution is dynamically set according to the tool list
|
|
130
|
+
async def _sequential_function_execution(initial_tool_input):
|
|
131
|
+
logger.debug(f"Executing sequential executor with tool list: {config.tool_list}")
|
|
132
|
+
|
|
133
|
+
tool_list: list[FunctionRef] = config.tool_list
|
|
134
|
+
tool_input = initial_tool_input
|
|
135
|
+
tool_response = None
|
|
136
|
+
|
|
137
|
+
for tool_name in tool_list:
|
|
138
|
+
tool = tools_dict[tool_name]
|
|
139
|
+
tool_execution_config = config.tool_execution_config.get(tool_name, None)
|
|
140
|
+
logger.debug(f"Executing tool {tool_name} with input: {tool_input}")
|
|
141
|
+
try:
|
|
142
|
+
if tool_execution_config:
|
|
143
|
+
if tool_execution_config.use_streaming:
|
|
144
|
+
output = ""
|
|
145
|
+
async for chunk in tool.astream(tool_input):
|
|
146
|
+
output += chunk.content
|
|
147
|
+
tool_response = output
|
|
148
|
+
else:
|
|
149
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
150
|
+
else:
|
|
151
|
+
tool_response = await tool.ainvoke(tool_input)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
logger.error(f"Error with tool {tool_name}: {e}")
|
|
154
|
+
raise
|
|
155
|
+
|
|
156
|
+
# The input of the next tool is the response of the previous tool
|
|
157
|
+
tool_input = tool_response
|
|
158
|
+
|
|
159
|
+
return tool_response
|
|
160
|
+
|
|
161
|
+
# Dynamically set the annotations for the function
|
|
162
|
+
_sequential_function_execution.__annotations__ = {"initial_tool_input": input_type, "return": output_type}
|
|
163
|
+
logger.debug(f"Sequential executor function annotations: {_sequential_function_execution.__annotations__}")
|
|
164
|
+
|
|
165
|
+
yield FunctionInfo.from_fn(_sequential_function_execution,
|
|
166
|
+
description="Executes a list of functions sequentially."
|
|
167
|
+
"The input of the next tool is the response of the previous tool.")
|