nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +17 -14
- nat/agent/reasoning_agent/reasoning_agent.py +9 -7
- nat/agent/register.py +1 -0
- nat/agent/rewoo_agent/agent.py +9 -2
- nat/agent/rewoo_agent/register.py +16 -12
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +14 -13
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/context.py +28 -6
- nat/builder/function.py +313 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +215 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +4 -9
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +167 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/authentication.py +38 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +5 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
- nat/front_ends/mcp/tool_converter.py +3 -0
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/log_levels.py +25 -0
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -16,7 +16,9 @@
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
import inspect
|
|
18
18
|
import logging
|
|
19
|
+
import typing
|
|
19
20
|
import warnings
|
|
21
|
+
from collections.abc import Sequence
|
|
20
22
|
from contextlib import AbstractAsyncContextManager
|
|
21
23
|
from contextlib import AsyncExitStack
|
|
22
24
|
from contextlib import asynccontextmanager
|
|
@@ -31,6 +33,7 @@ from nat.builder.context import ContextState
|
|
|
31
33
|
from nat.builder.embedder import EmbedderProviderInfo
|
|
32
34
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
33
35
|
from nat.builder.function import Function
|
|
36
|
+
from nat.builder.function import FunctionGroup
|
|
34
37
|
from nat.builder.function import LambdaFunction
|
|
35
38
|
from nat.builder.function_info import FunctionInfo
|
|
36
39
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -42,6 +45,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
|
|
|
42
45
|
from nat.data_models.component import ComponentGroup
|
|
43
46
|
from nat.data_models.component_ref import AuthenticationRef
|
|
44
47
|
from nat.data_models.component_ref import EmbedderRef
|
|
48
|
+
from nat.data_models.component_ref import FunctionGroupRef
|
|
45
49
|
from nat.data_models.component_ref import FunctionRef
|
|
46
50
|
from nat.data_models.component_ref import LLMRef
|
|
47
51
|
from nat.data_models.component_ref import MemoryRef
|
|
@@ -52,6 +56,7 @@ from nat.data_models.config import Config
|
|
|
52
56
|
from nat.data_models.config import GeneralConfig
|
|
53
57
|
from nat.data_models.embedder import EmbedderBaseConfig
|
|
54
58
|
from nat.data_models.function import FunctionBaseConfig
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
55
60
|
from nat.data_models.function_dependencies import FunctionDependencies
|
|
56
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
57
62
|
from nat.data_models.memory import MemoryBaseConfig
|
|
@@ -85,6 +90,12 @@ class ConfiguredFunction:
|
|
|
85
90
|
instance: Function
|
|
86
91
|
|
|
87
92
|
|
|
93
|
+
@dataclasses.dataclass
|
|
94
|
+
class ConfiguredFunctionGroup:
|
|
95
|
+
config: FunctionGroupBaseConfig
|
|
96
|
+
instance: FunctionGroup
|
|
97
|
+
|
|
98
|
+
|
|
88
99
|
@dataclasses.dataclass
|
|
89
100
|
class ConfiguredLLM:
|
|
90
101
|
config: LLMBaseConfig
|
|
@@ -145,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
145
156
|
self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
|
|
146
157
|
|
|
147
158
|
self._functions: dict[str, ConfiguredFunction] = {}
|
|
159
|
+
self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
|
|
148
160
|
self._workflow: ConfiguredFunction | None = None
|
|
149
161
|
|
|
150
162
|
self._llms: dict[str, ConfiguredLLM] = {}
|
|
@@ -161,7 +173,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
161
173
|
|
|
162
174
|
# Create a mapping to track function name -> other function names it depends on
|
|
163
175
|
self.function_dependencies: dict[str, FunctionDependencies] = {}
|
|
176
|
+
self.function_group_dependencies: dict[str, FunctionDependencies] = {}
|
|
164
177
|
self.current_function_building: str | None = None
|
|
178
|
+
self.current_function_group_building: str | None = None
|
|
165
179
|
|
|
166
180
|
async def __aenter__(self):
|
|
167
181
|
|
|
@@ -224,12 +238,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
224
238
|
if (self._workflow is None):
|
|
225
239
|
raise ValueError("Must set a workflow before building")
|
|
226
240
|
|
|
241
|
+
# Set of all functions which are "included" by function groups
|
|
242
|
+
included_functions = set()
|
|
243
|
+
# Dictionary of function configs
|
|
244
|
+
function_configs = dict()
|
|
245
|
+
# Dictionary of function group configs
|
|
246
|
+
function_group_configs = dict()
|
|
247
|
+
# Dictionary of function instances
|
|
248
|
+
function_instances = dict()
|
|
249
|
+
# Dictionary of function group instances
|
|
250
|
+
function_group_instances = dict()
|
|
251
|
+
|
|
252
|
+
for k, v in self._function_groups.items():
|
|
253
|
+
included_functions.update(v.instance.get_included_functions().keys())
|
|
254
|
+
function_group_configs[k] = v.config
|
|
255
|
+
function_group_instances[k] = v.instance
|
|
256
|
+
|
|
257
|
+
# Function configs need to be restricted to only the functions that are not in a function group
|
|
258
|
+
for k, v in self._functions.items():
|
|
259
|
+
if k not in included_functions:
|
|
260
|
+
function_configs[k] = v.config
|
|
261
|
+
function_instances[k] = v.instance
|
|
262
|
+
|
|
227
263
|
# Build the config from the added objects
|
|
228
264
|
config = Config(general=self.general_config,
|
|
229
|
-
functions=
|
|
230
|
-
|
|
231
|
-
for k, v in self._functions.items()
|
|
232
|
-
},
|
|
265
|
+
functions=function_configs,
|
|
266
|
+
function_groups=function_group_configs,
|
|
233
267
|
workflow=self._workflow.config,
|
|
234
268
|
llms={
|
|
235
269
|
k: v.config
|
|
@@ -263,10 +297,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
263
297
|
|
|
264
298
|
workflow = Workflow.from_entry_fn(config=config,
|
|
265
299
|
entry_fn=entry_fn_obj,
|
|
266
|
-
functions=
|
|
267
|
-
|
|
268
|
-
for k, v in self._functions.items()
|
|
269
|
-
},
|
|
300
|
+
functions=function_instances,
|
|
301
|
+
function_groups=function_group_instances,
|
|
270
302
|
llms={
|
|
271
303
|
k: v.instance
|
|
272
304
|
for k, v in self._llms.items()
|
|
@@ -347,11 +379,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
347
379
|
|
|
348
380
|
return ConfiguredFunction(config=config, instance=build_result)
|
|
349
381
|
|
|
382
|
+
async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
|
|
383
|
+
"""Build a function group from the provided configuration.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
name: The name of the function group
|
|
387
|
+
config: The function group configuration
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
ConfiguredFunctionGroup: The built function group
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
ValueError: If the function group builder returns invalid results
|
|
394
|
+
"""
|
|
395
|
+
registration = self._registry.get_function_group(type(config))
|
|
396
|
+
|
|
397
|
+
inner_builder = ChildBuilder(self)
|
|
398
|
+
|
|
399
|
+
# Build the function group - use the same wrapping pattern as _build_function
|
|
400
|
+
llms = {k: v.instance for k, v in self._llms.items()}
|
|
401
|
+
function_frameworks = detect_llm_frameworks_in_build_fn(registration)
|
|
402
|
+
|
|
403
|
+
build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
|
|
404
|
+
|
|
405
|
+
# Set the currently building function group so the ChildBuilder can track dependencies
|
|
406
|
+
self.current_function_group_building = config.type
|
|
407
|
+
# Empty set of dependencies for the current function group
|
|
408
|
+
self.function_group_dependencies[config.type] = FunctionDependencies()
|
|
409
|
+
|
|
410
|
+
build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
|
|
411
|
+
|
|
412
|
+
self.function_group_dependencies[name] = inner_builder.dependencies
|
|
413
|
+
|
|
414
|
+
if not isinstance(build_result, FunctionGroup):
|
|
415
|
+
raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
|
|
416
|
+
f"Got {type(build_result)}")
|
|
417
|
+
|
|
418
|
+
# set the instance name for the function group based on the workflow-provided name
|
|
419
|
+
build_result.set_instance_name(name)
|
|
420
|
+
return ConfiguredFunctionGroup(config=config, instance=build_result)
|
|
421
|
+
|
|
350
422
|
@override
|
|
351
423
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
|
424
|
+
if isinstance(name, FunctionRef):
|
|
425
|
+
name = str(name)
|
|
352
426
|
|
|
353
|
-
if (name in self._functions):
|
|
354
|
-
raise ValueError(f"Function `{name}` already exists in the list of functions")
|
|
427
|
+
if (name in self._functions or name in self._function_groups):
|
|
428
|
+
raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
|
|
355
429
|
|
|
356
430
|
build_result = await self._build_function(name=name, config=config)
|
|
357
431
|
|
|
@@ -360,20 +434,66 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
360
434
|
return build_result.instance
|
|
361
435
|
|
|
362
436
|
@override
|
|
363
|
-
def
|
|
437
|
+
async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
438
|
+
if isinstance(name, FunctionGroupRef):
|
|
439
|
+
name = str(name)
|
|
440
|
+
|
|
441
|
+
if (name in self._function_groups or name in self._functions):
|
|
442
|
+
raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
|
|
443
|
+
|
|
444
|
+
# Build the function group
|
|
445
|
+
build_result = await self._build_function_group(name=name, config=config)
|
|
446
|
+
|
|
447
|
+
self._function_groups[name] = build_result
|
|
448
|
+
|
|
449
|
+
# If the function group exposes functions, add them to the global function registry
|
|
450
|
+
# If the function group exposes functions, record and add them to the registry
|
|
451
|
+
for k in build_result.instance.get_included_functions():
|
|
452
|
+
if k in self._functions:
|
|
453
|
+
raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
|
|
454
|
+
self._functions.update({
|
|
455
|
+
k: ConfiguredFunction(config=v.config, instance=v)
|
|
456
|
+
for k, v in build_result.instance.get_included_functions().items()
|
|
457
|
+
})
|
|
458
|
+
|
|
459
|
+
return build_result.instance
|
|
364
460
|
|
|
461
|
+
@override
|
|
462
|
+
def get_function(self, name: str | FunctionRef) -> Function:
|
|
463
|
+
if isinstance(name, FunctionRef):
|
|
464
|
+
name = str(name)
|
|
365
465
|
if name not in self._functions:
|
|
366
466
|
raise ValueError(f"Function `{name}` not found")
|
|
367
467
|
|
|
368
468
|
return self._functions[name].instance
|
|
369
469
|
|
|
470
|
+
@override
|
|
471
|
+
def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
|
|
472
|
+
if isinstance(name, FunctionGroupRef):
|
|
473
|
+
name = str(name)
|
|
474
|
+
if name not in self._function_groups:
|
|
475
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
476
|
+
|
|
477
|
+
return self._function_groups[name].instance
|
|
478
|
+
|
|
370
479
|
@override
|
|
371
480
|
def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
|
|
481
|
+
if isinstance(name, FunctionRef):
|
|
482
|
+
name = str(name)
|
|
372
483
|
if name not in self._functions:
|
|
373
484
|
raise ValueError(f"Function `{name}` not found")
|
|
374
485
|
|
|
375
486
|
return self._functions[name].config
|
|
376
487
|
|
|
488
|
+
@override
|
|
489
|
+
def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
|
|
490
|
+
if isinstance(name, FunctionGroupRef):
|
|
491
|
+
name = str(name)
|
|
492
|
+
if name not in self._function_groups:
|
|
493
|
+
raise ValueError(f"Function group `{name}` not found")
|
|
494
|
+
|
|
495
|
+
return self._function_groups[name].config
|
|
496
|
+
|
|
377
497
|
@override
|
|
378
498
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
379
499
|
|
|
@@ -403,16 +523,59 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
403
523
|
|
|
404
524
|
@override
|
|
405
525
|
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
|
|
526
|
+
if isinstance(fn_name, FunctionRef):
|
|
527
|
+
fn_name = str(fn_name)
|
|
406
528
|
return self.function_dependencies[fn_name]
|
|
407
529
|
|
|
408
530
|
@override
|
|
409
|
-
def
|
|
531
|
+
def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
|
|
532
|
+
if isinstance(fn_name, FunctionGroupRef):
|
|
533
|
+
fn_name = str(fn_name)
|
|
534
|
+
return self.function_group_dependencies[fn_name]
|
|
535
|
+
|
|
536
|
+
@override
|
|
537
|
+
def get_tools(self,
|
|
538
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
539
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
540
|
+
tools = []
|
|
541
|
+
seen = set()
|
|
542
|
+
for n in tool_names:
|
|
543
|
+
is_function_group_ref = isinstance(n, FunctionGroupRef)
|
|
544
|
+
if isinstance(n, FunctionRef) or is_function_group_ref:
|
|
545
|
+
n = str(n)
|
|
546
|
+
if n in seen:
|
|
547
|
+
raise ValueError(f"Function or Function Group `{n}` already seen")
|
|
548
|
+
seen.add(n)
|
|
549
|
+
if n not in self._function_groups:
|
|
550
|
+
# the passed tool name is probably a function
|
|
551
|
+
if is_function_group_ref:
|
|
552
|
+
raise ValueError(f"Function group `{n}` not found in the list of function groups")
|
|
553
|
+
tools.append(self.get_tool(n, wrapper_type))
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
# Using the registry, get the tool wrapper for the requested framework
|
|
557
|
+
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
558
|
+
|
|
559
|
+
current_function_group = self._function_groups[n]
|
|
560
|
+
|
|
561
|
+
# walk through all functions in the function group -- guaranteed to not be fallible
|
|
562
|
+
for fn_name, fn_instance in current_function_group.instance.get_accessible_functions().items():
|
|
563
|
+
try:
|
|
564
|
+
# Wrap in the correct wrapper and add to tools list
|
|
565
|
+
tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
|
|
566
|
+
except Exception:
|
|
567
|
+
logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
|
|
568
|
+
raise
|
|
569
|
+
|
|
570
|
+
return tools
|
|
410
571
|
|
|
572
|
+
@override
|
|
573
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
574
|
+
if isinstance(fn_name, FunctionRef):
|
|
575
|
+
fn_name = str(fn_name)
|
|
411
576
|
if fn_name not in self._functions:
|
|
412
577
|
raise ValueError(f"Function `{fn_name}` not found in list of functions")
|
|
413
|
-
|
|
414
578
|
fn = self._functions[fn_name]
|
|
415
|
-
|
|
416
579
|
try:
|
|
417
580
|
# Using the registry, get the tool wrapper for the requested framework
|
|
418
581
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
@@ -892,12 +1055,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
892
1055
|
# Instantiate a memory client
|
|
893
1056
|
elif component_instance.component_group == ComponentGroup.MEMORY:
|
|
894
1057
|
await self.add_memory_client(component_instance.name, component_instance.config)
|
|
895
|
-
|
|
1058
|
+
# Instantiate a object store client
|
|
896
1059
|
elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
|
|
897
1060
|
await self.add_object_store(component_instance.name, component_instance.config)
|
|
898
1061
|
# Instantiate a retriever client
|
|
899
1062
|
elif component_instance.component_group == ComponentGroup.RETRIEVERS:
|
|
900
1063
|
await self.add_retriever(component_instance.name, component_instance.config)
|
|
1064
|
+
# Instantiate a function group
|
|
1065
|
+
elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
|
|
1066
|
+
await self.add_function_group(component_instance.name, component_instance.config)
|
|
901
1067
|
# Instantiate a function
|
|
902
1068
|
elif component_instance.component_group == ComponentGroup.FUNCTIONS:
|
|
903
1069
|
# If the function is the root, set it as the workflow later
|
|
@@ -956,6 +1122,10 @@ class ChildBuilder(Builder):
|
|
|
956
1122
|
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
957
1123
|
return await self._workflow_builder.add_function(name, config)
|
|
958
1124
|
|
|
1125
|
+
@override
|
|
1126
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
1127
|
+
return await self._workflow_builder.add_function_group(name, config)
|
|
1128
|
+
|
|
959
1129
|
@override
|
|
960
1130
|
def get_function(self, name: str) -> Function:
|
|
961
1131
|
# If a function tries to get another function, we assume it uses it
|
|
@@ -965,10 +1135,23 @@ class ChildBuilder(Builder):
|
|
|
965
1135
|
|
|
966
1136
|
return fn
|
|
967
1137
|
|
|
1138
|
+
@override
|
|
1139
|
+
def get_function_group(self, name: str) -> FunctionGroup:
|
|
1140
|
+
# If a function tries to get a function group, we assume it uses it
|
|
1141
|
+
function_group = self._workflow_builder.get_function_group(name)
|
|
1142
|
+
|
|
1143
|
+
self._dependencies.add_function_group(name)
|
|
1144
|
+
|
|
1145
|
+
return function_group
|
|
1146
|
+
|
|
968
1147
|
@override
|
|
969
1148
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
970
1149
|
return self._workflow_builder.get_function_config(name)
|
|
971
1150
|
|
|
1151
|
+
@override
|
|
1152
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
1153
|
+
return self._workflow_builder.get_function_group_config(name)
|
|
1154
|
+
|
|
972
1155
|
@override
|
|
973
1156
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
974
1157
|
return await self._workflow_builder.set_workflow(config)
|
|
@@ -982,7 +1165,19 @@ class ChildBuilder(Builder):
|
|
|
982
1165
|
return self._workflow_builder.get_workflow_config()
|
|
983
1166
|
|
|
984
1167
|
@override
|
|
985
|
-
def
|
|
1168
|
+
def get_tools(self,
|
|
1169
|
+
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
|
|
1170
|
+
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
|
|
1171
|
+
tools = self._workflow_builder.get_tools(tool_names, wrapper_type)
|
|
1172
|
+
for tool_name in tool_names:
|
|
1173
|
+
if tool_name in self._workflow_builder._function_groups:
|
|
1174
|
+
self._dependencies.add_function_group(tool_name)
|
|
1175
|
+
else:
|
|
1176
|
+
self._dependencies.add_function(tool_name)
|
|
1177
|
+
return tools
|
|
1178
|
+
|
|
1179
|
+
@override
|
|
1180
|
+
def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
|
|
986
1181
|
# If a function tries to get another function as a tool, we assume it uses it
|
|
987
1182
|
fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
|
|
988
1183
|
|
|
@@ -1111,3 +1306,7 @@ class ChildBuilder(Builder):
|
|
|
1111
1306
|
@override
|
|
1112
1307
|
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1113
1308
|
return self._workflow_builder.get_function_dependencies(fn_name)
|
|
1309
|
+
|
|
1310
|
+
@override
|
|
1311
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
1312
|
+
return self._workflow_builder.get_function_group_dependencies(fn_name)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
|
|
20
|
+
import click
|
|
21
|
+
|
|
22
|
+
from nat.data_models.optimizer import OptimizerRunConfig
|
|
23
|
+
from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
|
|
29
|
+
@click.option(
|
|
30
|
+
"--config_file",
|
|
31
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
32
|
+
required=True,
|
|
33
|
+
help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
|
|
34
|
+
)
|
|
35
|
+
@click.option(
|
|
36
|
+
"--dataset",
|
|
37
|
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
|
|
38
|
+
required=False,
|
|
39
|
+
help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
|
|
40
|
+
)
|
|
41
|
+
@click.option(
|
|
42
|
+
"--result_json_path",
|
|
43
|
+
type=str,
|
|
44
|
+
default="$",
|
|
45
|
+
help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
|
|
46
|
+
"multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
|
|
47
|
+
"from the result."),
|
|
48
|
+
)
|
|
49
|
+
@click.option(
|
|
50
|
+
"--endpoint",
|
|
51
|
+
type=str,
|
|
52
|
+
default=None,
|
|
53
|
+
help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
|
|
54
|
+
)
|
|
55
|
+
@click.option(
|
|
56
|
+
"--endpoint_timeout",
|
|
57
|
+
type=int,
|
|
58
|
+
default=300,
|
|
59
|
+
help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
|
|
60
|
+
)
|
|
61
|
+
@click.pass_context
|
|
62
|
+
def optimizer_command(ctx, **kwargs) -> None:
|
|
63
|
+
""" Optimize workflow with the specified dataset"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
async def run_optimizer(config: OptimizerRunConfig):
|
|
68
|
+
await optimize_config(config)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@optimizer_command.result_callback(replace=True)
|
|
72
|
+
def run_optimizer_callback(
|
|
73
|
+
processors, # pylint: disable=unused-argument
|
|
74
|
+
*,
|
|
75
|
+
config_file: Path,
|
|
76
|
+
dataset: Path,
|
|
77
|
+
result_json_path: str,
|
|
78
|
+
endpoint: str,
|
|
79
|
+
endpoint_timeout: int,
|
|
80
|
+
):
|
|
81
|
+
"""Run the optimizer with the provided config file and dataset."""
|
|
82
|
+
config = OptimizerRunConfig(
|
|
83
|
+
config_file=config_file,
|
|
84
|
+
dataset=dataset,
|
|
85
|
+
result_json_path=result_json_path,
|
|
86
|
+
endpoint=endpoint,
|
|
87
|
+
endpoint_timeout=endpoint_timeout,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
asyncio.run(run_optimizer(config))
|
|
@@ -171,6 +171,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
171
171
|
workflow_dir (str): The directory to create the workflow package.
|
|
172
172
|
description (str): Description to pre-popluate the workflow docstring.
|
|
173
173
|
"""
|
|
174
|
+
# Fail fast with Click's standard exit code (2) for bad params.
|
|
175
|
+
if not workflow_name or not workflow_name.strip():
|
|
176
|
+
raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
|
|
174
177
|
try:
|
|
175
178
|
# Get the repository root
|
|
176
179
|
try:
|
|
@@ -217,15 +220,13 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
217
220
|
else:
|
|
218
221
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
219
222
|
|
|
220
|
-
config_source = configs_dir / 'config.yml'
|
|
221
|
-
|
|
222
223
|
# List of templates and their destinations
|
|
223
224
|
files_to_render = {
|
|
224
225
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
225
226
|
'register.py.j2': base_dir / 'register.py',
|
|
226
227
|
'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
|
|
227
228
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
228
|
-
'config.yml.j2':
|
|
229
|
+
'config.yml.j2': configs_dir / 'config.yml',
|
|
229
230
|
}
|
|
230
231
|
|
|
231
232
|
# Render templates
|
|
@@ -246,10 +247,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
246
247
|
with open(output_path, 'w', encoding="utf-8") as f:
|
|
247
248
|
f.write(content)
|
|
248
249
|
|
|
249
|
-
# Create symlink for config.yml
|
|
250
|
-
config_link = new_workflow_dir / 'configs' / 'config.yml'
|
|
251
|
-
os.symlink(config_source, config_link)
|
|
252
|
-
|
|
253
250
|
# Create symlinks for config and data directories
|
|
254
251
|
config_dir_source = configs_dir
|
|
255
252
|
config_dir_link = new_workflow_dir / 'configs'
|
nat/cli/entrypoint.py
CHANGED
|
@@ -30,10 +30,13 @@ import time
|
|
|
30
30
|
import click
|
|
31
31
|
import nest_asyncio
|
|
32
32
|
|
|
33
|
+
from nat.utils.log_levels import LOG_LEVELS
|
|
34
|
+
|
|
33
35
|
from .commands.configure.configure import configure_command
|
|
34
36
|
from .commands.evaluate import eval_command
|
|
35
37
|
from .commands.info.info import info_command
|
|
36
38
|
from .commands.object_store.object_store import object_store_command
|
|
39
|
+
from .commands.optimize import optimizer_command
|
|
37
40
|
from .commands.registry.registry import registry_command
|
|
38
41
|
from .commands.sizing.sizing import sizing
|
|
39
42
|
from .commands.start import start_command
|
|
@@ -44,15 +47,6 @@ from .commands.workflow.workflow import workflow_command
|
|
|
44
47
|
# Apply at the beginning of the file to avoid issues with asyncio
|
|
45
48
|
nest_asyncio.apply()
|
|
46
49
|
|
|
47
|
-
# Define log level choices
|
|
48
|
-
LOG_LEVELS = {
|
|
49
|
-
'DEBUG': logging.DEBUG,
|
|
50
|
-
'INFO': logging.INFO,
|
|
51
|
-
'WARNING': logging.WARNING,
|
|
52
|
-
'ERROR': logging.ERROR,
|
|
53
|
-
'CRITICAL': logging.CRITICAL
|
|
54
|
-
}
|
|
55
|
-
|
|
56
50
|
|
|
57
51
|
def setup_logging(log_level: str):
|
|
58
52
|
"""Configure logging with the specified level"""
|
|
@@ -108,6 +102,7 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
108
102
|
cli.add_command(validate_command, name="validate")
|
|
109
103
|
cli.add_command(workflow_command, name="workflow")
|
|
110
104
|
cli.add_command(sizing, name="sizing")
|
|
105
|
+
cli.add_command(optimizer_command, name="optimize")
|
|
111
106
|
cli.add_command(object_store_command, name="object-store")
|
|
112
107
|
|
|
113
108
|
# Aliases
|
nat/cli/register_workflow.py
CHANGED
|
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
|
|
|
27
27
|
from nat.cli.type_registry import FrontEndBuildCallableT
|
|
28
28
|
from nat.cli.type_registry import FrontEndRegisteredCallableT
|
|
29
29
|
from nat.cli.type_registry import FunctionBuildCallableT
|
|
30
|
+
from nat.cli.type_registry import FunctionGroupBuildCallableT
|
|
31
|
+
from nat.cli.type_registry import FunctionGroupRegisteredCallableT
|
|
30
32
|
from nat.cli.type_registry import FunctionRegisteredCallableT
|
|
31
33
|
from nat.cli.type_registry import LLMClientBuildCallableT
|
|
32
34
|
from nat.cli.type_registry import LLMClientRegisteredCallableT
|
|
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
|
|
|
60
62
|
from nat.data_models.evaluator import EvaluatorBaseConfigT
|
|
61
63
|
from nat.data_models.front_end import FrontEndConfigT
|
|
62
64
|
from nat.data_models.function import FunctionConfigT
|
|
65
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
63
66
|
from nat.data_models.llm import LLMBaseConfigT
|
|
64
67
|
from nat.data_models.memory import MemoryBaseConfigT
|
|
65
68
|
from nat.data_models.object_store import ObjectStoreBaseConfigT
|
|
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
155
158
|
|
|
156
159
|
context_manager_fn = asynccontextmanager(fn)
|
|
157
160
|
|
|
158
|
-
|
|
159
|
-
framework_wrappers_list: list[str] = []
|
|
160
|
-
else:
|
|
161
|
-
framework_wrappers_list = list(framework_wrappers)
|
|
161
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
162
162
|
|
|
163
163
|
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
164
164
|
component_type=ComponentEnum.FUNCTION)
|
|
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
|
|
|
177
177
|
return register_function_inner
|
|
178
178
|
|
|
179
179
|
|
|
180
|
+
def register_function_group(config_type: type[FunctionGroupConfigT],
|
|
181
|
+
framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
|
|
182
|
+
"""
|
|
183
|
+
Register a function group with optional framework_wrappers for automatic profiler hooking.
|
|
184
|
+
Function groups share configuration/resources across multiple functions.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def register_function_group_inner(
|
|
188
|
+
fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
|
|
189
|
+
) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
|
|
190
|
+
from .type_registry import GlobalTypeRegistry
|
|
191
|
+
from .type_registry import RegisteredFunctionGroupInfo
|
|
192
|
+
|
|
193
|
+
context_manager_fn = asynccontextmanager(fn)
|
|
194
|
+
|
|
195
|
+
framework_wrappers_list = list(framework_wrappers or [])
|
|
196
|
+
|
|
197
|
+
discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
|
|
198
|
+
component_type=ComponentEnum.FUNCTION_GROUP)
|
|
199
|
+
|
|
200
|
+
GlobalTypeRegistry.get().register_function_group(
|
|
201
|
+
RegisteredFunctionGroupInfo(
|
|
202
|
+
full_type=config_type.full_type,
|
|
203
|
+
config_type=config_type,
|
|
204
|
+
build_fn=context_manager_fn,
|
|
205
|
+
framework_wrappers=framework_wrappers_list,
|
|
206
|
+
discovery_metadata=discovery_metadata,
|
|
207
|
+
))
|
|
208
|
+
|
|
209
|
+
return context_manager_fn
|
|
210
|
+
|
|
211
|
+
return register_function_group_inner
|
|
212
|
+
|
|
213
|
+
|
|
180
214
|
def register_llm_provider(config_type: type[LLMBaseConfigT]):
|
|
181
215
|
|
|
182
216
|
def register_llm_provider_inner(
|