nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250922__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +17 -14
  6. nat/agent/reasoning_agent/reasoning_agent.py +9 -7
  7. nat/agent/register.py +1 -0
  8. nat/agent/rewoo_agent/agent.py +9 -2
  9. nat/agent/rewoo_agent/register.py +16 -12
  10. nat/agent/tool_calling_agent/agent.py +69 -7
  11. nat/agent/tool_calling_agent/register.py +14 -13
  12. nat/authentication/credential_validator/__init__.py +14 -0
  13. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  14. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  15. nat/builder/builder.py +27 -4
  16. nat/builder/component_utils.py +7 -3
  17. nat/builder/context.py +28 -6
  18. nat/builder/function.py +313 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +215 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +4 -9
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/control_flow/__init__.py +0 -0
  29. nat/control_flow/register.py +20 -0
  30. nat/control_flow/router_agent/__init__.py +0 -0
  31. nat/control_flow/router_agent/agent.py +329 -0
  32. nat/control_flow/router_agent/prompt.py +48 -0
  33. nat/control_flow/router_agent/register.py +91 -0
  34. nat/control_flow/sequential_executor.py +167 -0
  35. nat/data_models/agent.py +34 -0
  36. nat/data_models/authentication.py +38 -0
  37. nat/data_models/component.py +2 -0
  38. nat/data_models/component_ref.py +11 -0
  39. nat/data_models/config.py +40 -16
  40. nat/data_models/function.py +34 -0
  41. nat/data_models/function_dependencies.py +8 -0
  42. nat/data_models/optimizable.py +119 -0
  43. nat/data_models/optimizer.py +149 -0
  44. nat/data_models/temperature_mixin.py +4 -3
  45. nat/data_models/top_p_mixin.py +4 -3
  46. nat/embedder/nim_embedder.py +1 -1
  47. nat/embedder/openai_embedder.py +1 -1
  48. nat/eval/config.py +1 -1
  49. nat/eval/evaluate.py +5 -1
  50. nat/eval/register.py +4 -0
  51. nat/eval/runtime_evaluator/__init__.py +14 -0
  52. nat/eval/runtime_evaluator/evaluate.py +123 -0
  53. nat/eval/runtime_evaluator/register.py +100 -0
  54. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  55. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  56. nat/front_ends/fastapi/fastapi_front_end_config.py +18 -3
  57. nat/front_ends/fastapi/fastapi_front_end_plugin.py +134 -3
  58. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  59. nat/front_ends/fastapi/job_store.py +518 -99
  60. nat/front_ends/fastapi/main.py +11 -19
  61. nat/front_ends/fastapi/utils.py +57 -0
  62. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  63. nat/front_ends/mcp/mcp_front_end_config.py +5 -1
  64. nat/front_ends/mcp/mcp_front_end_plugin.py +37 -11
  65. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +111 -3
  66. nat/front_ends/mcp/tool_converter.py +3 -0
  67. nat/llm/aws_bedrock_llm.py +14 -3
  68. nat/llm/nim_llm.py +14 -3
  69. nat/llm/openai_llm.py +8 -1
  70. nat/observability/exporter/processing_exporter.py +29 -55
  71. nat/observability/mixin/redaction_config_mixin.py +5 -4
  72. nat/observability/mixin/tagging_config_mixin.py +26 -14
  73. nat/observability/mixin/type_introspection_mixin.py +420 -107
  74. nat/observability/processor/processor.py +3 -0
  75. nat/observability/processor/redaction/__init__.py +24 -0
  76. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  77. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  78. nat/observability/processor/redaction/redaction_processor.py +177 -0
  79. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  80. nat/observability/processor/span_tagging_processor.py +21 -14
  81. nat/profiler/decorators/framework_wrapper.py +9 -6
  82. nat/profiler/parameter_optimization/__init__.py +0 -0
  83. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  84. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  85. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  86. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  87. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  88. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  89. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  90. nat/profiler/utils.py +3 -1
  91. nat/tool/chat_completion.py +4 -1
  92. nat/tool/github_tools.py +450 -0
  93. nat/tool/register.py +2 -7
  94. nat/utils/callable_utils.py +70 -0
  95. nat/utils/exception_handlers/automatic_retries.py +103 -48
  96. nat/utils/log_levels.py +25 -0
  97. nat/utils/type_utils.py +4 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/METADATA +10 -1
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/RECORD +105 -76
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/entry_points.txt +1 -0
  101. nat/observability/processor/header_redaction_processor.py +0 -123
  102. nat/observability/processor/redaction_processor.py +0 -77
  103. nat/tool/github_tools/create_github_commit.py +0 -133
  104. nat/tool/github_tools/create_github_issue.py +0 -87
  105. nat/tool/github_tools/create_github_pr.py +0 -106
  106. nat/tool/github_tools/get_github_file.py +0 -106
  107. nat/tool/github_tools/get_github_issue.py +0 -166
  108. nat/tool/github_tools/get_github_pr.py +0 -256
  109. nat/tool/github_tools/update_github_issue.py +0 -100
  110. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  111. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/WHEEL +0 -0
  112. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  113. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/licenses/LICENSE.md +0 -0
  114. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250922.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,9 @@
16
16
  import dataclasses
17
17
  import inspect
18
18
  import logging
19
+ import typing
19
20
  import warnings
21
+ from collections.abc import Sequence
20
22
  from contextlib import AbstractAsyncContextManager
21
23
  from contextlib import AsyncExitStack
22
24
  from contextlib import asynccontextmanager
@@ -31,6 +33,7 @@ from nat.builder.context import ContextState
31
33
  from nat.builder.embedder import EmbedderProviderInfo
32
34
  from nat.builder.framework_enum import LLMFrameworkEnum
33
35
  from nat.builder.function import Function
36
+ from nat.builder.function import FunctionGroup
34
37
  from nat.builder.function import LambdaFunction
35
38
  from nat.builder.function_info import FunctionInfo
36
39
  from nat.builder.llm import LLMProviderInfo
@@ -42,6 +45,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
42
45
  from nat.data_models.component import ComponentGroup
43
46
  from nat.data_models.component_ref import AuthenticationRef
44
47
  from nat.data_models.component_ref import EmbedderRef
48
+ from nat.data_models.component_ref import FunctionGroupRef
45
49
  from nat.data_models.component_ref import FunctionRef
46
50
  from nat.data_models.component_ref import LLMRef
47
51
  from nat.data_models.component_ref import MemoryRef
@@ -52,6 +56,7 @@ from nat.data_models.config import Config
52
56
  from nat.data_models.config import GeneralConfig
53
57
  from nat.data_models.embedder import EmbedderBaseConfig
54
58
  from nat.data_models.function import FunctionBaseConfig
59
+ from nat.data_models.function import FunctionGroupBaseConfig
55
60
  from nat.data_models.function_dependencies import FunctionDependencies
56
61
  from nat.data_models.llm import LLMBaseConfig
57
62
  from nat.data_models.memory import MemoryBaseConfig
@@ -85,6 +90,12 @@ class ConfiguredFunction:
85
90
  instance: Function
86
91
 
87
92
 
93
+ @dataclasses.dataclass
94
+ class ConfiguredFunctionGroup:
95
+ config: FunctionGroupBaseConfig
96
+ instance: FunctionGroup
97
+
98
+
88
99
  @dataclasses.dataclass
89
100
  class ConfiguredLLM:
90
101
  config: LLMBaseConfig
@@ -145,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
145
156
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
146
157
 
147
158
  self._functions: dict[str, ConfiguredFunction] = {}
159
+ self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
148
160
  self._workflow: ConfiguredFunction | None = None
149
161
 
150
162
  self._llms: dict[str, ConfiguredLLM] = {}
@@ -161,7 +173,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
161
173
 
162
174
  # Create a mapping to track function name -> other function names it depends on
163
175
  self.function_dependencies: dict[str, FunctionDependencies] = {}
176
+ self.function_group_dependencies: dict[str, FunctionDependencies] = {}
164
177
  self.current_function_building: str | None = None
178
+ self.current_function_group_building: str | None = None
165
179
 
166
180
  async def __aenter__(self):
167
181
 
@@ -224,12 +238,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
224
238
  if (self._workflow is None):
225
239
  raise ValueError("Must set a workflow before building")
226
240
 
241
+ # Set of all functions which are "included" by function groups
242
+ included_functions = set()
243
+ # Dictionary of function configs
244
+ function_configs = dict()
245
+ # Dictionary of function group configs
246
+ function_group_configs = dict()
247
+ # Dictionary of function instances
248
+ function_instances = dict()
249
+ # Dictionary of function group instances
250
+ function_group_instances = dict()
251
+
252
+ for k, v in self._function_groups.items():
253
+ included_functions.update(v.instance.get_included_functions().keys())
254
+ function_group_configs[k] = v.config
255
+ function_group_instances[k] = v.instance
256
+
257
+ # Function configs need to be restricted to only the functions that are not in a function group
258
+ for k, v in self._functions.items():
259
+ if k not in included_functions:
260
+ function_configs[k] = v.config
261
+ function_instances[k] = v.instance
262
+
227
263
  # Build the config from the added objects
228
264
  config = Config(general=self.general_config,
229
- functions={
230
- k: v.config
231
- for k, v in self._functions.items()
232
- },
265
+ functions=function_configs,
266
+ function_groups=function_group_configs,
233
267
  workflow=self._workflow.config,
234
268
  llms={
235
269
  k: v.config
@@ -263,10 +297,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
263
297
 
264
298
  workflow = Workflow.from_entry_fn(config=config,
265
299
  entry_fn=entry_fn_obj,
266
- functions={
267
- k: v.instance
268
- for k, v in self._functions.items()
269
- },
300
+ functions=function_instances,
301
+ function_groups=function_group_instances,
270
302
  llms={
271
303
  k: v.instance
272
304
  for k, v in self._llms.items()
@@ -347,11 +379,53 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
347
379
 
348
380
  return ConfiguredFunction(config=config, instance=build_result)
349
381
 
382
+ async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
383
+ """Build a function group from the provided configuration.
384
+
385
+ Args:
386
+ name: The name of the function group
387
+ config: The function group configuration
388
+
389
+ Returns:
390
+ ConfiguredFunctionGroup: The built function group
391
+
392
+ Raises:
393
+ ValueError: If the function group builder returns invalid results
394
+ """
395
+ registration = self._registry.get_function_group(type(config))
396
+
397
+ inner_builder = ChildBuilder(self)
398
+
399
+ # Build the function group - use the same wrapping pattern as _build_function
400
+ llms = {k: v.instance for k, v in self._llms.items()}
401
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
402
+
403
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
404
+
405
+ # Set the currently building function group so the ChildBuilder can track dependencies
406
+ self.current_function_group_building = config.type
407
+ # Empty set of dependencies for the current function group
408
+ self.function_group_dependencies[config.type] = FunctionDependencies()
409
+
410
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
411
+
412
+ self.function_group_dependencies[name] = inner_builder.dependencies
413
+
414
+ if not isinstance(build_result, FunctionGroup):
415
+ raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
416
+ f"Got {type(build_result)}")
417
+
418
+ # set the instance name for the function group based on the workflow-provided name
419
+ build_result.set_instance_name(name)
420
+ return ConfiguredFunctionGroup(config=config, instance=build_result)
421
+
350
422
  @override
351
423
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
424
+ if isinstance(name, FunctionRef):
425
+ name = str(name)
352
426
 
353
- if (name in self._functions):
354
- raise ValueError(f"Function `{name}` already exists in the list of functions")
427
+ if (name in self._functions or name in self._function_groups):
428
+ raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
355
429
 
356
430
  build_result = await self._build_function(name=name, config=config)
357
431
 
@@ -360,20 +434,66 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
360
434
  return build_result.instance
361
435
 
362
436
  @override
363
- def get_function(self, name: str | FunctionRef) -> Function:
437
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
438
+ if isinstance(name, FunctionGroupRef):
439
+ name = str(name)
440
+
441
+ if (name in self._function_groups or name in self._functions):
442
+ raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
443
+
444
+ # Build the function group
445
+ build_result = await self._build_function_group(name=name, config=config)
446
+
447
+ self._function_groups[name] = build_result
448
+
449
+ # If the function group exposes functions, add them to the global function registry
450
+ # If the function group exposes functions, record and add them to the registry
451
+ for k in build_result.instance.get_included_functions():
452
+ if k in self._functions:
453
+ raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
454
+ self._functions.update({
455
+ k: ConfiguredFunction(config=v.config, instance=v)
456
+ for k, v in build_result.instance.get_included_functions().items()
457
+ })
458
+
459
+ return build_result.instance
364
460
 
461
+ @override
462
+ def get_function(self, name: str | FunctionRef) -> Function:
463
+ if isinstance(name, FunctionRef):
464
+ name = str(name)
365
465
  if name not in self._functions:
366
466
  raise ValueError(f"Function `{name}` not found")
367
467
 
368
468
  return self._functions[name].instance
369
469
 
470
+ @override
471
+ def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
472
+ if isinstance(name, FunctionGroupRef):
473
+ name = str(name)
474
+ if name not in self._function_groups:
475
+ raise ValueError(f"Function group `{name}` not found")
476
+
477
+ return self._function_groups[name].instance
478
+
370
479
  @override
371
480
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
481
+ if isinstance(name, FunctionRef):
482
+ name = str(name)
372
483
  if name not in self._functions:
373
484
  raise ValueError(f"Function `{name}` not found")
374
485
 
375
486
  return self._functions[name].config
376
487
 
488
+ @override
489
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
490
+ if isinstance(name, FunctionGroupRef):
491
+ name = str(name)
492
+ if name not in self._function_groups:
493
+ raise ValueError(f"Function group `{name}` not found")
494
+
495
+ return self._function_groups[name].config
496
+
377
497
  @override
378
498
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
379
499
 
@@ -403,16 +523,59 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
403
523
 
404
524
  @override
405
525
  def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
526
+ if isinstance(fn_name, FunctionRef):
527
+ fn_name = str(fn_name)
406
528
  return self.function_dependencies[fn_name]
407
529
 
408
530
  @override
409
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
531
+ def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
532
+ if isinstance(fn_name, FunctionGroupRef):
533
+ fn_name = str(fn_name)
534
+ return self.function_group_dependencies[fn_name]
535
+
536
+ @override
537
+ def get_tools(self,
538
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
539
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
540
+ tools = []
541
+ seen = set()
542
+ for n in tool_names:
543
+ is_function_group_ref = isinstance(n, FunctionGroupRef)
544
+ if isinstance(n, FunctionRef) or is_function_group_ref:
545
+ n = str(n)
546
+ if n in seen:
547
+ raise ValueError(f"Function or Function Group `{n}` already seen")
548
+ seen.add(n)
549
+ if n not in self._function_groups:
550
+ # the passed tool name is probably a function
551
+ if is_function_group_ref:
552
+ raise ValueError(f"Function group `{n}` not found in the list of function groups")
553
+ tools.append(self.get_tool(n, wrapper_type))
554
+ continue
555
+
556
+ # Using the registry, get the tool wrapper for the requested framework
557
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
558
+
559
+ current_function_group = self._function_groups[n]
560
+
561
+ # walk through all functions in the function group -- guaranteed to not be fallible
562
+ for fn_name, fn_instance in current_function_group.instance.get_accessible_functions().items():
563
+ try:
564
+ # Wrap in the correct wrapper and add to tools list
565
+ tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
566
+ except Exception:
567
+ logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
568
+ raise
569
+
570
+ return tools
410
571
 
572
+ @override
573
+ def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
574
+ if isinstance(fn_name, FunctionRef):
575
+ fn_name = str(fn_name)
411
576
  if fn_name not in self._functions:
412
577
  raise ValueError(f"Function `{fn_name}` not found in list of functions")
413
-
414
578
  fn = self._functions[fn_name]
415
-
416
579
  try:
417
580
  # Using the registry, get the tool wrapper for the requested framework
418
581
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
@@ -892,12 +1055,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
892
1055
  # Instantiate a memory client
893
1056
  elif component_instance.component_group == ComponentGroup.MEMORY:
894
1057
  await self.add_memory_client(component_instance.name, component_instance.config)
895
- # Instantiate a object store client
1058
+ # Instantiate a object store client
896
1059
  elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
897
1060
  await self.add_object_store(component_instance.name, component_instance.config)
898
1061
  # Instantiate a retriever client
899
1062
  elif component_instance.component_group == ComponentGroup.RETRIEVERS:
900
1063
  await self.add_retriever(component_instance.name, component_instance.config)
1064
+ # Instantiate a function group
1065
+ elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1066
+ await self.add_function_group(component_instance.name, component_instance.config)
901
1067
  # Instantiate a function
902
1068
  elif component_instance.component_group == ComponentGroup.FUNCTIONS:
903
1069
  # If the function is the root, set it as the workflow later
@@ -956,6 +1122,10 @@ class ChildBuilder(Builder):
956
1122
  async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
957
1123
  return await self._workflow_builder.add_function(name, config)
958
1124
 
1125
+ @override
1126
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
1127
+ return await self._workflow_builder.add_function_group(name, config)
1128
+
959
1129
  @override
960
1130
  def get_function(self, name: str) -> Function:
961
1131
  # If a function tries to get another function, we assume it uses it
@@ -965,10 +1135,23 @@ class ChildBuilder(Builder):
965
1135
 
966
1136
  return fn
967
1137
 
1138
+ @override
1139
+ def get_function_group(self, name: str) -> FunctionGroup:
1140
+ # If a function tries to get a function group, we assume it uses it
1141
+ function_group = self._workflow_builder.get_function_group(name)
1142
+
1143
+ self._dependencies.add_function_group(name)
1144
+
1145
+ return function_group
1146
+
968
1147
  @override
969
1148
  def get_function_config(self, name: str) -> FunctionBaseConfig:
970
1149
  return self._workflow_builder.get_function_config(name)
971
1150
 
1151
+ @override
1152
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
1153
+ return self._workflow_builder.get_function_group_config(name)
1154
+
972
1155
  @override
973
1156
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
974
1157
  return await self._workflow_builder.set_workflow(config)
@@ -982,7 +1165,19 @@ class ChildBuilder(Builder):
982
1165
  return self._workflow_builder.get_workflow_config()
983
1166
 
984
1167
  @override
985
- def get_tool(self, fn_name: str, wrapper_type: LLMFrameworkEnum | str):
1168
+ def get_tools(self,
1169
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
1170
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
1171
+ tools = self._workflow_builder.get_tools(tool_names, wrapper_type)
1172
+ for tool_name in tool_names:
1173
+ if tool_name in self._workflow_builder._function_groups:
1174
+ self._dependencies.add_function_group(tool_name)
1175
+ else:
1176
+ self._dependencies.add_function(tool_name)
1177
+ return tools
1178
+
1179
+ @override
1180
+ def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
986
1181
  # If a function tries to get another function as a tool, we assume it uses it
987
1182
  fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
988
1183
 
@@ -1111,3 +1306,7 @@ class ChildBuilder(Builder):
1111
1306
  @override
1112
1307
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
1113
1308
  return self._workflow_builder.get_function_dependencies(fn_name)
1309
+
1310
+ @override
1311
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1312
+ return self._workflow_builder.get_function_group_dependencies(fn_name)
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import click
21
+
22
+ from nat.data_models.optimizer import OptimizerRunConfig
23
+ from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
29
+ @click.option(
30
+ "--config_file",
31
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
32
+ required=True,
33
+ help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
34
+ )
35
+ @click.option(
36
+ "--dataset",
37
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
38
+ required=False,
39
+ help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
40
+ )
41
+ @click.option(
42
+ "--result_json_path",
43
+ type=str,
44
+ default="$",
45
+ help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
46
+ "multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
47
+ "from the result."),
48
+ )
49
+ @click.option(
50
+ "--endpoint",
51
+ type=str,
52
+ default=None,
53
+ help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
54
+ )
55
+ @click.option(
56
+ "--endpoint_timeout",
57
+ type=int,
58
+ default=300,
59
+ help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
60
+ )
61
+ @click.pass_context
62
+ def optimizer_command(ctx, **kwargs) -> None:
63
+ """ Optimize workflow with the specified dataset"""
64
+ pass
65
+
66
+
67
+ async def run_optimizer(config: OptimizerRunConfig):
68
+ await optimize_config(config)
69
+
70
+
71
+ @optimizer_command.result_callback(replace=True)
72
+ def run_optimizer_callback(
73
+ processors, # pylint: disable=unused-argument
74
+ *,
75
+ config_file: Path,
76
+ dataset: Path,
77
+ result_json_path: str,
78
+ endpoint: str,
79
+ endpoint_timeout: int,
80
+ ):
81
+ """Run the optimizer with the provided config file and dataset."""
82
+ config = OptimizerRunConfig(
83
+ config_file=config_file,
84
+ dataset=dataset,
85
+ result_json_path=result_json_path,
86
+ endpoint=endpoint,
87
+ endpoint_timeout=endpoint_timeout,
88
+ )
89
+
90
+ asyncio.run(run_optimizer(config))
@@ -1,5 +1,4 @@
1
1
  general:
2
- use_uvloop: true
3
2
  logging:
4
3
  console:
5
4
  _type: console
@@ -171,6 +171,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
171
171
  workflow_dir (str): The directory to create the workflow package.
172
172
  description (str): Description to pre-popluate the workflow docstring.
173
173
  """
174
+ # Fail fast with Click's standard exit code (2) for bad params.
175
+ if not workflow_name or not workflow_name.strip():
176
+ raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
174
177
  try:
175
178
  # Get the repository root
176
179
  try:
@@ -217,15 +220,13 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
217
220
  else:
218
221
  install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
219
222
 
220
- config_source = configs_dir / 'config.yml'
221
-
222
223
  # List of templates and their destinations
223
224
  files_to_render = {
224
225
  'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
225
226
  'register.py.j2': base_dir / 'register.py',
226
227
  'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
227
228
  '__init__.py.j2': base_dir / '__init__.py',
228
- 'config.yml.j2': config_source,
229
+ 'config.yml.j2': configs_dir / 'config.yml',
229
230
  }
230
231
 
231
232
  # Render templates
@@ -246,10 +247,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
246
247
  with open(output_path, 'w', encoding="utf-8") as f:
247
248
  f.write(content)
248
249
 
249
- # Create symlink for config.yml
250
- config_link = new_workflow_dir / 'configs' / 'config.yml'
251
- os.symlink(config_source, config_link)
252
-
253
250
  # Create symlinks for config and data directories
254
251
  config_dir_source = configs_dir
255
252
  config_dir_link = new_workflow_dir / 'configs'
nat/cli/entrypoint.py CHANGED
@@ -30,10 +30,13 @@ import time
30
30
  import click
31
31
  import nest_asyncio
32
32
 
33
+ from nat.utils.log_levels import LOG_LEVELS
34
+
33
35
  from .commands.configure.configure import configure_command
34
36
  from .commands.evaluate import eval_command
35
37
  from .commands.info.info import info_command
36
38
  from .commands.object_store.object_store import object_store_command
39
+ from .commands.optimize import optimizer_command
37
40
  from .commands.registry.registry import registry_command
38
41
  from .commands.sizing.sizing import sizing
39
42
  from .commands.start import start_command
@@ -44,15 +47,6 @@ from .commands.workflow.workflow import workflow_command
44
47
  # Apply at the beginning of the file to avoid issues with asyncio
45
48
  nest_asyncio.apply()
46
49
 
47
- # Define log level choices
48
- LOG_LEVELS = {
49
- 'DEBUG': logging.DEBUG,
50
- 'INFO': logging.INFO,
51
- 'WARNING': logging.WARNING,
52
- 'ERROR': logging.ERROR,
53
- 'CRITICAL': logging.CRITICAL
54
- }
55
-
56
50
 
57
51
  def setup_logging(log_level: str):
58
52
  """Configure logging with the specified level"""
@@ -108,6 +102,7 @@ cli.add_command(uninstall_command, name="uninstall")
108
102
  cli.add_command(validate_command, name="validate")
109
103
  cli.add_command(workflow_command, name="workflow")
110
104
  cli.add_command(sizing, name="sizing")
105
+ cli.add_command(optimizer_command, name="optimize")
111
106
  cli.add_command(object_store_command, name="object-store")
112
107
 
113
108
  # Aliases
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
27
27
  from nat.cli.type_registry import FrontEndBuildCallableT
28
28
  from nat.cli.type_registry import FrontEndRegisteredCallableT
29
29
  from nat.cli.type_registry import FunctionBuildCallableT
30
+ from nat.cli.type_registry import FunctionGroupBuildCallableT
31
+ from nat.cli.type_registry import FunctionGroupRegisteredCallableT
30
32
  from nat.cli.type_registry import FunctionRegisteredCallableT
31
33
  from nat.cli.type_registry import LLMClientBuildCallableT
32
34
  from nat.cli.type_registry import LLMClientRegisteredCallableT
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
60
62
  from nat.data_models.evaluator import EvaluatorBaseConfigT
61
63
  from nat.data_models.front_end import FrontEndConfigT
62
64
  from nat.data_models.function import FunctionConfigT
65
+ from nat.data_models.function import FunctionGroupConfigT
63
66
  from nat.data_models.llm import LLMBaseConfigT
64
67
  from nat.data_models.memory import MemoryBaseConfigT
65
68
  from nat.data_models.object_store import ObjectStoreBaseConfigT
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
155
158
 
156
159
  context_manager_fn = asynccontextmanager(fn)
157
160
 
158
- if framework_wrappers is None:
159
- framework_wrappers_list: list[str] = []
160
- else:
161
- framework_wrappers_list = list(framework_wrappers)
161
+ framework_wrappers_list = list(framework_wrappers or [])
162
162
 
163
163
  discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
164
164
  component_type=ComponentEnum.FUNCTION)
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
177
177
  return register_function_inner
178
178
 
179
179
 
180
+ def register_function_group(config_type: type[FunctionGroupConfigT],
181
+ framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
182
+ """
183
+ Register a function group with optional framework_wrappers for automatic profiler hooking.
184
+ Function groups share configuration/resources across multiple functions.
185
+ """
186
+
187
+ def register_function_group_inner(
188
+ fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
189
+ ) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
190
+ from .type_registry import GlobalTypeRegistry
191
+ from .type_registry import RegisteredFunctionGroupInfo
192
+
193
+ context_manager_fn = asynccontextmanager(fn)
194
+
195
+ framework_wrappers_list = list(framework_wrappers or [])
196
+
197
+ discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
198
+ component_type=ComponentEnum.FUNCTION_GROUP)
199
+
200
+ GlobalTypeRegistry.get().register_function_group(
201
+ RegisteredFunctionGroupInfo(
202
+ full_type=config_type.full_type,
203
+ config_type=config_type,
204
+ build_fn=context_manager_fn,
205
+ framework_wrappers=framework_wrappers_list,
206
+ discovery_metadata=discovery_metadata,
207
+ ))
208
+
209
+ return context_manager_fn
210
+
211
+ return register_function_group_inner
212
+
213
+
180
214
  def register_llm_provider(config_type: type[LLMBaseConfigT]):
181
215
 
182
216
  def register_llm_provider_inner(