nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +15 -5
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +4 -2
- nat/agent/rewoo_agent/register.py +8 -3
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +11 -3
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
nat/builder/builder.py
CHANGED
|
@@ -24,9 +24,11 @@ from nat.authentication.interfaces import AuthProviderBase
|
|
|
24
24
|
from nat.builder.context import Context
|
|
25
25
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
26
26
|
from nat.builder.function import Function
|
|
27
|
+
from nat.builder.function import FunctionGroup
|
|
27
28
|
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
28
29
|
from nat.data_models.component_ref import AuthenticationRef
|
|
29
30
|
from nat.data_models.component_ref import EmbedderRef
|
|
31
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
30
32
|
from nat.data_models.component_ref import FunctionRef
|
|
31
33
|
from nat.data_models.component_ref import LLMRef
|
|
32
34
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -36,6 +38,7 @@ from nat.data_models.component_ref import TTCStrategyRef
|
|
|
36
38
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
37
39
|
from nat.data_models.evaluator import EvaluatorBaseConfig
|
|
38
40
|
from nat.data_models.function import FunctionBaseConfig
|
|
41
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
39
42
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
40
43
|
from nat.data_models.llm import LLMBaseConfig
|
|
41
44
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -64,18 +67,33 @@ class Builder(ABC):
|
|
|
64
67
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
65
68
|
pass
|
|
66
69
|
|
|
70
|
+
@abstractmethod
|
|
71
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
72
|
+
pass
|
|
73
|
+
|
|
67
74
|
@abstractmethod
|
|
68
75
|
def get_function(self, name: str | FunctionRef) -> Function:
|
|
69
76
|
pass
|
|
70
77
|
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
80
|
+
pass
|
|
81
|
+
|
|
71
82
|
def get_functions(self, function_names: Sequence[str | FunctionRef]) -> list[Function]:
|
|
72
83
|
|
|
73
84
|
return [self.get_function(name) for name in function_names]
|
|
74
85
|
|
|
86
|
+
def get_function_groups(self, function_group_names: Sequence[str | FunctionGroupRef]) -> list[FunctionGroup]:
|
|
87
|
+
return [self.get_function_group(name) for name in function_group_names]
|
|
88
|
+
|
|
75
89
|
@abstractmethod
|
|
76
90
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
77
91
|
pass
|
|
78
92
|
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
95
|
+
pass
|
|
96
|
+
|
|
79
97
|
@abstractmethod
|
|
80
98
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
81
99
|
pass
|
|
@@ -88,10 +106,11 @@ class Builder(ABC):
|
|
|
88
106
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
89
107
|
pass
|
|
90
108
|
|
|
91
|
-
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def get_tools(self,
|
|
111
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
92
112
|
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
93
|
-
|
|
94
|
-
return [self.get_tool(fn_name=n, wrapper_type=wrapper_type) for n in tool_names]
|
|
113
|
+
pass
|
|
95
114
|
|
|
96
115
|
@abstractmethod
|
|
97
116
|
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str) -> typing.Any:
|
|
@@ -257,8 +276,12 @@ class Builder(ABC):
|
|
|
257
276
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
258
277
|
pass
|
|
259
278
|
|
|
279
|
+
@abstractmethod
|
|
280
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
281
|
+
pass
|
|
282
|
+
|
|
260
283
|
|
|
261
|
-
class EvalBuilder(
|
|
284
|
+
class EvalBuilder(ABC):
|
|
262
285
|
|
|
263
286
|
@abstractmethod
|
|
264
287
|
async def add_evaluator(self, name: str, config: EvaluatorBaseConfig):
|
nat/builder/component_utils.py
CHANGED
|
@@ -30,6 +30,7 @@ from nat.data_models.component_ref import generate_instance_id
|
|
|
30
30
|
from nat.data_models.config import Config
|
|
31
31
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
32
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
34
|
from nat.data_models.llm import LLMBaseConfig
|
|
34
35
|
from nat.data_models.memory import MemoryBaseConfig
|
|
35
36
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
@@ -48,6 +49,7 @@ _component_group_order = [
|
|
|
48
49
|
ComponentGroup.OBJECT_STORES,
|
|
49
50
|
ComponentGroup.RETRIEVERS,
|
|
50
51
|
ComponentGroup.TTC_STRATEGIES,
|
|
52
|
+
ComponentGroup.FUNCTION_GROUPS,
|
|
51
53
|
ComponentGroup.FUNCTIONS,
|
|
52
54
|
]
|
|
53
55
|
|
|
@@ -107,6 +109,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
|
|
|
107
109
|
return ComponentGroup.EMBEDDERS
|
|
108
110
|
if (isinstance(component, FunctionBaseConfig)):
|
|
109
111
|
return ComponentGroup.FUNCTIONS
|
|
112
|
+
if (isinstance(component, FunctionGroupBaseConfig)):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
110
114
|
if (isinstance(component, LLMBaseConfig)):
|
|
111
115
|
return ComponentGroup.LLMS
|
|
112
116
|
if (isinstance(component, MemoryBaseConfig)):
|
|
@@ -254,9 +258,9 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
|
|
|
254
258
|
runtime instance references.
|
|
255
259
|
"""
|
|
256
260
|
|
|
257
|
-
total_node_count = len(config.embedders) + len(config.functions) + len(config.
|
|
258
|
-
|
|
259
|
-
|
|
261
|
+
total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
|
|
262
|
+
len(config.memory) + len(config.object_stores) + len(config.retrievers) +
|
|
263
|
+
len(config.ttc_strategies) + len(config.authentication) + 1) # +1 for the workflow
|
|
260
264
|
|
|
261
265
|
dependency_map: dict
|
|
262
266
|
dependency_graph: nx.DiGraph
|
nat/builder/function.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
import re
|
|
17
18
|
import typing
|
|
18
19
|
from abc import ABC
|
|
19
20
|
from abc import abstractmethod
|
|
@@ -29,7 +30,9 @@ from nat.builder.function_base import InputT
|
|
|
29
30
|
from nat.builder.function_base import SingleOutputT
|
|
30
31
|
from nat.builder.function_base import StreamingOutputT
|
|
31
32
|
from nat.builder.function_info import FunctionInfo
|
|
33
|
+
from nat.data_models.function import EmptyFunctionConfig
|
|
32
34
|
from nat.data_models.function import FunctionBaseConfig
|
|
35
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
36
|
|
|
34
37
|
_InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
|
|
35
38
|
_StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
|
|
@@ -342,3 +345,167 @@ class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
342
345
|
pass
|
|
343
346
|
|
|
344
347
|
return FunctionImpl(config=config, info=info, instance_name=instance_name)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class FunctionGroup:
|
|
351
|
+
"""
|
|
352
|
+
A group of functions that can be used together, sharing the same configuration, context, and resources.
|
|
353
|
+
"""
|
|
354
|
+
|
|
355
|
+
def __init__(self, *, config: FunctionGroupBaseConfig, instance_name: str | None = None):
|
|
356
|
+
"""
|
|
357
|
+
Creates a new function group.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
config : FunctionGroupBaseConfig
|
|
362
|
+
The configuration for the function group.
|
|
363
|
+
instance_name : str | None, optional
|
|
364
|
+
The name of the function group. If not provided, the type of the function group will be used.
|
|
365
|
+
"""
|
|
366
|
+
self._config = config
|
|
367
|
+
self._instance_name = instance_name or config.type
|
|
368
|
+
self._functions: dict[str, Function] = {}
|
|
369
|
+
|
|
370
|
+
def add_function(self,
|
|
371
|
+
name: str,
|
|
372
|
+
fn: Callable,
|
|
373
|
+
*,
|
|
374
|
+
input_schema: type[BaseModel] | None = None,
|
|
375
|
+
description: str | None = None,
|
|
376
|
+
converters: list[Callable] | None = None):
|
|
377
|
+
"""
|
|
378
|
+
Adds a function to the function group.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
name : str
|
|
383
|
+
The name of the function.
|
|
384
|
+
fn : Callable
|
|
385
|
+
The function to add to the function group.
|
|
386
|
+
input_schema : type[BaseModel] | None, optional
|
|
387
|
+
The input schema for the function.
|
|
388
|
+
description : str | None, optional
|
|
389
|
+
The description of the function.
|
|
390
|
+
converters : list[Callable] | None, optional
|
|
391
|
+
The converters to use for the function.
|
|
392
|
+
|
|
393
|
+
Raises
|
|
394
|
+
------
|
|
395
|
+
ValueError
|
|
396
|
+
When the function name is empty or blank.
|
|
397
|
+
When the function name contains invalid characters.
|
|
398
|
+
When the function already exists in the function group.
|
|
399
|
+
"""
|
|
400
|
+
if not name.strip():
|
|
401
|
+
raise ValueError("Function name cannot be empty or blank")
|
|
402
|
+
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
|
403
|
+
raise ValueError(f"Function name can only contain letters, numbers, underscores, and hyphens: {name}")
|
|
404
|
+
if name in self._functions:
|
|
405
|
+
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
406
|
+
|
|
407
|
+
info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
|
|
408
|
+
full_name = self._get_fn_name(name)
|
|
409
|
+
lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
|
|
410
|
+
self._functions[name] = lambda_fn
|
|
411
|
+
|
|
412
|
+
def get_config(self) -> FunctionGroupBaseConfig:
|
|
413
|
+
"""
|
|
414
|
+
Returns the configuration for the function group.
|
|
415
|
+
|
|
416
|
+
Returns
|
|
417
|
+
-------
|
|
418
|
+
FunctionGroupBaseConfig
|
|
419
|
+
The configuration for the function group.
|
|
420
|
+
"""
|
|
421
|
+
return self._config
|
|
422
|
+
|
|
423
|
+
def _get_fn_name(self, name: str) -> str:
|
|
424
|
+
return f"{self._instance_name}.{name}"
|
|
425
|
+
|
|
426
|
+
def _get_all_but_excluded_functions(self) -> dict[str, Function]:
|
|
427
|
+
"""
|
|
428
|
+
Returns a dictionary of all functions in the function group except the excluded functions.
|
|
429
|
+
"""
|
|
430
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
431
|
+
if missing:
|
|
432
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
433
|
+
excluded = set(self._config.exclude)
|
|
434
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions if name not in excluded}
|
|
435
|
+
|
|
436
|
+
def get_accessible_functions(self) -> dict[str, Function]:
|
|
437
|
+
"""
|
|
438
|
+
Returns a dictionary of all accessible functions in the function group.
|
|
439
|
+
If the function group is configured to:
|
|
440
|
+
- include some functions, this will return only the included functions.
|
|
441
|
+
- not include or exclude any function, this will return all functions in the group.
|
|
442
|
+
- exclude some functions, this will return all functions in the group except the excluded functions.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
dict[str, Function]
|
|
447
|
+
A dictionary of all accessible functions in the function group.
|
|
448
|
+
|
|
449
|
+
Raises
|
|
450
|
+
------
|
|
451
|
+
ValueError
|
|
452
|
+
When the function group is configured to include functions that are not found in the group.
|
|
453
|
+
"""
|
|
454
|
+
if self._config.include:
|
|
455
|
+
return self.get_included_functions()
|
|
456
|
+
if self._config.exclude:
|
|
457
|
+
return self._get_all_but_excluded_functions()
|
|
458
|
+
return self.get_all_functions()
|
|
459
|
+
|
|
460
|
+
def get_excluded_functions(self) -> dict[str, Function]:
|
|
461
|
+
"""
|
|
462
|
+
Returns a dictionary of all functions in the function group which are configured to be excluded.
|
|
463
|
+
If the function group is configured to not exclude any functions, this will return an empty dictionary.
|
|
464
|
+
|
|
465
|
+
Returns
|
|
466
|
+
-------
|
|
467
|
+
dict[str, Function]
|
|
468
|
+
A dictionary of all excluded functions in the function group.
|
|
469
|
+
|
|
470
|
+
Raises
|
|
471
|
+
------
|
|
472
|
+
ValueError
|
|
473
|
+
When the function group is configured to exclude functions that are not found in the group.
|
|
474
|
+
"""
|
|
475
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
476
|
+
if missing:
|
|
477
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
478
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._config.exclude}
|
|
479
|
+
|
|
480
|
+
def get_included_functions(self) -> dict[str, Function]:
|
|
481
|
+
"""
|
|
482
|
+
Returns a dictionary of all functions in the function group which are:
|
|
483
|
+
- configured to be included and added to the global function registry
|
|
484
|
+
- not configured to be excluded.
|
|
485
|
+
If the function group is configured to not include any functions, this will return an empty dictionary.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
dict[str, Function]
|
|
490
|
+
A dictionary of all included functions in the function group.
|
|
491
|
+
|
|
492
|
+
Raises
|
|
493
|
+
------
|
|
494
|
+
ValueError
|
|
495
|
+
When the function group is configured to include functions that are not found in the group.
|
|
496
|
+
"""
|
|
497
|
+
missing = set(self._config.include) - set(self._functions.keys())
|
|
498
|
+
if missing:
|
|
499
|
+
raise ValueError(f"Unknown included functions: {sorted(missing)}")
|
|
500
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._config.include}
|
|
501
|
+
|
|
502
|
+
def get_all_functions(self) -> dict[str, Function]:
|
|
503
|
+
"""
|
|
504
|
+
Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
|
|
505
|
+
|
|
506
|
+
Returns
|
|
507
|
+
-------
|
|
508
|
+
dict[str, Function]
|
|
509
|
+
A dictionary of all functions in the function group.
|
|
510
|
+
"""
|
|
511
|
+
return {self._get_fn_name(name): self._functions[name] for name in self._functions}
|
nat/builder/function_info.py
CHANGED
|
@@ -233,7 +233,7 @@ class FunctionDescriptor:
|
|
|
233
233
|
|
|
234
234
|
is_input_typed = all([a != sig.empty for a in annotations])
|
|
235
235
|
|
|
236
|
-
input_type = tuple[*annotations] if is_input_typed else None
|
|
236
|
+
input_type = tuple[*annotations] if is_input_typed else None
|
|
237
237
|
|
|
238
238
|
# Get the base type here removing all annotations and async generators
|
|
239
239
|
output_annotation_decomp = DecomposedType(sig.return_annotation).get_base_type()
|
nat/builder/workflow.py
CHANGED
|
@@ -20,6 +20,7 @@ from typing import Any
|
|
|
20
20
|
from nat.builder.context import ContextState
|
|
21
21
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
22
22
|
from nat.builder.function import Function
|
|
23
|
+
from nat.builder.function import FunctionGroup
|
|
23
24
|
from nat.builder.function_base import FunctionBase
|
|
24
25
|
from nat.builder.function_base import InputT
|
|
25
26
|
from nat.builder.function_base import SingleOutputT
|
|
@@ -44,6 +45,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
44
45
|
config: Config,
|
|
45
46
|
entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
|
|
46
47
|
functions: dict[str, Function] | None = None,
|
|
48
|
+
function_groups: dict[str, FunctionGroup] | None = None,
|
|
47
49
|
llms: dict[str, LLMProviderInfo] | None = None,
|
|
48
50
|
embeddings: dict[str, EmbedderProviderInfo] | None = None,
|
|
49
51
|
memory: dict[str, MemoryEditor] | None = None,
|
|
@@ -59,6 +61,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
59
61
|
|
|
60
62
|
self.config = config
|
|
61
63
|
self.functions = functions or {}
|
|
64
|
+
self.function_groups = function_groups or {}
|
|
62
65
|
self.llms = llms or {}
|
|
63
66
|
self.embeddings = embeddings or {}
|
|
64
67
|
self.memory = memory or {}
|
|
@@ -126,6 +129,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
126
129
|
config: Config,
|
|
127
130
|
entry_fn: Function[InputT, StreamingOutputT, SingleOutputT],
|
|
128
131
|
functions: dict[str, Function] | None = None,
|
|
132
|
+
function_groups: dict[str, FunctionGroup] | None = None,
|
|
129
133
|
llms: dict[str, LLMProviderInfo] | None = None,
|
|
130
134
|
embeddings: dict[str, EmbedderProviderInfo] | None = None,
|
|
131
135
|
memory: dict[str, MemoryEditor] | None = None,
|
|
@@ -145,6 +149,7 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
145
149
|
return WorkflowImpl(config=config,
|
|
146
150
|
entry_fn=entry_fn,
|
|
147
151
|
functions=functions,
|
|
152
|
+
function_groups=function_groups,
|
|
148
153
|
llms=llms,
|
|
149
154
|
embeddings=embeddings,
|
|
150
155
|
memory=memory,
|