nvidia-nat 1.3.0a20250909__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +11 -6
- nat/agent/dual_node.py +2 -2
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -7
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +6 -3
- nat/agent/rewoo_agent/register.py +16 -10
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +17 -9
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +5 -8
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/api_server.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +15 -4
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +5 -2
- nat/tool/document_search.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +94 -74
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
import inspect
|
|
18
18
|
import logging
|
|
19
|
+
import typing
|
|
19
20
|
import warnings
|
|
21
|
+
from collections.abc import Sequence
|
|
20
22
|
from contextlib import AbstractAsyncContextManager
|
|
21
23
|
from contextlib import AsyncExitStack
|
|
22
24
|
from contextlib import asynccontextmanager
|
|
@@ -31,6 +33,7 @@ from nat.builder.context import ContextState
|
|
|
31
33
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
32
34
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
33
35
|
from nat.builder.function import Function
|
|
36
|
+
from nat.builder.function import FunctionGroup
|
|
34
37
|
from nat.builder.function import LambdaFunction
|
|
35
38
|
from nat.builder.function_info import FunctionInfo
|
|
36
39
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -42,6 +45,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
42
45
|
from nat.data_models.component import ComponentGroup
|
|
43
46
|
from nat.data_models.component_ref import AuthenticationRef
|
|
44
47
|
from nat.data_models.component_ref import EmbedderRef
|
|
48
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
45
49
|
from nat.data_models.component_ref import FunctionRef
|
|
46
50
|
from nat.data_models.component_ref import LLMRef
|
|
47
51
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -52,6 +56,7 @@ from nat.data_models.config import Config
|
|
|
52
56
|
from nat.data_models.config import GeneralConfig
|
|
53
57
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
54
58
|
from nat.data_models.function import FunctionBaseConfig
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
55
60
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
56
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
57
62
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -85,6 +90,12 @@ class ConfiguredFunction:
|
|
|
85
90
|
instance: Function
|
|
86
91
|
|
|
87
92
|
|
|
93
|
+
@dataclasses.dataclass
|
|
94
|
+
class ConfiguredFunctionGroup:
|
|
95
|
+
config: FunctionGroupBaseConfig
|
|
96
|
+
instance: FunctionGroup
|
|
97
|
+
|
|
98
|
+
|
|
88
99
|
@dataclasses.dataclass
|
|
89
100
|
class ConfiguredLLM:
|
|
90
101
|
config: LLMBaseConfig
|
|
@@ -145,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
145
156
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
146
157
|
|
|
147
158
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
159
|
+
self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
|
|
148
160
|
self._workflow: ConfiguredFunction | None = None
|
|
149
161
|
|
|
150
162
|
self._llms: dict[str, ConfiguredLLM] = {}
|
|
@@ -161,7 +173,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
161
173
|
|
|
162
174
|
# Create a mapping to track function name -> other function names it depends on
|
|
163
175
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
176
|
+
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
164
177
|
self.current_function_building: str | None = None
|
|
178
|
+
self.current_function_group_building: str | None = None
|
|
165
179
|
|
|
166
180
|
async def __aenter__(self):
|
|
167
181
|
|
|
@@ -224,12 +238,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
224
238
|
if (self._workflow is None):
|
|
225
239
|
raise ValueError("Must set a workflow before building")
|
|
226
240
|
|
|
241
|
+
# Set of all functions which are "included" by function groups
|
|
242
|
+
included_functions = set()
|
|
243
|
+
# Dictionary of function configs
|
|
244
|
+
function_configs = dict()
|
|
245
|
+
# Dictionary of function group configs
|
|
246
|
+
function_group_configs = dict()
|
|
247
|
+
# Dictionary of function instances
|
|
248
|
+
function_instances = dict()
|
|
249
|
+
# Dictionary of function group instances
|
|
250
|
+
function_group_instances = dict()
|
|
251
|
+
|
|
252
|
+
for k, v in self._function_groups.items():
|
|
253
|
+
included_functions.update(v.instance.get_included_functions().keys())
|
|
254
|
+
function_group_configs[k] = v.config
|
|
255
|
+
function_group_instances[k] = v.instance
|
|
256
|
+
|
|
257
|
+
# Function configs need to be restricted to only the functions that are not in a function group
|
|
258
|
+
for k, v in self._functions.items():
|
|
259
|
+
if k not in included_functions:
|
|
260
|
+
function_configs[k] = v.config
|
|
261
|
+
function_instances[k] = v.instance
|
|
262
|
+
|
|
227
263
|
# Build the config from the added objects
|
|
228
264
|
config = Config(general=self.general_config,
|
|
229
|
-
functions=
|
|
230
|
-
|
|
231
|
-
for k, v in self._functions.items()
|
|
232
|
-
},
|
|
265
|
+
functions=function_configs,
|
|
266
|
+
function_groups=function_group_configs,
|
|
233
267
|
workflow=self._workflow.config,
|
|
234
268
|
llms={
|
|
235
269
|
k: v.config
|
|
@@ -263,10 +297,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
263
297
|
|
|
264
298
|
workflow = Workflow.from_entry_fn(config=config,
|
|
265
299
|
entry_fn=entry_fn_obj,
|
|
266
|
-
functions=
|
|
267
|
-
|
|
268
|
-
for k, v in self._functions.items()
|
|
269
|
-
},
|
|
300
|
+
functions=function_instances,
|
|
301
|
+
function_groups=function_group_instances,
|
|
270
302
|
llms={
|
|
271
303
|
k: v.instance
|
|
272
304
|
for k, v in self._llms.items()
|
|
@@ -347,11 +379,51 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
347
379
|
|
|
348
380
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
349
381
|
|
|
382
|
+
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
383
|
+
"""Build a function group from the provided configuration.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
name: The name of the function group
|
|
387
|
+
config: The function group configuration
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
ConfiguredFunctionGroup: The built function group
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
ValueError: If the function group builder returns invalid results
|
|
394
|
+
"""
|
|
395
|
+
registration = self._registry.get_function_group(type(config))
|
|
396
|
+
|
|
397
|
+
inner_builder = ChildBuilder(self)
|
|
398
|
+
|
|
399
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
400
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
401
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
402
|
+
|
|
403
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
404
|
+
|
|
405
|
+
# Set the currently building function group so the ChildBuilder can track dependencies
|
|
406
|
+
self.current_function_group_building = config.type
|
|
407
|
+
# Empty set of dependencies for the current function group
|
|
408
|
+
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
409
|
+
|
|
410
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
411
|
+
|
|
412
|
+
self.function_group_dependencies[name] = inner_builder.dependencies
|
|
413
|
+
|
|
414
|
+
if not isinstance(build_result, FunctionGroup):
|
|
415
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
416
|
+
f"Got {type(build_result)}")
|
|
417
|
+
|
|
418
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
419
|
+
|
|
350
420
|
@override
|
|
351
421
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
422
|
+
if isinstance(name, FunctionRef):
|
|
423
|
+
name = str(name)
|
|
352
424
|
|
|
353
|
-
if (name in self._functions):
|
|
354
|
-
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
425
|
+
if (name in self._functions or name in self._function_groups):
|
|
426
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
355
427
|
|
|
356
428
|
build_result = await self._build_function(name=name, config=config)
|
|
357
429
|
|
|
@@ -360,20 +432,66 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
360
432
|
return build_result.instance
|
|
361
433
|
|
|
362
434
|
@override
|
|
363
|
-
def
|
|
435
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
436
|
+
if isinstance(name, FunctionGroupRef):
|
|
437
|
+
name = str(name)
|
|
438
|
+
|
|
439
|
+
if (name in self._function_groups or name in self._functions):
|
|
440
|
+
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
441
|
+
|
|
442
|
+
# Build the function group
|
|
443
|
+
build_result = await self._build_function_group(name=name, config=config)
|
|
444
|
+
|
|
445
|
+
self._function_groups[name] = build_result
|
|
446
|
+
|
|
447
|
+
# If the function group exposes functions, add them to the global function registry
|
|
448
|
+
# If the function group exposes functions, record and add them to the registry
|
|
449
|
+
for k in build_result.instance.get_included_functions():
|
|
450
|
+
if k in self._functions:
|
|
451
|
+
raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
|
|
452
|
+
self._functions.update({
|
|
453
|
+
k: ConfiguredFunction(config=v.config, instance=v)
|
|
454
|
+
for k, v in build_result.instance.get_included_functions().items()
|
|
455
|
+
})
|
|
456
|
+
|
|
457
|
+
return build_result.instance
|
|
364
458
|
|
|
459
|
+
@override
|
|
460
|
+
def get_function(self, name: str | FunctionRef) -> Function:
|
|
461
|
+
if isinstance(name, FunctionRef):
|
|
462
|
+
name = str(name)
|
|
365
463
|
if name not in self._functions:
|
|
366
464
|
raise ValueError(f"Function `{name}` not found")
|
|
367
465
|
|
|
368
466
|
return self._functions[name].instance
|
|
369
467
|
|
|
468
|
+
@override
|
|
469
|
+
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
470
|
+
if isinstance(name, FunctionGroupRef):
|
|
471
|
+
name = str(name)
|
|
472
|
+
if name not in self._function_groups:
|
|
473
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
474
|
+
|
|
475
|
+
return self._function_groups[name].instance
|
|
476
|
+
|
|
370
477
|
@override
|
|
371
478
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
479
|
+
if isinstance(name, FunctionRef):
|
|
480
|
+
name = str(name)
|
|
372
481
|
if name not in self._functions:
|
|
373
482
|
raise ValueError(f"Function `{name}` not found")
|
|
374
483
|
|
|
375
484
|
return self._functions[name].config
|
|
376
485
|
|
|
486
|
+
@override
|
|
487
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
488
|
+
if isinstance(name, FunctionGroupRef):
|
|
489
|
+
name = str(name)
|
|
490
|
+
if name not in self._function_groups:
|
|
491
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
492
|
+
|
|
493
|
+
return self._function_groups[name].config
|
|
494
|
+
|
|
377
495
|
@override
|
|
378
496
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
379
497
|
|
|
@@ -403,16 +521,59 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
403
521
|
|
|
404
522
|
@override
|
|
405
523
|
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
524
|
+
if isinstance(fn_name, FunctionRef):
|
|
525
|
+
fn_name = str(fn_name)
|
|
406
526
|
return self.function_dependencies[fn_name]
|
|
407
527
|
|
|
408
528
|
@override
|
|
409
|
-
def
|
|
529
|
+
def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
|
|
530
|
+
if isinstance(fn_name, FunctionGroupRef):
|
|
531
|
+
fn_name = str(fn_name)
|
|
532
|
+
return self.function_group_dependencies[fn_name]
|
|
533
|
+
|
|
534
|
+
@override
|
|
535
|
+
def get_tools(self,
|
|
536
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
537
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
538
|
+
tools = []
|
|
539
|
+
seen = set()
|
|
540
|
+
for n in tool_names:
|
|
541
|
+
is_function_group_ref = isinstance(n, FunctionGroupRef)
|
|
542
|
+
if isinstance(n, FunctionRef) or is_function_group_ref:
|
|
543
|
+
n = str(n)
|
|
544
|
+
if n in seen:
|
|
545
|
+
raise ValueError(f"Function or Function Group `{n}` already seen")
|
|
546
|
+
seen.add(n)
|
|
547
|
+
if n not in self._function_groups:
|
|
548
|
+
# the passed tool name is probably a function
|
|
549
|
+
if is_function_group_ref:
|
|
550
|
+
raise ValueError(f"Function group `{n}` not found in the list of function groups")
|
|
551
|
+
tools.append(self.get_tool(n, wrapper_type))
|
|
552
|
+
continue
|
|
553
|
+
|
|
554
|
+
# Using the registry, get the tool wrapper for the requested framework
|
|
555
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
556
|
+
|
|
557
|
+
current_function_group = self._function_groups[n]
|
|
558
|
+
|
|
559
|
+
# walk through all functions in the function group -- guaranteed to not be fallible
|
|
560
|
+
for fn_name, fn_instance in current_function_group.instance.get_accessible_functions().items():
|
|
561
|
+
try:
|
|
562
|
+
# Wrap in the correct wrapper and add to tools list
|
|
563
|
+
tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
|
|
564
|
+
except Exception:
|
|
565
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
566
|
+
raise
|
|
567
|
+
|
|
568
|
+
return tools
|
|
410
569
|
|
|
570
|
+
@override
|
|
571
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
572
|
+
if isinstance(fn_name, FunctionRef):
|
|
573
|
+
fn_name = str(fn_name)
|
|
411
574
|
if fn_name not in self._functions:
|
|
412
575
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
413
|
-
|
|
414
576
|
fn = self._functions[fn_name]
|
|
415
|
-
|
|
416
577
|
try:
|
|
417
578
|
# Using the registry, get the tool wrapper for the requested framework
|
|
418
579
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
@@ -892,12 +1053,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
892
1053
|
# Instantiate a memory client
|
|
893
1054
|
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
894
1055
|
await self.add_memory_client(component_instance.name, component_instance.config)
|
|
895
|
-
|
|
1056
|
+
# Instantiate a object store client
|
|
896
1057
|
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
897
1058
|
await self.add_object_store(component_instance.name, component_instance.config)
|
|
898
1059
|
# Instantiate a retriever client
|
|
899
1060
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
900
1061
|
await self.add_retriever(component_instance.name, component_instance.config)
|
|
1062
|
+
# Instantiate a function group
|
|
1063
|
+
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1064
|
+
await self.add_function_group(component_instance.name, component_instance.config)
|
|
901
1065
|
# Instantiate a function
|
|
902
1066
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
903
1067
|
# If the function is the root, set it as the workflow later
|
|
@@ -956,6 +1120,10 @@ class ChildBuilder(Builder):
|
|
|
956
1120
|
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
957
1121
|
return await self._workflow_builder.add_function(name, config)
|
|
958
1122
|
|
|
1123
|
+
@override
|
|
1124
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1125
|
+
return await self._workflow_builder.add_function_group(name, config)
|
|
1126
|
+
|
|
959
1127
|
@override
|
|
960
1128
|
def get_function(self, name: str) -> Function:
|
|
961
1129
|
# If a function tries to get another function, we assume it uses it
|
|
@@ -965,10 +1133,23 @@ class ChildBuilder(Builder):
|
|
|
965
1133
|
|
|
966
1134
|
return fn
|
|
967
1135
|
|
|
1136
|
+
@override
|
|
1137
|
+
def get_function_group(self, name: str) -> FunctionGroup:
|
|
1138
|
+
# If a function tries to get a function group, we assume it uses it
|
|
1139
|
+
function_group = self._workflow_builder.get_function_group(name)
|
|
1140
|
+
|
|
1141
|
+
self._dependencies.add_function_group(name)
|
|
1142
|
+
|
|
1143
|
+
return function_group
|
|
1144
|
+
|
|
968
1145
|
@override
|
|
969
1146
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
970
1147
|
return self._workflow_builder.get_function_config(name)
|
|
971
1148
|
|
|
1149
|
+
@override
|
|
1150
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1151
|
+
return self._workflow_builder.get_function_group_config(name)
|
|
1152
|
+
|
|
972
1153
|
@override
|
|
973
1154
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
974
1155
|
return await self._workflow_builder.set_workflow(config)
|
|
@@ -982,7 +1163,19 @@ class ChildBuilder(Builder):
|
|
|
982
1163
|
return self._workflow_builder.get_workflow_config()
|
|
983
1164
|
|
|
984
1165
|
@override
|
|
985
|
-
def
|
|
1166
|
+
def get_tools(self,
|
|
1167
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1168
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1169
|
+
tools = self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1170
|
+
for tool_name in tool_names:
|
|
1171
|
+
if tool_name in self._workflow_builder._function_groups:
|
|
1172
|
+
self._dependencies.add_function_group(tool_name)
|
|
1173
|
+
else:
|
|
1174
|
+
self._dependencies.add_function(tool_name)
|
|
1175
|
+
return tools
|
|
1176
|
+
|
|
1177
|
+
@override
|
|
1178
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
986
1179
|
# If a function tries to get another function as a tool, we assume it uses it
|
|
987
1180
|
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
988
1181
|
|
|
@@ -1111,3 +1304,7 @@ class ChildBuilder(Builder):
|
|
|
1111
1304
|
@override
|
|
1112
1305
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1113
1306
|
return self._workflow_builder.get_function_dependencies(fn_name)
|
|
1307
|
+
|
|
1308
|
+
@override
|
|
1309
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1310
|
+
return self._workflow_builder.get_function_group_dependencies(fn_name)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import click
|
|
21
|
+
|
|
22
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
23
|
+
from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
|
|
29
|
+
@click.option(
|
|
30
|
+
"--config_file",
|
|
31
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
32
|
+
required=True,
|
|
33
|
+
help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
|
|
34
|
+
)
|
|
35
|
+
@click.option(
|
|
36
|
+
"--dataset",
|
|
37
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
38
|
+
required=False,
|
|
39
|
+
help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
|
|
40
|
+
)
|
|
41
|
+
@click.option(
|
|
42
|
+
"--result_json_path",
|
|
43
|
+
type=str,
|
|
44
|
+
default="$",
|
|
45
|
+
help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
|
|
46
|
+
"multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
|
|
47
|
+
"from the result."),
|
|
48
|
+
)
|
|
49
|
+
@click.option(
|
|
50
|
+
"--endpoint",
|
|
51
|
+
type=str,
|
|
52
|
+
default=None,
|
|
53
|
+
help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
|
|
54
|
+
)
|
|
55
|
+
@click.option(
|
|
56
|
+
"--endpoint_timeout",
|
|
57
|
+
type=int,
|
|
58
|
+
default=300,
|
|
59
|
+
help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
|
|
60
|
+
)
|
|
61
|
+
@click.pass_context
|
|
62
|
+
def optimizer_command(ctx, **kwargs) -> None:
|
|
63
|
+
""" Optimize workflow with the specified dataset"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def run_optimizer(config: OptimizerRunConfig):
|
|
68
|
+
await optimize_config(config)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@optimizer_command.result_callback(replace=True)
|
|
72
|
+
def run_optimizer_callback(
|
|
73
|
+
processors, # pylint: disable=unused-argument
|
|
74
|
+
*,
|
|
75
|
+
config_file: Path,
|
|
76
|
+
dataset: Path,
|
|
77
|
+
result_json_path: str,
|
|
78
|
+
endpoint: str,
|
|
79
|
+
endpoint_timeout: int,
|
|
80
|
+
):
|
|
81
|
+
"""Run the optimizer with the provided config file and dataset."""
|
|
82
|
+
config = OptimizerRunConfig(
|
|
83
|
+
config_file=config_file,
|
|
84
|
+
dataset=dataset,
|
|
85
|
+
result_json_path=result_json_path,
|
|
86
|
+
endpoint=endpoint,
|
|
87
|
+
endpoint_timeout=endpoint_timeout,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
asyncio.run(run_optimizer(config))
|
|
@@ -37,7 +37,7 @@ def _get_nat_dependency(versioned: bool = True) -> str:
|
|
|
37
37
|
Returns:
|
|
38
38
|
str: The dependency string to use in pyproject.toml
|
|
39
39
|
"""
|
|
40
|
-
# Assume the default dependency is
|
|
40
|
+
# Assume the default dependency is LangChain/LangGraph
|
|
41
41
|
dependency = "nvidia-nat[langchain]"
|
|
42
42
|
|
|
43
43
|
if not versioned:
|
|
@@ -171,6 +171,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
171
171
|
workflow_dir (str): The directory to create the workflow package.
|
|
172
172
|
description (str): Description to pre-popluate the workflow docstring.
|
|
173
173
|
"""
|
|
174
|
+
# Fail fast with Click's standard exit code (2) for bad params.
|
|
175
|
+
if not workflow_name or not workflow_name.strip():
|
|
176
|
+
raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
|
|
174
177
|
try:
|
|
175
178
|
# Get the repository root
|
|
176
179
|
try:
|
|
@@ -217,15 +220,13 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
217
220
|
else:
|
|
218
221
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
219
222
|
|
|
220
|
-
config_source = configs_dir / 'config.yml'
|
|
221
|
-
|
|
222
223
|
# List of templates and their destinations
|
|
223
224
|
files_to_render = {
|
|
224
225
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
225
226
|
'register.py.j2': base_dir / 'register.py',
|
|
226
227
|
'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
|
|
227
228
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
228
|
-
'config.yml.j2':
|
|
229
|
+
'config.yml.j2': configs_dir / 'config.yml',
|
|
229
230
|
}
|
|
230
231
|
|
|
231
232
|
# Render templates
|
|
@@ -246,10 +247,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
246
247
|
with open(output_path, 'w', encoding="utf-8") as f:
|
|
247
248
|
f.write(content)
|
|
248
249
|
|
|
249
|
-
# Create symlink for config.yml
|
|
250
|
-
config_link = new_workflow_dir / 'configs' / 'config.yml'
|
|
251
|
-
os.symlink(config_source, config_link)
|
|
252
|
-
|
|
253
250
|
# Create symlinks for config and data directories
|
|
254
251
|
config_dir_source = configs_dir
|
|
255
252
|
config_dir_link = new_workflow_dir / 'configs'
|
nat/cli/entrypoint.py
CHANGED
|
@@ -34,6 +34,7 @@ from .commands.configure.configure import configure_command
|
|
|
34
34
|
from .commands.evaluate import eval_command
|
|
35
35
|
from .commands.info.info import info_command
|
|
36
36
|
from .commands.object_store.object_store import object_store_command
|
|
37
|
+
from .commands.optimize import optimizer_command
|
|
37
38
|
from .commands.registry.registry import registry_command
|
|
38
39
|
from .commands.sizing.sizing import sizing
|
|
39
40
|
from .commands.start import start_command
|
|
@@ -108,6 +109,7 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
108
109
|
cli.add_command(validate_command, name="validate")
|
|
109
110
|
cli.add_command(workflow_command, name="workflow")
|
|
110
111
|
cli.add_command(sizing, name="sizing")
|
|
112
|
+
cli.add_command(optimizer_command, name="optimize")
|
|
111
113
|
cli.add_command(object_store_command, name="object-store")
|
|
112
114
|
|
|
113
115
|
# Aliases
|
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|