nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.3.0a20250917__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +1 -1
- nat/agent/react_agent/register.py +15 -5
- nat/agent/reasoning_agent/reasoning_agent.py +6 -1
- nat/agent/register.py +2 -0
- nat/agent/rewoo_agent/agent.py +4 -2
- nat/agent/rewoo_agent/register.py +8 -3
- nat/agent/router_agent/__init__.py +0 -0
- nat/agent/router_agent/agent.py +329 -0
- nat/agent/router_agent/prompt.py +48 -0
- nat/agent/router_agent/register.py +97 -0
- nat/agent/tool_calling_agent/agent.py +69 -7
- nat/agent/tool_calling_agent/register.py +11 -3
- nat/builder/builder.py +27 -4
- nat/builder/component_utils.py +7 -3
- nat/builder/function.py +167 -0
- nat/builder/function_info.py +1 -1
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +213 -16
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -7
- nat/cli/entrypoint.py +2 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +71 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +40 -16
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/temperature_mixin.py +4 -3
- nat/data_models/top_p_mixin.py +4 -3
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/eval/config.py +1 -1
- nat/eval/evaluate.py +5 -1
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +5 -1
- nat/front_ends/fastapi/dask_client_mixin.py +43 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +14 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +111 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +243 -228
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +3 -2
- nat/llm/aws_bedrock_llm.py +14 -3
- nat/llm/nim_llm.py +14 -3
- nat/llm/openai_llm.py +8 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +401 -107
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/profiler/decorators/framework_wrapper.py +9 -6
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +149 -0
- nat/profiler/parameter_optimization/parameter_selection.py +108 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/tool/chat_completion.py +4 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/register.py +2 -7
- nat/utils/callable_utils.py +70 -0
- nat/utils/exception_handlers/automatic_retries.py +103 -48
- nat/utils/type_utils.py +4 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/METADATA +8 -1
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/RECORD +91 -71
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.3.0a20250917.dist-info}/top_level.txt +0 -0
nat/cli/type_registry.py
CHANGED
|
@@ -37,6 +37,7 @@ from nat.builder.embedder import EmbedderProviderInfo
|
|
|
37
37
|
from nat.builder.evaluator import EvaluatorInfo
|
|
38
38
|
from nat.builder.front_end import FrontEndBase
|
|
39
39
|
from nat.builder.function import Function
|
|
40
|
+
from nat.builder.function import FunctionGroup
|
|
40
41
|
from nat.builder.function_base import FunctionBase
|
|
41
42
|
from nat.builder.function_info import FunctionInfo
|
|
42
43
|
from nat.builder.llm import LLMProviderInfo
|
|
@@ -55,6 +56,8 @@ from nat.data_models.front_end import FrontEndBaseConfig
|
|
|
55
56
|
from nat.data_models.front_end import FrontEndConfigT
|
|
56
57
|
from nat.data_models.function import FunctionBaseConfig
|
|
57
58
|
from nat.data_models.function import FunctionConfigT
|
|
59
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
60
|
+
from nat.data_models.function import FunctionGroupConfigT
|
|
58
61
|
from nat.data_models.llm import LLMBaseConfig
|
|
59
62
|
from nat.data_models.llm import LLMBaseConfigT
|
|
60
63
|
from nat.data_models.logging import LoggingBaseConfig
|
|
@@ -85,6 +88,7 @@ EmbedderProviderBuildCallableT = Callable[[EmbedderBaseConfigT, Builder], AsyncI
|
|
|
85
88
|
EvaluatorBuildCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], AsyncIterator[EvaluatorInfo]]
|
|
86
89
|
FrontEndBuildCallableT = Callable[[FrontEndConfigT, Config], AsyncIterator[FrontEndBase]]
|
|
87
90
|
FunctionBuildCallableT = Callable[[FunctionConfigT, Builder], AsyncIterator[FunctionInfo | Callable | FunctionBase]]
|
|
91
|
+
FunctionGroupBuildCallableT = Callable[[FunctionGroupConfigT, Builder], AsyncIterator[FunctionGroup]]
|
|
88
92
|
TTCStrategyBuildCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AsyncIterator[StrategyBase]]
|
|
89
93
|
LLMClientBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[typing.Any]]
|
|
90
94
|
LLMProviderBuildCallableT = Callable[[LLMBaseConfigT, Builder], AsyncIterator[LLMProviderInfo]]
|
|
@@ -106,6 +110,7 @@ EvaluatorRegisteredCallableT = Callable[[EvaluatorBaseConfigT, EvalBuilder], Abs
|
|
|
106
110
|
FrontEndRegisteredCallableT = Callable[[FrontEndConfigT, Config], AbstractAsyncContextManager[FrontEndBase]]
|
|
107
111
|
FunctionRegisteredCallableT = Callable[[FunctionConfigT, Builder],
|
|
108
112
|
AbstractAsyncContextManager[FunctionInfo | Callable | FunctionBase]]
|
|
113
|
+
FunctionGroupRegisteredCallableT = Callable[[FunctionGroupConfigT, Builder], AbstractAsyncContextManager[FunctionGroup]]
|
|
109
114
|
TTCStrategyRegisterCallableT = Callable[[TTCStrategyBaseConfigT, Builder], AbstractAsyncContextManager[StrategyBase]]
|
|
110
115
|
LLMClientRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[typing.Any]]
|
|
111
116
|
LLMProviderRegisteredCallableT = Callable[[LLMBaseConfigT, Builder], AbstractAsyncContextManager[LLMProviderInfo]]
|
|
@@ -178,6 +183,16 @@ class RegisteredFunctionInfo(RegisteredInfo[FunctionBaseConfig]):
|
|
|
178
183
|
framework_wrappers: list[str] = Field(default_factory=list)
|
|
179
184
|
|
|
180
185
|
|
|
186
|
+
class RegisteredFunctionGroupInfo(RegisteredInfo[FunctionGroupBaseConfig]):
|
|
187
|
+
"""
|
|
188
|
+
Represents a registered function group. Function groups are collections of functions that share configuration
|
|
189
|
+
and resources.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
build_fn: FunctionGroupRegisteredCallableT = Field(repr=False)
|
|
193
|
+
framework_wrappers: list[str] = Field(default_factory=list)
|
|
194
|
+
|
|
195
|
+
|
|
181
196
|
class RegisteredLLMProviderInfo(RegisteredInfo[LLMBaseConfig]):
|
|
182
197
|
"""
|
|
183
198
|
Represents a registered LLM provider. LLM Providers are the operators of the LLMs. i.e. NIMs, OpenAI, Anthropic,
|
|
@@ -313,6 +328,9 @@ class TypeRegistry:
|
|
|
313
328
|
# Functions
|
|
314
329
|
self._registered_functions: dict[type[FunctionBaseConfig], RegisteredFunctionInfo] = {}
|
|
315
330
|
|
|
331
|
+
# Function Groups
|
|
332
|
+
self._registered_function_groups: dict[type[FunctionGroupBaseConfig], RegisteredFunctionGroupInfo] = {}
|
|
333
|
+
|
|
316
334
|
# LLMs
|
|
317
335
|
self._registered_llm_provider_infos: dict[type[LLMBaseConfig], RegisteredLLMProviderInfo] = {}
|
|
318
336
|
self._llm_client_provider_to_framework: dict[type[LLMBaseConfig], dict[str, RegisteredLLMClientInfo]] = {}
|
|
@@ -478,6 +496,50 @@ class TypeRegistry:
|
|
|
478
496
|
|
|
479
497
|
return list(self._registered_functions.values())
|
|
480
498
|
|
|
499
|
+
def register_function_group(self, registration: RegisteredFunctionGroupInfo):
|
|
500
|
+
"""Register a function group with the type registry.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
registration: The function group registration information
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
ValueError: If a function group with the same config type is already registered
|
|
507
|
+
"""
|
|
508
|
+
if (registration.config_type in self._registered_function_groups):
|
|
509
|
+
raise ValueError(
|
|
510
|
+
f"A function group with the same config type `{registration.config_type}` has already been "
|
|
511
|
+
"registered.")
|
|
512
|
+
|
|
513
|
+
self._registered_function_groups[registration.config_type] = registration
|
|
514
|
+
|
|
515
|
+
self._registration_changed()
|
|
516
|
+
|
|
517
|
+
def get_function_group(self, config_type: type[FunctionGroupBaseConfig]) -> RegisteredFunctionGroupInfo:
|
|
518
|
+
"""Get a registered function group by its config type.
|
|
519
|
+
|
|
520
|
+
Args:
|
|
521
|
+
config_type: The function group configuration type
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
RegisteredFunctionGroupInfo: The registered function group information
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
KeyError: If no function group is registered for the given config type
|
|
528
|
+
"""
|
|
529
|
+
try:
|
|
530
|
+
return self._registered_function_groups[config_type]
|
|
531
|
+
except KeyError as err:
|
|
532
|
+
raise KeyError(f"Could not find a registered function group for config `{config_type}`. "
|
|
533
|
+
f"Registered configs: {set(self._registered_function_groups.keys())}") from err
|
|
534
|
+
|
|
535
|
+
def get_registered_function_groups(self) -> list[RegisteredInfo[FunctionGroupBaseConfig]]:
|
|
536
|
+
"""Get all registered function groups.
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
list[RegisteredInfo[FunctionGroupBaseConfig]]: List of all registered function groups
|
|
540
|
+
"""
|
|
541
|
+
return list(self._registered_function_groups.values())
|
|
542
|
+
|
|
481
543
|
def register_llm_provider(self, info: RegisteredLLMProviderInfo):
|
|
482
544
|
|
|
483
545
|
if (info.config_type in self._registered_llm_provider_infos):
|
|
@@ -790,6 +852,9 @@ class TypeRegistry:
|
|
|
790
852
|
if component_type == ComponentEnum.FUNCTION:
|
|
791
853
|
return self._registered_functions
|
|
792
854
|
|
|
855
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
856
|
+
return self._registered_function_groups
|
|
857
|
+
|
|
793
858
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
794
859
|
return self._registered_tool_wrappers
|
|
795
860
|
|
|
@@ -854,6 +919,9 @@ class TypeRegistry:
|
|
|
854
919
|
if component_type == ComponentEnum.FUNCTION:
|
|
855
920
|
return [i.static_type() for i in self._registered_functions]
|
|
856
921
|
|
|
922
|
+
if component_type == ComponentEnum.FUNCTION_GROUP:
|
|
923
|
+
return [i.static_type() for i in self._registered_function_groups]
|
|
924
|
+
|
|
857
925
|
if component_type == ComponentEnum.TOOL_WRAPPER:
|
|
858
926
|
return list(self._registered_tool_wrappers)
|
|
859
927
|
|
|
@@ -943,6 +1011,9 @@ class TypeRegistry:
|
|
|
943
1011
|
if issubclass(cls, FunctionBaseConfig):
|
|
944
1012
|
return self._do_compute_annotation(cls, self.get_registered_functions())
|
|
945
1013
|
|
|
1014
|
+
if issubclass(cls, FunctionGroupBaseConfig):
|
|
1015
|
+
return self._do_compute_annotation(cls, self.get_registered_function_groups())
|
|
1016
|
+
|
|
946
1017
|
if issubclass(cls, LLMBaseConfig):
|
|
947
1018
|
return self._do_compute_annotation(cls, self.get_registered_llm_providers())
|
|
948
1019
|
|
nat/data_models/component.py
CHANGED
|
@@ -27,6 +27,7 @@ class ComponentEnum(StrEnum):
|
|
|
27
27
|
EVALUATOR = "evaluator"
|
|
28
28
|
FRONT_END = "front_end"
|
|
29
29
|
FUNCTION = "function"
|
|
30
|
+
FUNCTION_GROUP = "function_group"
|
|
30
31
|
TTC_STRATEGY = "ttc_strategy"
|
|
31
32
|
LLM_CLIENT = "llm_client"
|
|
32
33
|
LLM_PROVIDER = "llm_provider"
|
|
@@ -47,6 +48,7 @@ class ComponentGroup(StrEnum):
|
|
|
47
48
|
AUTHENTICATION = "authentication"
|
|
48
49
|
EMBEDDERS = "embedders"
|
|
49
50
|
FUNCTIONS = "functions"
|
|
51
|
+
FUNCTION_GROUPS = "function_groups"
|
|
50
52
|
TTC_STRATEGIES = "ttc_strategies"
|
|
51
53
|
LLMS = "llms"
|
|
52
54
|
MEMORY = "memory"
|
nat/data_models/component_ref.py
CHANGED
|
@@ -102,6 +102,17 @@ class FunctionRef(ComponentRef):
|
|
|
102
102
|
return ComponentGroup.FUNCTIONS
|
|
103
103
|
|
|
104
104
|
|
|
105
|
+
class FunctionGroupRef(ComponentRef):
|
|
106
|
+
"""
|
|
107
|
+
A reference to a function group in a NAT configuration object.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
@override
|
|
112
|
+
def component_group(self):
|
|
113
|
+
return ComponentGroup.FUNCTION_GROUPS
|
|
114
|
+
|
|
115
|
+
|
|
105
116
|
class LLMRef(ComponentRef):
|
|
106
117
|
"""
|
|
107
118
|
A reference to an LLM in a NAT configuration object.
|
nat/data_models/config.py
CHANGED
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
from pydantic import ConfigDict
|
|
22
22
|
from pydantic import Discriminator
|
|
23
|
+
from pydantic import Field
|
|
23
24
|
from pydantic import ValidationError
|
|
24
25
|
from pydantic import ValidationInfo
|
|
25
26
|
from pydantic import ValidatorFunctionWrapHandler
|
|
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
|
|
|
29
30
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
31
|
from nat.data_models.function import EmptyFunctionConfig
|
|
31
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
34
|
from nat.data_models.logging import LoggingBaseConfig
|
|
35
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
33
36
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
37
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
38
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
error_type = e['type']
|
|
58
61
|
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
62
|
requested_type = e["ctx"]["tag"]
|
|
60
|
-
|
|
61
63
|
if (info.field_name in ('workflow', 'functions')):
|
|
62
64
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
65
|
+
elif (info.field_name == "function_groups"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
|
|
63
67
|
elif (info.field_name == "authentication"):
|
|
64
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
69
|
elif (info.field_name == "llms"):
|
|
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
135
139
|
|
|
136
140
|
class TelemetryConfig(BaseModel):
|
|
137
141
|
|
|
138
|
-
logging: dict[str, LoggingBaseConfig] =
|
|
139
|
-
tracing: dict[str, TelemetryExporterBaseConfig] =
|
|
142
|
+
logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
|
|
143
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
|
|
140
144
|
|
|
141
145
|
@field_validator("logging", "tracing", mode="wrap")
|
|
142
146
|
@classmethod
|
|
@@ -185,10 +189,14 @@ class GeneralConfig(BaseModel):
|
|
|
185
189
|
|
|
186
190
|
model_config = ConfigDict(protected_namespaces=())
|
|
187
191
|
|
|
188
|
-
use_uvloop: bool =
|
|
192
|
+
use_uvloop: bool | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
deprecated=
|
|
195
|
+
"`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
|
|
196
|
+
"automatically determined based on platform")
|
|
189
197
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
198
|
+
This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
|
|
199
|
+
usage is now determined automatically based on the platform.
|
|
192
200
|
"""
|
|
193
201
|
|
|
194
202
|
telemetry: TelemetryConfig = TelemetryConfig()
|
|
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
|
|
|
240
248
|
general: GeneralConfig = GeneralConfig()
|
|
241
249
|
|
|
242
250
|
# Functions Configuration
|
|
243
|
-
functions: dict[str, FunctionBaseConfig] =
|
|
251
|
+
functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
# Function Groups Configuration
|
|
254
|
+
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
244
255
|
|
|
245
256
|
# LLMs Configuration
|
|
246
|
-
llms: dict[str, LLMBaseConfig] =
|
|
257
|
+
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
247
258
|
|
|
248
259
|
# Embedders Configuration
|
|
249
|
-
embedders: dict[str, EmbedderBaseConfig] =
|
|
260
|
+
embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
|
|
250
261
|
|
|
251
262
|
# Memory Configuration
|
|
252
|
-
memory: dict[str, MemoryBaseConfig] =
|
|
263
|
+
memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
|
|
253
264
|
|
|
254
265
|
# Object Stores Configuration
|
|
255
|
-
object_stores: dict[str, ObjectStoreBaseConfig] =
|
|
266
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
|
|
267
|
+
|
|
268
|
+
# Optimizer Configuration
|
|
269
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
256
270
|
|
|
257
271
|
# Retriever Configuration
|
|
258
|
-
retrievers: dict[str, RetrieverBaseConfig] =
|
|
272
|
+
retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
|
|
259
273
|
|
|
260
274
|
# TTC Strategies
|
|
261
|
-
ttc_strategies: dict[str, TTCStrategyBaseConfig] =
|
|
275
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
|
|
262
276
|
|
|
263
277
|
# Workflow Configuration
|
|
264
278
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
279
|
|
|
266
280
|
# Authentication Configuration
|
|
267
|
-
authentication: dict[str, AuthProviderBaseConfig] =
|
|
281
|
+
authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
|
|
268
282
|
|
|
269
283
|
# Evaluation Options
|
|
270
284
|
eval: EvalConfig = EvalConfig()
|
|
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
|
|
|
278
292
|
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
293
|
|
|
280
294
|
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
295
|
+
stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
|
|
281
296
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
297
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
298
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
|
|
|
287
302
|
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
303
|
|
|
289
304
|
@field_validator("functions",
|
|
305
|
+
"function_groups",
|
|
290
306
|
"llms",
|
|
291
307
|
"embedders",
|
|
292
308
|
"memory",
|
|
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
|
|
|
328
344
|
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
345
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
346
|
|
|
347
|
+
FunctionGroupsAnnotation = dict[str,
|
|
348
|
+
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
|
+
|
|
331
351
|
MemoryAnnotation = dict[str,
|
|
332
352
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
353
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
|
|
|
335
355
|
ObjectStoreAnnotation = dict[str,
|
|
336
356
|
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
357
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
-
|
|
339
358
|
RetrieverAnnotation = dict[str,
|
|
340
359
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
360
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
|
|
|
344
363
|
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
365
|
|
|
347
|
-
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
366
|
+
WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
|
|
348
367
|
Discriminator(TypedBaseModel.discriminator)]
|
|
349
368
|
|
|
350
369
|
should_rebuild = False
|
|
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
|
|
|
369
388
|
functions_field.annotation = FunctionsAnnotation
|
|
370
389
|
should_rebuild = True
|
|
371
390
|
|
|
391
|
+
function_groups_field = cls.model_fields.get("function_groups")
|
|
392
|
+
if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
|
|
393
|
+
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
|
+
should_rebuild = True
|
|
395
|
+
|
|
372
396
|
memory_field = cls.model_fields.get("memory")
|
|
373
397
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
398
|
memory_field.annotation = MemoryAnnotation
|
nat/data_models/function.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
17
|
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import field_validator
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
18
22
|
from .common import BaseModelRegistryTag
|
|
19
23
|
from .common import TypedBaseModel
|
|
20
24
|
|
|
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
|
23
27
|
pass
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
31
|
+
"""Base configuration for function groups.
|
|
32
|
+
|
|
33
|
+
Function groups enable sharing of configurations and resources across multiple functions.
|
|
34
|
+
"""
|
|
35
|
+
include: list[str] = Field(
|
|
36
|
+
default_factory=list,
|
|
37
|
+
description="The list of function names which should be added to the global Function registry",
|
|
38
|
+
)
|
|
39
|
+
exclude: list[str] = Field(
|
|
40
|
+
default_factory=list,
|
|
41
|
+
description="The list of function names which should be excluded from default access to the group",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@field_validator("include", "exclude")
|
|
45
|
+
@classmethod
|
|
46
|
+
def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
|
|
47
|
+
if len(set(value)) != len(value):
|
|
48
|
+
raise ValueError("Function names must be unique")
|
|
49
|
+
return sorted(value)
|
|
50
|
+
|
|
51
|
+
@model_validator(mode="after")
|
|
52
|
+
def _validate_include_exclude(self):
|
|
53
|
+
if self.include and self.exclude:
|
|
54
|
+
raise ValueError("include and exclude cannot be used together")
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
|
|
26
58
|
class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
|
|
27
59
|
pass
|
|
28
60
|
|
|
29
61
|
|
|
30
62
|
FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
|
|
63
|
+
|
|
64
|
+
FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
|
|
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
|
|
|
23
23
|
A class to represent the dependencies of a function.
|
|
24
24
|
"""
|
|
25
25
|
functions: set[str] = Field(default_factory=set)
|
|
26
|
+
function_groups: set[str] = Field(default_factory=set)
|
|
26
27
|
llms: set[str] = Field(default_factory=set)
|
|
27
28
|
embedders: set[str] = Field(default_factory=set)
|
|
28
29
|
memory_clients: set[str] = Field(default_factory=set)
|
|
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
|
|
|
33
34
|
def serialize_functions(self, v: set[str]) -> list[str]:
|
|
34
35
|
return list(v)
|
|
35
36
|
|
|
37
|
+
@field_serializer("function_groups", when_used="json")
|
|
38
|
+
def serialize_function_groups(self, v: set[str]) -> list[str]:
|
|
39
|
+
return list(v)
|
|
40
|
+
|
|
36
41
|
@field_serializer("llms", when_used="json")
|
|
37
42
|
def serialize_llms(self, v: set[str]) -> list[str]:
|
|
38
43
|
return list(v)
|
|
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
|
|
|
56
61
|
def add_function(self, function: str):
|
|
57
62
|
self.functions.add(function)
|
|
58
63
|
|
|
64
|
+
def add_function_group(self, function_group: str):
|
|
65
|
+
self.function_groups.add(function_group) # pylint: disable=no-member
|
|
66
|
+
|
|
59
67
|
def add_llm(self, llm: str):
|
|
60
68
|
self.llms.add(llm)
|
|
61
69
|
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import Generic
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
|
|
21
|
+
from optuna import Trial
|
|
22
|
+
from pydantic import BaseModel
|
|
23
|
+
from pydantic import ConfigDict
|
|
24
|
+
from pydantic import Field
|
|
25
|
+
from pydantic import model_validator
|
|
26
|
+
|
|
27
|
+
T = TypeVar("T", int, float, bool, str)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# --------------------------------------------------------------------- #
|
|
31
|
+
# 1. Hyper‑parameter metadata container #
|
|
32
|
+
# --------------------------------------------------------------------- #
|
|
33
|
+
class SearchSpace(BaseModel, Generic[T]):
|
|
34
|
+
values: Sequence[T] | None = None
|
|
35
|
+
low: T | None = None
|
|
36
|
+
high: T | None = None
|
|
37
|
+
log: bool = False # log scale
|
|
38
|
+
step: float | None = None
|
|
39
|
+
is_prompt: bool = False
|
|
40
|
+
prompt: str | None = None # prompt to optimize
|
|
41
|
+
prompt_purpose: str | None = None # purpose of the prompt
|
|
42
|
+
|
|
43
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
44
|
+
|
|
45
|
+
@model_validator(mode="after")
|
|
46
|
+
def validate_search_space_parameters(self):
|
|
47
|
+
"""Validate that either values is provided, or both high and low."""
|
|
48
|
+
if self.values is not None:
|
|
49
|
+
# If values is provided, we don't need high/low
|
|
50
|
+
if self.high is not None or self.low is not None:
|
|
51
|
+
raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
|
|
52
|
+
return self
|
|
53
|
+
|
|
54
|
+
return self
|
|
55
|
+
|
|
56
|
+
# Helper for Optuna Trials
|
|
57
|
+
def suggest(self, trial: Trial, name: str):
|
|
58
|
+
if self.is_prompt:
|
|
59
|
+
raise ValueError("Prompt optimization not currently supported using Optuna. "
|
|
60
|
+
"Use the genetic algorithm implementation instead.")
|
|
61
|
+
if self.values is not None:
|
|
62
|
+
return trial.suggest_categorical(name, self.values)
|
|
63
|
+
if isinstance(self.low, int):
|
|
64
|
+
return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
|
|
65
|
+
return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def OptimizableField(
|
|
69
|
+
default: Any,
|
|
70
|
+
*,
|
|
71
|
+
space: SearchSpace | None = None,
|
|
72
|
+
merge_conflict: str = "overwrite",
|
|
73
|
+
**fld_kw,
|
|
74
|
+
):
|
|
75
|
+
# 1. Pull out any user‑supplied extras (must be a dict)
|
|
76
|
+
user_extra = fld_kw.pop("json_schema_extra", None) or {}
|
|
77
|
+
if not isinstance(user_extra, dict):
|
|
78
|
+
raise TypeError("`json_schema_extra` must be a mapping.")
|
|
79
|
+
|
|
80
|
+
# 2. If the space is a prompt, ensure a concrete base prompt exists
|
|
81
|
+
if space is not None and getattr(space, "is_prompt", False):
|
|
82
|
+
if getattr(space, "prompt", None) is None:
|
|
83
|
+
if default is None:
|
|
84
|
+
raise ValueError("Prompt-optimized fields require a base prompt: provide a "
|
|
85
|
+
"non-None field default or set space.prompt.")
|
|
86
|
+
# Default prompt not provided in space; fall back to the field's default
|
|
87
|
+
space.prompt = default
|
|
88
|
+
|
|
89
|
+
# 3. Prepare our own metadata
|
|
90
|
+
ours = {"optimizable": True}
|
|
91
|
+
if space is not None:
|
|
92
|
+
ours["search_space"] = space
|
|
93
|
+
|
|
94
|
+
# 4. Merge with user extras according to merge_conflict policy
|
|
95
|
+
intersect = ours.keys() & user_extra.keys()
|
|
96
|
+
if intersect:
|
|
97
|
+
if merge_conflict == "error":
|
|
98
|
+
raise ValueError("`json_schema_extra` already contains reserved key(s): "
|
|
99
|
+
f"{', '.join(intersect)}")
|
|
100
|
+
if merge_conflict == "keep":
|
|
101
|
+
# remove the ones the user already set so we don't overwrite them
|
|
102
|
+
ours = {k: v for k, v in ours.items() if k not in intersect}
|
|
103
|
+
|
|
104
|
+
merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
|
|
105
|
+
|
|
106
|
+
# 5. Return a normal Pydantic Field with merged extras
|
|
107
|
+
return Field(default, json_schema_extra=merged_extra, **fld_kw)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class OptimizableMixin(BaseModel):
|
|
111
|
+
optimizable_params: list[str] = Field(default_factory=list,
|
|
112
|
+
description="List of parameters that can be optimized.",
|
|
113
|
+
exclude=True)
|
|
114
|
+
|
|
115
|
+
search_space: dict[str, SearchSpace] = Field(
|
|
116
|
+
default_factory=dict,
|
|
117
|
+
description="Optional search space overrides for optimizable parameters.",
|
|
118
|
+
exclude=True,
|
|
119
|
+
)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OptimizerMetric(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
Parameters used by the workflow optimizer to define a metric to optimize.
|
|
25
|
+
"""
|
|
26
|
+
evaluator_name: str = Field(description="Name of the metric to optimize.")
|
|
27
|
+
direction: str = Field(description="Direction of the optimization. Can be 'maximize' or 'minimize'.")
|
|
28
|
+
weight: float = Field(description="Weight of the metric in the optimization process.", default=1.0)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class NumericOptimizationConfig(BaseModel):
|
|
32
|
+
"""
|
|
33
|
+
Configuration for numeric/enum optimization (Optuna).
|
|
34
|
+
"""
|
|
35
|
+
enabled: bool = Field(default=True, description="Enable numeric optimization")
|
|
36
|
+
n_trials: int = Field(description="Number of trials for numeric optimization.", default=20)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PromptGAOptimizationConfig(BaseModel):
|
|
40
|
+
"""
|
|
41
|
+
Configuration for prompt optimization using a Genetic Algorithm.
|
|
42
|
+
"""
|
|
43
|
+
enabled: bool = Field(default=False, description="Enable GA-based prompt optimization")
|
|
44
|
+
|
|
45
|
+
# Prompt optimization function hooks
|
|
46
|
+
prompt_population_init_function: str | None = Field(
|
|
47
|
+
default=None,
|
|
48
|
+
description="Optional function name to initialize/mutate candidate prompts.",
|
|
49
|
+
)
|
|
50
|
+
prompt_recombination_function: str | None = Field(
|
|
51
|
+
default=None,
|
|
52
|
+
description="Optional function name to recombine two parent prompts into a child.",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Genetic algorithm configuration
|
|
56
|
+
ga_population_size: int = Field(
|
|
57
|
+
description="Population size for genetic algorithm prompt optimization.",
|
|
58
|
+
default=24,
|
|
59
|
+
)
|
|
60
|
+
ga_generations: int = Field(
|
|
61
|
+
description="Number of generations to evolve in GA prompt optimization.",
|
|
62
|
+
default=15,
|
|
63
|
+
)
|
|
64
|
+
ga_offspring_size: int | None = Field(
|
|
65
|
+
description="Number of offspring to produce per generation. Defaults to population_size - elitism.",
|
|
66
|
+
default=None,
|
|
67
|
+
)
|
|
68
|
+
ga_crossover_rate: float = Field(
|
|
69
|
+
description="Probability of applying crossover during reproduction.",
|
|
70
|
+
default=0.8,
|
|
71
|
+
ge=0.0,
|
|
72
|
+
le=1.0,
|
|
73
|
+
)
|
|
74
|
+
ga_mutation_rate: float = Field(
|
|
75
|
+
description="Probability of mutating a child after crossover.",
|
|
76
|
+
default=0.3,
|
|
77
|
+
ge=0.0,
|
|
78
|
+
le=1.0,
|
|
79
|
+
)
|
|
80
|
+
ga_elitism: int = Field(
|
|
81
|
+
description="Number of top individuals carried over unchanged each generation.",
|
|
82
|
+
default=2,
|
|
83
|
+
)
|
|
84
|
+
ga_selection_method: str = Field(
|
|
85
|
+
description="Parent selection strategy: 'tournament' or 'roulette'.",
|
|
86
|
+
default="tournament",
|
|
87
|
+
)
|
|
88
|
+
ga_tournament_size: int = Field(
|
|
89
|
+
description="Tournament size when using tournament selection.",
|
|
90
|
+
default=3,
|
|
91
|
+
)
|
|
92
|
+
ga_parallel_evaluations: int = Field(
|
|
93
|
+
description="Max number of individuals to evaluate concurrently per generation.",
|
|
94
|
+
default=8,
|
|
95
|
+
)
|
|
96
|
+
ga_diversity_lambda: float = Field(
|
|
97
|
+
description="Strength of diversity penalty (0 disables). Penalizes identical/near-identical prompts.",
|
|
98
|
+
default=0.0,
|
|
99
|
+
ge=0.0,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class OptimizerConfig(BaseModel):
|
|
104
|
+
"""
|
|
105
|
+
Parameters used by the workflow optimizer.
|
|
106
|
+
"""
|
|
107
|
+
output_path: Path | None = Field(
|
|
108
|
+
default=None,
|
|
109
|
+
description="Path to the output directory where the results will be saved.",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
eval_metrics: dict[str, OptimizerMetric] | None = Field(
|
|
113
|
+
description="List of evaluation metrics to optimize.",
|
|
114
|
+
default=None,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
reps_per_param_set: int = Field(
|
|
118
|
+
default=3,
|
|
119
|
+
description="Number of repetitions per parameter set for the optimization.",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
target: float | None = Field(
|
|
123
|
+
description=(
|
|
124
|
+
"Target value for the optimization. If set, the optimization will stop when this value is reached."),
|
|
125
|
+
default=None,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
multi_objective_combination_mode: str = Field(
|
|
129
|
+
description="Method to combine multiple objectives into a single score.",
|
|
130
|
+
default="harmonic",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Nested configs
|
|
134
|
+
numeric: NumericOptimizationConfig = NumericOptimizationConfig()
|
|
135
|
+
prompt: PromptGAOptimizationConfig = PromptGAOptimizationConfig()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class OptimizerRunConfig(BaseModel):
|
|
139
|
+
"""
|
|
140
|
+
Parameters used for an Optimizer R=run
|
|
141
|
+
"""
|
|
142
|
+
# Eval parameters
|
|
143
|
+
|
|
144
|
+
config_file: Path | BaseModel # allow for instantiated configs to be passed in
|
|
145
|
+
dataset: str | Path | None # dataset file path can be specified in the config file
|
|
146
|
+
result_json_path: str = "$"
|
|
147
|
+
endpoint: str | None = None # only used when running the workflow remotely
|
|
148
|
+
endpoint_timeout: int = 300
|
|
149
|
+
override: tuple[tuple[str, str], ...] = ()
|