nvidia-nat 1.3.0a20250909__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. nat/agent/base.py +11 -6
  2. nat/agent/dual_node.py +2 -2
  3. nat/agent/prompt_optimizer/prompt.py +68 -0
  4. nat/agent/prompt_optimizer/register.py +149 -0
  5. nat/agent/react_agent/agent.py +1 -1
  6. nat/agent/react_agent/register.py +17 -7
  7. nat/agent/reasoning_agent/reasoning_agent.py +6 -1
  8. nat/agent/register.py +2 -0
  9. nat/agent/rewoo_agent/agent.py +6 -3
  10. nat/agent/rewoo_agent/register.py +16 -10
  11. nat/agent/router_agent/__init__.py +0 -0
  12. nat/agent/router_agent/agent.py +329 -0
  13. nat/agent/router_agent/prompt.py +48 -0
  14. nat/agent/router_agent/register.py +97 -0
  15. nat/agent/tool_calling_agent/agent.py +69 -7
  16. nat/agent/tool_calling_agent/register.py +17 -9
  17. nat/builder/builder.py +27 -4
  18. nat/builder/component_utils.py +7 -3
  19. nat/builder/function.py +167 -0
  20. nat/builder/function_info.py +1 -1
  21. nat/builder/workflow.py +5 -0
  22. nat/builder/workflow_builder.py +213 -16
  23. nat/cli/commands/optimize.py +90 -0
  24. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  25. nat/cli/commands/workflow/workflow_commands.py +5 -8
  26. nat/cli/entrypoint.py +2 -0
  27. nat/cli/register_workflow.py +38 -4
  28. nat/cli/type_registry.py +71 -0
  29. nat/data_models/api_server.py +1 -1
  30. nat/data_models/component.py +2 -0
  31. nat/data_models/component_ref.py +11 -0
  32. nat/data_models/config.py +40 -16
  33. nat/data_models/function.py +34 -0
  34. nat/data_models/function_dependencies.py +8 -0
  35. nat/data_models/optimizable.py +119 -0
  36. nat/data_models/optimizer.py +149 -0
  37. nat/data_models/temperature_mixin.py +4 -3
  38. nat/data_models/top_p_mixin.py +4 -3
  39. nat/embedder/nim_embedder.py +1 -1
  40. nat/embedder/openai_embedder.py +1 -1
  41. nat/eval/config.py +1 -1
  42. nat/eval/evaluate.py +5 -1
  43. nat/eval/register.py +4 -0
  44. nat/eval/runtime_evaluator/__init__.py +14 -0
  45. nat/eval/runtime_evaluator/evaluate.py +123 -0
  46. nat/eval/runtime_evaluator/register.py +100 -0
  47. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
  48. nat/front_ends/fastapi/dask_client_mixin.py +43 -0
  49. nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
  50. nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
  51. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
  52. nat/front_ends/fastapi/job_store.py +518 -99
  53. nat/front_ends/fastapi/main.py +11 -19
  54. nat/front_ends/fastapi/utils.py +57 -0
  55. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
  56. nat/llm/aws_bedrock_llm.py +15 -4
  57. nat/llm/nim_llm.py +14 -3
  58. nat/llm/openai_llm.py +8 -1
  59. nat/observability/exporter/processing_exporter.py +29 -55
  60. nat/observability/mixin/redaction_config_mixin.py +5 -4
  61. nat/observability/mixin/tagging_config_mixin.py +26 -14
  62. nat/observability/mixin/type_introspection_mixin.py +401 -107
  63. nat/observability/processor/processor.py +3 -0
  64. nat/observability/processor/redaction/__init__.py +24 -0
  65. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  66. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  67. nat/observability/processor/redaction/redaction_processor.py +177 -0
  68. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  69. nat/observability/processor/span_tagging_processor.py +21 -14
  70. nat/profiler/decorators/framework_wrapper.py +9 -6
  71. nat/profiler/parameter_optimization/__init__.py +0 -0
  72. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  73. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  74. nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
  75. nat/profiler/parameter_optimization/parameter_selection.py +108 -0
  76. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  77. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  78. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  79. nat/profiler/utils.py +3 -1
  80. nat/tool/chat_completion.py +5 -2
  81. nat/tool/document_search.py +1 -1
  82. nat/tool/github_tools.py +450 -0
  83. nat/tool/register.py +2 -7
  84. nat/utils/callable_utils.py +70 -0
  85. nat/utils/exception_handlers/automatic_retries.py +103 -48
  86. nat/utils/type_utils.py +4 -0
  87. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
  88. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +94 -74
  89. nat/observability/processor/header_redaction_processor.py +0 -123
  90. nat/observability/processor/redaction_processor.py +0 -77
  91. nat/tool/github_tools/create_github_commit.py +0 -133
  92. nat/tool/github_tools/create_github_issue.py +0 -87
  93. nat/tool/github_tools/create_github_pr.py +0 -106
  94. nat/tool/github_tools/get_github_file.py +0 -106
  95. nat/tool/github_tools/get_github_issue.py +0 -166
  96. nat/tool/github_tools/get_github_pr.py +0 -256
  97. nat/tool/github_tools/update_github_issue.py +0 -100
  98. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  99. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
  100. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
  101. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  102. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
  103. {nvidia_nat-1.3.0a20250909.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py CHANGED
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
37
37
  from nat.builder.evaluator import EvaluatorInfo
38
38
  from nat.builder.front_end import FrontEndBase
39
39
  from nat.builder.function import Function
40
+ from nat.builder.function import FunctionGroup
40
41
  from nat.builder.function_base import FunctionBase
41
42
  from nat.builder.function_info import FunctionInfo
42
43
  from nat.builder.llm import LLMProviderInfo
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
55
56
  from nat.data_models.front_end import FrontEndConfigT
56
57
  from nat.data_models.function import FunctionBaseConfig
57
58
  from nat.data_models.function import FunctionConfigT
59
+ from nat.data_models.function import FunctionGroupBaseConfig
60
+ from nat.data_models.function import FunctionGroupConfigT
58
61
  from nat.data_models.llm import LLMBaseConfig
59
62
  from nat.data_models.llm import LLMBaseConfigT
60
63
  from nat.data_models.logging import LoggingBaseConfig
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
85
88
  EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
86
89
  FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
87
90
  FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
91
+ FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
88
92
  TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
89
93
  LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
90
94
  LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
106
110
  FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
107
111
  FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
108
112
  AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
113
+ FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
109
114
  TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
110
115
  LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
111
116
  LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
178
183
  framework_wrappers: list[str] = Field(default_factory=list)
179
184
 
180
185
 
186
+ class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
187
+ """
188
+ Represents a registered function group. Function groups are collections of functions that share configuration
189
+ and resources.
190
+ """
191
+
192
+ build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
193
+ framework_wrappers: list[str] = Field(default_factory=list)
194
+
195
+
181
196
  class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
182
197
  """
183
198
  Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
@@ -313,6 +328,9 @@ class TypeRegistry:
313
328
  # Functions
314
329
  self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
315
330
 
331
+ # Function Groups
332
+ self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
333
+
316
334
  # LLMs
317
335
  self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
318
336
  self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
@@ -478,6 +496,50 @@ class TypeRegistry:
478
496
 
479
497
  return list(self._registered_functions.values())
480
498
 
499
+ def register_function_group(self, registration: RegisteredFunctionGroupInfo):
500
+ """Register a function group with the type registry.
501
+
502
+ Args:
503
+ registration: The function group registration information
504
+
505
+ Raises:
506
+ ValueError: If a function group with the same config type is already registered
507
+ """
508
+ if (registration.config_type in self._registered_function_groups):
509
+ raise ValueError(
510
+ f"A function group with the same config type `{registration.config_type}` has already been "
511
+ "registered.")
512
+
513
+ self._registered_function_groups[registration.config_type] = registration
514
+
515
+ self._registration_changed()
516
+
517
+ def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
518
+ """Get a registered function group by its config type.
519
+
520
+ Args:
521
+ config_type: The function group configuration type
522
+
523
+ Returns:
524
+ RegisteredFunctionGroupInfo: The registered function group information
525
+
526
+ Raises:
527
+ KeyError: If no function group is registered for the given config type
528
+ """
529
+ try:
530
+ return self._registered_function_groups[config_type]
531
+ except KeyError as err:
532
+ raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
533
+ f"Registered configs: {set(self._registered_function_groups.keys())}") from err
534
+
535
+ def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
536
+ """Get all registered function groups.
537
+
538
+ Returns:
539
+ list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
540
+ """
541
+ return list(self._registered_function_groups.values())
542
+
481
543
  def register_llm_provider(self, info: RegisteredLLMProviderInfo):
482
544
 
483
545
  if (info.config_type in self._registered_llm_provider_infos):
@@ -790,6 +852,9 @@ class TypeRegistry:
790
852
  if component_type == ComponentEnum.FUNCTION:
791
853
  return self._registered_functions
792
854
 
855
+ if component_type == ComponentEnum.FUNCTION_GROUP:
856
+ return self._registered_function_groups
857
+
793
858
  if component_type == ComponentEnum.TOOL_WRAPPER:
794
859
  return self._registered_tool_wrappers
795
860
 
@@ -854,6 +919,9 @@ class TypeRegistry:
854
919
  if component_type == ComponentEnum.FUNCTION:
855
920
  return [i.static_type() for i in self._registered_functions]
856
921
 
922
+ if component_type == ComponentEnum.FUNCTION_GROUP:
923
+ return [i.static_type() for i in self._registered_function_groups]
924
+
857
925
  if component_type == ComponentEnum.TOOL_WRAPPER:
858
926
  return list(self._registered_tool_wrappers)
859
927
 
@@ -943,6 +1011,9 @@ class TypeRegistry:
943
1011
  if issubclass(cls, FunctionBaseConfig):
944
1012
  return self._do_compute_annotation(cls, self.get_registered_functions())
945
1013
 
1014
+ if issubclass(cls, FunctionGroupBaseConfig):
1015
+ return self._do_compute_annotation(cls, self.get_registered_function_groups())
1016
+
946
1017
  if issubclass(cls, LLMBaseConfig):
947
1018
  return self._do_compute_annotation(cls, self.get_registered_llm_providers())
948
1019
 
@@ -688,7 +688,7 @@ GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
688
688
 
689
689
  # ======== AINodeMessageChunk Converters ========
690
690
  def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
691
- '''Converts LangChain AINodeMessageChunk to ChatResponseChunk'''
691
+ '''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk'''
692
692
  content = ""
693
693
  if hasattr(data, 'content') and data.content is not None:
694
694
  content = str(data.content)
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
27
27
  EVALUATOR = "evaluator"
28
28
  FRONT_END = "front_end"
29
29
  FUNCTION = "function"
30
+ FUNCTION_GROUP = "function_group"
30
31
  TTC_STRATEGY = "ttc_strategy"
31
32
  LLM_CLIENT = "llm_client"
32
33
  LLM_PROVIDER = "llm_provider"
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
47
48
  AUTHENTICATION = "authentication"
48
49
  EMBEDDERS = "embedders"
49
50
  FUNCTIONS = "functions"
51
+ FUNCTION_GROUPS = "function_groups"
50
52
  TTC_STRATEGIES = "ttc_strategies"
51
53
  LLMS = "llms"
52
54
  MEMORY = "memory"
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
102
102
  return ComponentGroup.FUNCTIONS
103
103
 
104
104
 
105
+ class FunctionGroupRef(ComponentRef):
106
+ """
107
+ A reference to a function group in a NAT configuration object.
108
+ """
109
+
110
+ @property
111
+ @override
112
+ def component_group(self):
113
+ return ComponentGroup.FUNCTION_GROUPS
114
+
115
+
105
116
  class LLMRef(ComponentRef):
106
117
  """
107
118
  A reference to an LLM in a NAT configuration object.
nat/data_models/config.py CHANGED
@@ -20,6 +20,7 @@ import typing
20
20
  from pydantic import BaseModel
21
21
  from pydantic import ConfigDict
22
22
  from pydantic import Discriminator
23
+ from pydantic import Field
23
24
  from pydantic import ValidationError
24
25
  from pydantic import ValidationInfo
25
26
  from pydantic import ValidatorFunctionWrapHandler
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
29
30
  from nat.data_models.front_end import FrontEndBaseConfig
30
31
  from nat.data_models.function import EmptyFunctionConfig
31
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
32
34
  from nat.data_models.logging import LoggingBaseConfig
35
+ from nat.data_models.optimizer import OptimizerConfig
33
36
  from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
37
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
35
38
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
57
60
  error_type = e['type']
58
61
  if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
59
62
  requested_type = e["ctx"]["tag"]
60
-
61
63
  if (info.field_name in ('workflow', 'functions')):
62
64
  registered_keys = GlobalTypeRegistry.get().get_registered_functions()
65
+ elif (info.field_name == "function_groups"):
66
+ registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
63
67
  elif (info.field_name == "authentication"):
64
68
  registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
65
69
  elif (info.field_name == "llms"):
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
135
139
 
136
140
  class TelemetryConfig(BaseModel):
137
141
 
138
- logging: dict[str, LoggingBaseConfig] = {}
139
- tracing: dict[str, TelemetryExporterBaseConfig] = {}
142
+ logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
143
+ tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
140
144
 
141
145
  @field_validator("logging", "tracing", mode="wrap")
142
146
  @classmethod
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
185
189
 
186
190
  model_config = ConfigDict(protected_namespaces=())
187
191
 
188
- use_uvloop: bool = True
192
+ use_uvloop: bool | None = Field(
193
+ default=None,
194
+ deprecated=
195
+ "`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
196
+ "automatically determined based on platform")
189
197
  """
190
- Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
191
- better error messages when debugging.
198
+ This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
199
+ usage is now determined automatically based on the platform.
192
200
  """
193
201
 
194
202
  telemetry: TelemetryConfig = TelemetryConfig()
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
240
248
  general: GeneralConfig = GeneralConfig()
241
249
 
242
250
  # Functions Configuration
243
- functions: dict[str, FunctionBaseConfig] = {}
251
+ functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
252
+
253
+ # Function Groups Configuration
254
+ function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
244
255
 
245
256
  # LLMs Configuration
246
- llms: dict[str, LLMBaseConfig] = {}
257
+ llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
247
258
 
248
259
  # Embedders Configuration
249
- embedders: dict[str, EmbedderBaseConfig] = {}
260
+ embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
250
261
 
251
262
  # Memory Configuration
252
- memory: dict[str, MemoryBaseConfig] = {}
263
+ memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
253
264
 
254
265
  # Object Stores Configuration
255
- object_stores: dict[str, ObjectStoreBaseConfig] = {}
266
+ object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
267
+
268
+ # Optimizer Configuration
269
+ optimizer: OptimizerConfig = OptimizerConfig()
256
270
 
257
271
  # Retriever Configuration
258
- retrievers: dict[str, RetrieverBaseConfig] = {}
272
+ retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
259
273
 
260
274
  # TTC Strategies
261
- ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
275
+ ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
262
276
 
263
277
  # Workflow Configuration
264
278
  workflow: FunctionBaseConfig = EmptyFunctionConfig()
265
279
 
266
280
  # Authentication Configuration
267
- authentication: dict[str, AuthProviderBaseConfig] = {}
281
+ authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
268
282
 
269
283
  # Evaluation Options
270
284
  eval: EvalConfig = EvalConfig()
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
278
292
  stream.write(f"Workflow Type: {self.workflow.type}\n")
279
293
 
280
294
  stream.write(f"Number of Functions: {len(self.functions)}\n")
295
+ stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
281
296
  stream.write(f"Number of LLMs: {len(self.llms)}\n")
282
297
  stream.write(f"Number of Embedders: {len(self.embedders)}\n")
283
298
  stream.write(f"Number of Memory: {len(self.memory)}\n")
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
287
302
  stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
303
 
289
304
  @field_validator("functions",
305
+ "function_groups",
290
306
  "llms",
291
307
  "embedders",
292
308
  "memory",
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
328
344
  typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
329
345
  Discriminator(TypedBaseModel.discriminator)]]
330
346
 
347
+ FunctionGroupsAnnotation = dict[str,
348
+ typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
+ Discriminator(TypedBaseModel.discriminator)]]
350
+
331
351
  MemoryAnnotation = dict[str,
332
352
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
333
353
  Discriminator(TypedBaseModel.discriminator)]]
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
335
355
  ObjectStoreAnnotation = dict[str,
336
356
  typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
357
  Discriminator(TypedBaseModel.discriminator)]]
338
-
339
358
  RetrieverAnnotation = dict[str,
340
359
  typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
341
360
  Discriminator(TypedBaseModel.discriminator)]]
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
344
363
  typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
345
364
  Discriminator(TypedBaseModel.discriminator)]]
346
365
 
347
- WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
366
+ WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
348
367
  Discriminator(TypedBaseModel.discriminator)]
349
368
 
350
369
  should_rebuild = False
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
369
388
  functions_field.annotation = FunctionsAnnotation
370
389
  should_rebuild = True
371
390
 
391
+ function_groups_field = cls.model_fields.get("function_groups")
392
+ if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
393
+ function_groups_field.annotation = FunctionGroupsAnnotation
394
+ should_rebuild = True
395
+
372
396
  memory_field = cls.model_fields.get("memory")
373
397
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
374
398
  memory_field.annotation = MemoryAnnotation
@@ -15,6 +15,10 @@
15
15
 
16
16
  import typing
17
17
 
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
18
22
  from .common import BaseModelRegistryTag
19
23
  from .common import TypedBaseModel
20
24
 
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
27
  pass
24
28
 
25
29
 
30
+ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
31
+ """Base configuration for function groups.
32
+
33
+ Function groups enable sharing of configurations and resources across multiple functions.
34
+ """
35
+ include: list[str] = Field(
36
+ default_factory=list,
37
+ description="The list of function names which should be added to the global Function registry",
38
+ )
39
+ exclude: list[str] = Field(
40
+ default_factory=list,
41
+ description="The list of function names which should be excluded from default access to the group",
42
+ )
43
+
44
+ @field_validator("include", "exclude")
45
+ @classmethod
46
+ def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
47
+ if len(set(value)) != len(value):
48
+ raise ValueError("Function names must be unique")
49
+ return sorted(value)
50
+
51
+ @model_validator(mode="after")
52
+ def _validate_include_exclude(self):
53
+ if self.include and self.exclude:
54
+ raise ValueError("include and exclude cannot be used together")
55
+ return self
56
+
57
+
26
58
  class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
27
59
  pass
28
60
 
29
61
 
30
62
  FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
63
+
64
+ FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
23
23
  A class to represent the dependencies of a function.
24
24
  """
25
25
  functions: set[str] = Field(default_factory=set)
26
+ function_groups: set[str] = Field(default_factory=set)
26
27
  llms: set[str] = Field(default_factory=set)
27
28
  embedders: set[str] = Field(default_factory=set)
28
29
  memory_clients: set[str] = Field(default_factory=set)
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
33
34
  def serialize_functions(self, v: set[str]) -> list[str]:
34
35
  return list(v)
35
36
 
37
+ @field_serializer("function_groups", when_used="json")
38
+ def serialize_function_groups(self, v: set[str]) -> list[str]:
39
+ return list(v)
40
+
36
41
  @field_serializer("llms", when_used="json")
37
42
  def serialize_llms(self, v: set[str]) -> list[str]:
38
43
  return list(v)
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
56
61
  def add_function(self, function: str):
57
62
  self.functions.add(function)
58
63
 
64
+ def add_function_group(self, function_group: str):
65
+ self.function_groups.add(function_group) # pylint: disable=no-member
66
+
59
67
  def add_llm(self, llm: str):
60
68
  self.llms.add(llm)
61
69
 
@@ -0,0 +1,119 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from typing import Any
18
+ from typing import Generic
19
+ from typing import TypeVar
20
+
21
+ from optuna import Trial
22
+ from pydantic import BaseModel
23
+ from pydantic import ConfigDict
24
+ from pydantic import Field
25
+ from pydantic import model_validator
26
+
27
+ T = TypeVar("T", int, float, bool, str)
28
+
29
+
30
+ # --------------------------------------------------------------------- #
31
+ # 1. Hyper‑parameter metadata container #
32
+ # --------------------------------------------------------------------- #
33
+ class SearchSpace(BaseModel, Generic[T]):
34
+ values: Sequence[T] | None = None
35
+ low: T | None = None
36
+ high: T | None = None
37
+ log: bool = False # log scale
38
+ step: float | None = None
39
+ is_prompt: bool = False
40
+ prompt: str | None = None # prompt to optimize
41
+ prompt_purpose: str | None = None # purpose of the prompt
42
+
43
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
44
+
45
+ @model_validator(mode="after")
46
+ def validate_search_space_parameters(self):
47
+ """Validate that either values is provided, or both high and low."""
48
+ if self.values is not None:
49
+ # If values is provided, we don't need high/low
50
+ if self.high is not None or self.low is not None:
51
+ raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
52
+ return self
53
+
54
+ return self
55
+
56
+ # Helper for Optuna Trials
57
+ def suggest(self, trial: Trial, name: str):
58
+ if self.is_prompt:
59
+ raise ValueError("Prompt optimization not currently supported using Optuna. "
60
+ "Use the genetic algorithm implementation instead.")
61
+ if self.values is not None:
62
+ return trial.suggest_categorical(name, self.values)
63
+ if isinstance(self.low, int):
64
+ return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
65
+ return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
66
+
67
+
68
+ def OptimizableField(
69
+ default: Any,
70
+ *,
71
+ space: SearchSpace | None = None,
72
+ merge_conflict: str = "overwrite",
73
+ **fld_kw,
74
+ ):
75
+ # 1. Pull out any user‑supplied extras (must be a dict)
76
+ user_extra = fld_kw.pop("json_schema_extra", None) or {}
77
+ if not isinstance(user_extra, dict):
78
+ raise TypeError("`json_schema_extra` must be a mapping.")
79
+
80
+ # 2. If the space is a prompt, ensure a concrete base prompt exists
81
+ if space is not None and getattr(space, "is_prompt", False):
82
+ if getattr(space, "prompt", None) is None:
83
+ if default is None:
84
+ raise ValueError("Prompt-optimized fields require a base prompt: provide a "
85
+ "non-None field default or set space.prompt.")
86
+ # Default prompt not provided in space; fall back to the field's default
87
+ space.prompt = default
88
+
89
+ # 3. Prepare our own metadata
90
+ ours = {"optimizable": True}
91
+ if space is not None:
92
+ ours["search_space"] = space
93
+
94
+ # 4. Merge with user extras according to merge_conflict policy
95
+ intersect = ours.keys() & user_extra.keys()
96
+ if intersect:
97
+ if merge_conflict == "error":
98
+ raise ValueError("`json_schema_extra` already contains reserved key(s): "
99
+ f"{', '.join(intersect)}")
100
+ if merge_conflict == "keep":
101
+ # remove the ones the user already set so we don't overwrite them
102
+ ours = {k: v for k, v in ours.items() if k not in intersect}
103
+
104
+ merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
105
+
106
+ # 5. Return a normal Pydantic Field with merged extras
107
+ return Field(default, json_schema_extra=merged_extra, **fld_kw)
108
+
109
+
110
+ class OptimizableMixin(BaseModel):
111
+ optimizable_params: list[str] = Field(default_factory=list,
112
+ description="List of parameters that can be optimized.",
113
+ exclude=True)
114
+
115
+ search_space: dict[str, SearchSpace] = Field(
116
+ default_factory=dict,
117
+ description="Optional search space overrides for optimizable parameters.",
118
+ exclude=True,
119
+ )