nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. nat/agent/base.py +9 -4
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +1 -1
  5. nat/agent/react_agent/register.py +15 -5
  6. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  7. nat/agent/register.py +2 -0
  8. nat/agent/rewoo_agent/agent.py +4 -2
  9. nat/agent/rewoo_agent/register.py +8 -3
  10. nat/agent/router_agent/__init__.py +0 -0
  11. nat/agent/router_agent/agent.py +329 -0
  12. nat/agent/router_agent/prompt.py +48 -0
  13. nat/agent/router_agent/register.py +97 -0
  14. nat/agent/tool_calling_agent/agent.py +69 -7
  15. nat/agent/tool_calling_agent/register.py +11 -3
  16. nat/builder/builder.py +27 -4
  17. nat/builder/component_utils.py +7 -3
  18. nat/builder/function.py +167 -0
  19. nat/builder/function_info.py +1 -1
  20. nat/builder/workflow.py +5 -0
  21. nat/builder/workflow_builder.py +213 -16
  22. nat/cli/commands/optimize.py +90 -0
  23. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  24. nat/cli/commands/workflow/workflow_commands.py +4 -7
  25. nat/cli/entrypoint.py +2 -0
  26. nat/cli/register_workflow.py +38 -4
  27. nat/cli/type_registry.py +71 -0
  28. nat/data_models/component.py +2 -0
  29. nat/data_models/component_ref.py +11 -0
  30. nat/data_models/config.py +40 -16
  31. nat/data_models/function.py +34 -0
  32. nat/data_models/function_dependencies.py +8 -0
  33. nat/data_models/optimizable.py +119 -0
  34. nat/data_models/optimizer.py +149 -0
  35. nat/data_models/temperature_mixin.py +4 -3
  36. nat/data_models/top_p_mixin.py +4 -3
  37. nat/embedder/nim_embedder.py +1 -1
  38. nat/embedder/openai_embedder.py +1 -1
  39. nat/eval/config.py +1 -1
  40. nat/eval/evaluate.py +5 -1
  41. nat/eval/register.py +4 -0
  42. nat/eval/runtime_evaluator/__init__.py +14 -0
  43. nat/eval/runtime_evaluator/evaluate.py +123 -0
  44. nat/eval/runtime_evaluator/register.py +100 -0
  45. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  46. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  47. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  48. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  49. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  50. nat/front_ends/fastapi/job_store.py +518 -99
  51. nat/front_ends/fastapi/main.py +11 -19
  52. nat/front_ends/fastapi/utils.py +57 -0
  53. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  54. nat/llm/aws_bedrock_llm.py +14 -3
  55. nat/llm/nim_llm.py +14 -3
  56. nat/llm/openai_llm.py +8 -1
  57. nat/observability/exporter/processing_exporter.py +29 -55
  58. nat/observability/mixin/redaction_config_mixin.py +5 -4
  59. nat/observability/mixin/tagging_config_mixin.py +26 -14
  60. nat/observability/mixin/type_introspection_mixin.py +401 -107
  61. nat/observability/processor/processor.py +3 -0
  62. nat/observability/processor/redaction/__init__.py +24 -0
  63. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  64. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  65. nat/observability/processor/redaction/redaction_processor.py +177 -0
  66. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  67. nat/observability/processor/span_tagging_processor.py +21 -14
  68. nat/profiler/decorators/framework_wrapper.py +9 -6
  69. nat/profiler/parameter_optimization/__init__.py +0 -0
  70. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  71. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  72. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  73. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  74. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  75. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  76. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  77. nat/profiler/utils.py +3 -1
  78. nat/tool/chat_completion.py +4 -1
  79. nat/tool/github_tools.py +450 -0
  80. nat/tool/register.py +2 -7
  81. nat/utils/callable_utils.py +70 -0
  82. nat/utils/exception_handlers/automatic_retries.py +103 -48
  83. nat/utils/type_utils.py +4 -0
  84. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  85. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
  86. nat/observability/processor/header_redaction_processor.py +0 -123
  87. nat/observability/processor/redaction_processor.py +0 -77
  88. nat/tool/github_tools/create_github_commit.py +0 -133
  89. nat/tool/github_tools/create_github_issue.py +0 -87
  90. nat/tool/github_tools/create_github_pr.py +0 -106
  91. nat/tool/github_tools/get_github_file.py +0 -106
  92. nat/tool/github_tools/get_github_issue.py +0 -166
  93. nat/tool/github_tools/get_github_pr.py +0 -256
  94. nat/tool/github_tools/update_github_issue.py +0 -100
  95. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  96. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  97. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  98. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  99. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  100. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,9 @@
16
16
  import dataclasses
17
17
  import inspect
18
18
  import logging
19
+ import typing
19
20
  import warnings
21
+ from collections.abc import Sequence
20
22
  from contextlib import AbstractAsyncContextManager
21
23
  from contextlib import AsyncExitStack
22
24
  from contextlib import asynccontextmanager
@@ -31,6 +33,7 @@ from nat.builder.context import ContextState
31
33
  from nat.builder.embedder import EmbedderProviderInfo
32
34
  from nat.builder.framework_enum import LLMFrameworkEnum
33
35
  from nat.builder.function import Function
36
+ from nat.builder.function import FunctionGroup
34
37
  from nat.builder.function import LambdaFunction
35
38
  from nat.builder.function_info import FunctionInfo
36
39
  from nat.builder.llm import LLMProviderInfo
@@ -42,6 +45,7 @@ from nat.data_models.authentication import AuthProviderBaseConfig
42
45
  from nat.data_models.component import ComponentGroup
43
46
  from nat.data_models.component_ref import AuthenticationRef
44
47
  from nat.data_models.component_ref import EmbedderRef
48
+ from nat.data_models.component_ref import FunctionGroupRef
45
49
  from nat.data_models.component_ref import FunctionRef
46
50
  from nat.data_models.component_ref import LLMRef
47
51
  from nat.data_models.component_ref import MemoryRef
@@ -52,6 +56,7 @@ from nat.data_models.config import Config
52
56
  from nat.data_models.config import GeneralConfig
53
57
  from nat.data_models.embedder import EmbedderBaseConfig
54
58
  from nat.data_models.function import FunctionBaseConfig
59
+ from nat.data_models.function import FunctionGroupBaseConfig
55
60
  from nat.data_models.function_dependencies import FunctionDependencies
56
61
  from nat.data_models.llm import LLMBaseConfig
57
62
  from nat.data_models.memory import MemoryBaseConfig
@@ -85,6 +90,12 @@ class ConfiguredFunction:
85
90
  instance: Function
86
91
 
87
92
 
93
+ @dataclasses.dataclass
94
+ class ConfiguredFunctionGroup:
95
+ config: FunctionGroupBaseConfig
96
+ instance: FunctionGroup
97
+
98
+
88
99
  @dataclasses.dataclass
89
100
  class ConfiguredLLM:
90
101
  config: LLMBaseConfig
@@ -145,6 +156,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
145
156
  self._telemetry_exporters: dict[str, ConfiguredTelemetryExporter] = {}
146
157
 
147
158
  self._functions: dict[str, ConfiguredFunction] = {}
159
+ self._function_groups: dict[str, ConfiguredFunctionGroup] = {}
148
160
  self._workflow: ConfiguredFunction | None = None
149
161
 
150
162
  self._llms: dict[str, ConfiguredLLM] = {}
@@ -161,7 +173,9 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
161
173
 
162
174
  # Create a mapping to track function name -> other function names it depends on
163
175
  self.function_dependencies: dict[str, FunctionDependencies] = {}
176
+ self.function_group_dependencies: dict[str, FunctionDependencies] = {}
164
177
  self.current_function_building: str | None = None
178
+ self.current_function_group_building: str | None = None
165
179
 
166
180
  async def __aenter__(self):
167
181
 
@@ -224,12 +238,32 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
224
238
  if (self._workflow is None):
225
239
  raise ValueError("Must set a workflow before building")
226
240
 
241
+ # Set of all functions which are "included" by function groups
242
+ included_functions = set()
243
+ # Dictionary of function configs
244
+ function_configs = dict()
245
+ # Dictionary of function group configs
246
+ function_group_configs = dict()
247
+ # Dictionary of function instances
248
+ function_instances = dict()
249
+ # Dictionary of function group instances
250
+ function_group_instances = dict()
251
+
252
+ for k, v in self._function_groups.items():
253
+ included_functions.update(v.instance.get_included_functions().keys())
254
+ function_group_configs[k] = v.config
255
+ function_group_instances[k] = v.instance
256
+
257
+ # Function configs need to be restricted to only the functions that are not in a function group
258
+ for k, v in self._functions.items():
259
+ if k not in included_functions:
260
+ function_configs[k] = v.config
261
+ function_instances[k] = v.instance
262
+
227
263
  # Build the config from the added objects
228
264
  config = Config(general=self.general_config,
229
- functions={
230
- k: v.config
231
- for k, v in self._functions.items()
232
- },
265
+ functions=function_configs,
266
+ function_groups=function_group_configs,
233
267
  workflow=self._workflow.config,
234
268
  llms={
235
269
  k: v.config
@@ -263,10 +297,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
263
297
 
264
298
  workflow = Workflow.from_entry_fn(config=config,
265
299
  entry_fn=entry_fn_obj,
266
- functions={
267
- k: v.instance
268
- for k, v in self._functions.items()
269
- },
300
+ functions=function_instances,
301
+ function_groups=function_group_instances,
270
302
  llms={
271
303
  k: v.instance
272
304
  for k, v in self._llms.items()
@@ -347,11 +379,51 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
347
379
 
348
380
  return ConfiguredFunction(config=config, instance=build_result)
349
381
 
382
+ async def _build_function_group(self, name: str, config: FunctionGroupBaseConfig) -> ConfiguredFunctionGroup:
383
+ """Build a function group from the provided configuration.
384
+
385
+ Args:
386
+ name: The name of the function group
387
+ config: The function group configuration
388
+
389
+ Returns:
390
+ ConfiguredFunctionGroup: The built function group
391
+
392
+ Raises:
393
+ ValueError: If the function group builder returns invalid results
394
+ """
395
+ registration = self._registry.get_function_group(type(config))
396
+
397
+ inner_builder = ChildBuilder(self)
398
+
399
+ # Build the function group - use the same wrapping pattern as _build_function
400
+ llms = {k: v.instance for k, v in self._llms.items()}
401
+ function_frameworks = detect_llm_frameworks_in_build_fn(registration)
402
+
403
+ build_fn = chain_wrapped_build_fn(registration.build_fn, llms, function_frameworks)
404
+
405
+ # Set the currently building function group so the ChildBuilder can track dependencies
406
+ self.current_function_group_building = config.type
407
+ # Empty set of dependencies for the current function group
408
+ self.function_group_dependencies[config.type] = FunctionDependencies()
409
+
410
+ build_result = await self._get_exit_stack().enter_async_context(build_fn(config, inner_builder))
411
+
412
+ self.function_group_dependencies[name] = inner_builder.dependencies
413
+
414
+ if not isinstance(build_result, FunctionGroup):
415
+ raise ValueError("Expected a FunctionGroup object to be returned from the function group builder. "
416
+ f"Got {type(build_result)}")
417
+
418
+ return ConfiguredFunctionGroup(config=config, instance=build_result)
419
+
350
420
  @override
351
421
  async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
422
+ if isinstance(name, FunctionRef):
423
+ name = str(name)
352
424
 
353
- if (name in self._functions):
354
- raise ValueError(f"Function `{name}` already exists in the list of functions")
425
+ if (name in self._functions or name in self._function_groups):
426
+ raise ValueError(f"Function `{name}` already exists in the list of functions or function groups")
355
427
 
356
428
  build_result = await self._build_function(name=name, config=config)
357
429
 
@@ -360,20 +432,66 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
360
432
  return build_result.instance
361
433
 
362
434
  @override
363
- def get_function(self, name: str | FunctionRef) -> Function:
435
+ async def add_function_group(self, name: str | FunctionGroupRef, config: FunctionGroupBaseConfig) -> FunctionGroup:
436
+ if isinstance(name, FunctionGroupRef):
437
+ name = str(name)
438
+
439
+ if (name in self._function_groups or name in self._functions):
440
+ raise ValueError(f"Function group `{name}` already exists in the list of function groups or functions")
441
+
442
+ # Build the function group
443
+ build_result = await self._build_function_group(name=name, config=config)
444
+
445
+ self._function_groups[name] = build_result
446
+
447
+ # If the function group exposes functions, add them to the global function registry
448
+ # If the function group exposes functions, record and add them to the registry
449
+ for k in build_result.instance.get_included_functions():
450
+ if k in self._functions:
451
+ raise ValueError(f"Exposed function `{k}` from group `{name}` conflicts with an existing function")
452
+ self._functions.update({
453
+ k: ConfiguredFunction(config=v.config, instance=v)
454
+ for k, v in build_result.instance.get_included_functions().items()
455
+ })
456
+
457
+ return build_result.instance
364
458
 
459
+ @override
460
+ def get_function(self, name: str | FunctionRef) -> Function:
461
+ if isinstance(name, FunctionRef):
462
+ name = str(name)
365
463
  if name not in self._functions:
366
464
  raise ValueError(f"Function `{name}` not found")
367
465
 
368
466
  return self._functions[name].instance
369
467
 
468
+ @override
469
+ def get_function_group(self, name: str | FunctionGroupRef) -> FunctionGroup:
470
+ if isinstance(name, FunctionGroupRef):
471
+ name = str(name)
472
+ if name not in self._function_groups:
473
+ raise ValueError(f"Function group `{name}` not found")
474
+
475
+ return self._function_groups[name].instance
476
+
370
477
  @override
371
478
  def get_function_config(self, name: str | FunctionRef) -> FunctionBaseConfig:
479
+ if isinstance(name, FunctionRef):
480
+ name = str(name)
372
481
  if name not in self._functions:
373
482
  raise ValueError(f"Function `{name}` not found")
374
483
 
375
484
  return self._functions[name].config
376
485
 
486
+ @override
487
+ def get_function_group_config(self, name: str | FunctionGroupRef) -> FunctionGroupBaseConfig:
488
+ if isinstance(name, FunctionGroupRef):
489
+ name = str(name)
490
+ if name not in self._function_groups:
491
+ raise ValueError(f"Function group `{name}` not found")
492
+
493
+ return self._function_groups[name].config
494
+
377
495
  @override
378
496
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
379
497
 
@@ -403,16 +521,59 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
403
521
 
404
522
  @override
405
523
  def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
524
+ if isinstance(fn_name, FunctionRef):
525
+ fn_name = str(fn_name)
406
526
  return self.function_dependencies[fn_name]
407
527
 
408
528
  @override
409
- def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
529
+ def get_function_group_dependencies(self, fn_name: str | FunctionGroupRef) -> FunctionDependencies:
530
+ if isinstance(fn_name, FunctionGroupRef):
531
+ fn_name = str(fn_name)
532
+ return self.function_group_dependencies[fn_name]
533
+
534
+ @override
535
+ def get_tools(self,
536
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
537
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
538
+ tools = []
539
+ seen = set()
540
+ for n in tool_names:
541
+ is_function_group_ref = isinstance(n, FunctionGroupRef)
542
+ if isinstance(n, FunctionRef) or is_function_group_ref:
543
+ n = str(n)
544
+ if n in seen:
545
+ raise ValueError(f"Function or Function Group `{n}` already seen")
546
+ seen.add(n)
547
+ if n not in self._function_groups:
548
+ # the passed tool name is probably a function
549
+ if is_function_group_ref:
550
+ raise ValueError(f"Function group `{n}` not found in the list of function groups")
551
+ tools.append(self.get_tool(n, wrapper_type))
552
+ continue
553
+
554
+ # Using the registry, get the tool wrapper for the requested framework
555
+ tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
556
+
557
+ current_function_group = self._function_groups[n]
558
+
559
+ # walk through all functions in the function group -- guaranteed to not be fallible
560
+ for fn_name, fn_instance in current_function_group.instance.get_accessible_functions().items():
561
+ try:
562
+ # Wrap in the correct wrapper and add to tools list
563
+ tools.append(tool_wrapper_reg.build_fn(fn_name, fn_instance, self))
564
+ except Exception:
565
+ logger.error("Error fetching tool `%s`", fn_name, exc_info=True)
566
+ raise
567
+
568
+ return tools
410
569
 
570
+ @override
571
+ def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
572
+ if isinstance(fn_name, FunctionRef):
573
+ fn_name = str(fn_name)
411
574
  if fn_name not in self._functions:
412
575
  raise ValueError(f"Function `{fn_name}` not found in list of functions")
413
-
414
576
  fn = self._functions[fn_name]
415
-
416
577
  try:
417
578
  # Using the registry, get the tool wrapper for the requested framework
418
579
  tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
@@ -892,12 +1053,15 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
892
1053
  # Instantiate a memory client
893
1054
  elif component_instance.component_group == ComponentGroup.MEMORY:
894
1055
  await self.add_memory_client(component_instance.name, component_instance.config)
895
- # Instantiate a object store client
1056
+ # Instantiate a object store client
896
1057
  elif component_instance.component_group == ComponentGroup.OBJECT_STORES:
897
1058
  await self.add_object_store(component_instance.name, component_instance.config)
898
1059
  # Instantiate a retriever client
899
1060
  elif component_instance.component_group == ComponentGroup.RETRIEVERS:
900
1061
  await self.add_retriever(component_instance.name, component_instance.config)
1062
+ # Instantiate a function group
1063
+ elif component_instance.component_group == ComponentGroup.FUNCTION_GROUPS:
1064
+ await self.add_function_group(component_instance.name, component_instance.config)
901
1065
  # Instantiate a function
902
1066
  elif component_instance.component_group == ComponentGroup.FUNCTIONS:
903
1067
  # If the function is the root, set it as the workflow later
@@ -956,6 +1120,10 @@ class ChildBuilder(Builder):
956
1120
  async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
957
1121
  return await self._workflow_builder.add_function(name, config)
958
1122
 
1123
+ @override
1124
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
1125
+ return await self._workflow_builder.add_function_group(name, config)
1126
+
959
1127
  @override
960
1128
  def get_function(self, name: str) -> Function:
961
1129
  # If a function tries to get another function, we assume it uses it
@@ -965,10 +1133,23 @@ class ChildBuilder(Builder):
965
1133
 
966
1134
  return fn
967
1135
 
1136
+ @override
1137
+ def get_function_group(self, name: str) -> FunctionGroup:
1138
+ # If a function tries to get a function group, we assume it uses it
1139
+ function_group = self._workflow_builder.get_function_group(name)
1140
+
1141
+ self._dependencies.add_function_group(name)
1142
+
1143
+ return function_group
1144
+
968
1145
  @override
969
1146
  def get_function_config(self, name: str) -> FunctionBaseConfig:
970
1147
  return self._workflow_builder.get_function_config(name)
971
1148
 
1149
+ @override
1150
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
1151
+ return self._workflow_builder.get_function_group_config(name)
1152
+
972
1153
  @override
973
1154
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
974
1155
  return await self._workflow_builder.set_workflow(config)
@@ -982,7 +1163,19 @@ class ChildBuilder(Builder):
982
1163
  return self._workflow_builder.get_workflow_config()
983
1164
 
984
1165
  @override
985
- def get_tool(self, fn_name: str, wrapper_type: LLMFrameworkEnum | str):
1166
+ def get_tools(self,
1167
+ tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
1168
+ wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
1169
+ tools = self._workflow_builder.get_tools(tool_names, wrapper_type)
1170
+ for tool_name in tool_names:
1171
+ if tool_name in self._workflow_builder._function_groups:
1172
+ self._dependencies.add_function_group(tool_name)
1173
+ else:
1174
+ self._dependencies.add_function(tool_name)
1175
+ return tools
1176
+
1177
+ @override
1178
+ def get_tool(self, fn_name: str | FunctionRef, wrapper_type: LLMFrameworkEnum | str):
986
1179
  # If a function tries to get another function as a tool, we assume it uses it
987
1180
  fn = self._workflow_builder.get_tool(fn_name, wrapper_type)
988
1181
 
@@ -1111,3 +1304,7 @@ class ChildBuilder(Builder):
1111
1304
  @override
1112
1305
  def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
1113
1306
  return self._workflow_builder.get_function_dependencies(fn_name)
1307
+
1308
+ @override
1309
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
1310
+ return self._workflow_builder.get_function_group_dependencies(fn_name)
@@ -0,0 +1,90 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ from pathlib import Path
19
+
20
+ import click
21
+
22
+ from nat.data_models.optimizer import OptimizerRunConfig
23
+ from nat.profiler.parameter_optimization.optimizer_runtime import optimize_config
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @click.group(name=__name__, invoke_without_command=True, help="Optimize a workflow with the specified dataset.")
29
+ @click.option(
30
+ "--config_file",
31
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
32
+ required=True,
33
+ help="A JSON/YAML file that sets the parameters for the workflow and evaluation.",
34
+ )
35
+ @click.option(
36
+ "--dataset",
37
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
38
+ required=False,
39
+ help="A json file with questions and ground truth answers. This will override the dataset path in the config file.",
40
+ )
41
+ @click.option(
42
+ "--result_json_path",
43
+ type=str,
44
+ default="$",
45
+ help=("A JSON path to extract the result from the workflow. Use this when the workflow returns "
46
+ "multiple objects or a dictionary. For example, '$.output' will extract the 'output' field "
47
+ "from the result."),
48
+ )
49
+ @click.option(
50
+ "--endpoint",
51
+ type=str,
52
+ default=None,
53
+ help="Use endpoint for running the workflow. Example: http://localhost:8000/generate",
54
+ )
55
+ @click.option(
56
+ "--endpoint_timeout",
57
+ type=int,
58
+ default=300,
59
+ help="HTTP response timeout in seconds. Only relevant if endpoint is specified.",
60
+ )
61
+ @click.pass_context
62
+ def optimizer_command(ctx, **kwargs) -> None:
63
+ """ Optimize workflow with the specified dataset"""
64
+ pass
65
+
66
+
67
+ async def run_optimizer(config: OptimizerRunConfig):
68
+ await optimize_config(config)
69
+
70
+
71
+ @optimizer_command.result_callback(replace=True)
72
+ def run_optimizer_callback(
73
+ processors, # pylint: disable=unused-argument
74
+ *,
75
+ config_file: Path,
76
+ dataset: Path,
77
+ result_json_path: str,
78
+ endpoint: str,
79
+ endpoint_timeout: int,
80
+ ):
81
+ """Run the optimizer with the provided config file and dataset."""
82
+ config = OptimizerRunConfig(
83
+ config_file=config_file,
84
+ dataset=dataset,
85
+ result_json_path=result_json_path,
86
+ endpoint=endpoint,
87
+ endpoint_timeout=endpoint_timeout,
88
+ )
89
+
90
+ asyncio.run(run_optimizer(config))
@@ -1,5 +1,4 @@
1
1
  general:
2
- use_uvloop: true
3
2
  logging:
4
3
  console:
5
4
  _type: console
@@ -171,6 +171,9 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
171
171
  workflow_dir (str): The directory to create the workflow package.
172
172
  description (str): Description to pre-popluate the workflow docstring.
173
173
  """
174
+ # Fail fast with Click's standard exit code (2) for bad params.
175
+ if not workflow_name or not workflow_name.strip():
176
+ raise click.BadParameter("Workflow name cannot be empty.") # noqa: TRY003
174
177
  try:
175
178
  # Get the repository root
176
179
  try:
@@ -217,15 +220,13 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
217
220
  else:
218
221
  install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
219
222
 
220
- config_source = configs_dir / 'config.yml'
221
-
222
223
  # List of templates and their destinations
223
224
  files_to_render = {
224
225
  'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
225
226
  'register.py.j2': base_dir / 'register.py',
226
227
  'workflow.py.j2': base_dir / f'{workflow_name}_function.py',
227
228
  '__init__.py.j2': base_dir / '__init__.py',
228
- 'config.yml.j2': config_source,
229
+ 'config.yml.j2': configs_dir / 'config.yml',
229
230
  }
230
231
 
231
232
  # Render templates
@@ -246,10 +247,6 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
246
247
  with open(output_path, 'w', encoding="utf-8") as f:
247
248
  f.write(content)
248
249
 
249
- # Create symlink for config.yml
250
- config_link = new_workflow_dir / 'configs' / 'config.yml'
251
- os.symlink(config_source, config_link)
252
-
253
250
  # Create symlinks for config and data directories
254
251
  config_dir_source = configs_dir
255
252
  config_dir_link = new_workflow_dir / 'configs'
nat/cli/entrypoint.py CHANGED
@@ -34,6 +34,7 @@ from .commands.configure.configure import configure_command
34
34
  from .commands.evaluate import eval_command
35
35
  from .commands.info.info import info_command
36
36
  from .commands.object_store.object_store import object_store_command
37
+ from .commands.optimize import optimizer_command
37
38
  from .commands.registry.registry import registry_command
38
39
  from .commands.sizing.sizing import sizing
39
40
  from .commands.start import start_command
@@ -108,6 +109,7 @@ cli.add_command(uninstall_command, name="uninstall")
108
109
  cli.add_command(validate_command, name="validate")
109
110
  cli.add_command(workflow_command, name="workflow")
110
111
  cli.add_command(sizing, name="sizing")
112
+ cli.add_command(optimizer_command, name="optimize")
111
113
  cli.add_command(object_store_command, name="object-store")
112
114
 
113
115
  # Aliases
@@ -27,6 +27,8 @@ from nat.cli.type_registry import EvaluatorRegisteredCallableT
27
27
  from nat.cli.type_registry import FrontEndBuildCallableT
28
28
  from nat.cli.type_registry import FrontEndRegisteredCallableT
29
29
  from nat.cli.type_registry import FunctionBuildCallableT
30
+ from nat.cli.type_registry import FunctionGroupBuildCallableT
31
+ from nat.cli.type_registry import FunctionGroupRegisteredCallableT
30
32
  from nat.cli.type_registry import FunctionRegisteredCallableT
31
33
  from nat.cli.type_registry import LLMClientBuildCallableT
32
34
  from nat.cli.type_registry import LLMClientRegisteredCallableT
@@ -60,6 +62,7 @@ from nat.data_models.embedder import EmbedderBaseConfigT
60
62
  from nat.data_models.evaluator import EvaluatorBaseConfigT
61
63
  from nat.data_models.front_end import FrontEndConfigT
62
64
  from nat.data_models.function import FunctionConfigT
65
+ from nat.data_models.function import FunctionGroupConfigT
63
66
  from nat.data_models.llm import LLMBaseConfigT
64
67
  from nat.data_models.memory import MemoryBaseConfigT
65
68
  from nat.data_models.object_store import ObjectStoreBaseConfigT
@@ -155,10 +158,7 @@ def register_function(config_type: type[FunctionConfigT],
155
158
 
156
159
  context_manager_fn = asynccontextmanager(fn)
157
160
 
158
- if framework_wrappers is None:
159
- framework_wrappers_list: list[str] = []
160
- else:
161
- framework_wrappers_list = list(framework_wrappers)
161
+ framework_wrappers_list = list(framework_wrappers or [])
162
162
 
163
163
  discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
164
164
  component_type=ComponentEnum.FUNCTION)
@@ -177,6 +177,40 @@ def register_function(config_type: type[FunctionConfigT],
177
177
  return register_function_inner
178
178
 
179
179
 
180
+ def register_function_group(config_type: type[FunctionGroupConfigT],
181
+ framework_wrappers: list[LLMFrameworkEnum | str] | None = None):
182
+ """
183
+ Register a function group with optional framework_wrappers for automatic profiler hooking.
184
+ Function groups share configuration/resources across multiple functions.
185
+ """
186
+
187
+ def register_function_group_inner(
188
+ fn: FunctionGroupBuildCallableT[FunctionGroupConfigT]
189
+ ) -> FunctionGroupRegisteredCallableT[FunctionGroupConfigT]:
190
+ from .type_registry import GlobalTypeRegistry
191
+ from .type_registry import RegisteredFunctionGroupInfo
192
+
193
+ context_manager_fn = asynccontextmanager(fn)
194
+
195
+ framework_wrappers_list = list(framework_wrappers or [])
196
+
197
+ discovery_metadata = DiscoveryMetadata.from_config_type(config_type=config_type,
198
+ component_type=ComponentEnum.FUNCTION_GROUP)
199
+
200
+ GlobalTypeRegistry.get().register_function_group(
201
+ RegisteredFunctionGroupInfo(
202
+ full_type=config_type.full_type,
203
+ config_type=config_type,
204
+ build_fn=context_manager_fn,
205
+ framework_wrappers=framework_wrappers_list,
206
+ discovery_metadata=discovery_metadata,
207
+ ))
208
+
209
+ return context_manager_fn
210
+
211
+ return register_function_group_inner
212
+
213
+
180
214
  def register_llm_provider(config_type: type[LLMBaseConfigT]):
181
215
 
182
216
  def register_llm_provider_inner(