nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +13 -8
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +6 -5
- nat/agent/react_agent/register.py +49 -39
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +2 -0
- nat/agent/responses_api_agent/__init__.py +14 -0
- nat/agent/responses_api_agent/register.py +126 -0
- nat/agent/rewoo_agent/agent.py +304 -117
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +51 -38
- nat/agent/tool_calling_agent/agent.py +75 -17
- nat/agent/tool_calling_agent/register.py +46 -23
- nat/authentication/api_key/api_key_auth_provider.py +6 -11
- nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
- nat/authentication/credential_validator/__init__.py +14 -0
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
- nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
- nat/builder/builder.py +55 -23
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +54 -15
- nat/builder/eval_builder.py +14 -9
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +370 -0
- nat/builder/function_info.py +1 -1
- nat/builder/intermediate_step_manager.py +38 -2
- nat/builder/workflow.py +5 -0
- nat/builder/workflow_builder.py +306 -54
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +60 -18
- nat/cli/entrypoint.py +15 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +72 -1
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +199 -69
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +47 -0
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +4 -3
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +8 -0
- nat/data_models/intermediate_step.py +9 -1
- nat/data_models/llm.py +15 -1
- nat/data_models/openai_mcp.py +46 -0
- nat/data_models/optimizable.py +208 -0
- nat/data_models/optimizer.py +161 -0
- nat/data_models/span.py +41 -3
- nat/data_models/thinking_mixin.py +2 -2
- nat/embedder/azure_openai_embedder.py +2 -1
- nat/embedder/nim_embedder.py +3 -2
- nat/embedder/openai_embedder.py +3 -2
- nat/eval/config.py +1 -1
- nat/eval/dataset_handler/dataset_downloader.py +3 -2
- nat/eval/dataset_handler/dataset_filter.py +34 -2
- nat/eval/evaluate.py +10 -3
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +7 -4
- nat/eval/register.py +4 -0
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
- nat/eval/usage_stats.py +2 -0
- nat/eval/utils/output_uploader.py +3 -2
- nat/eval/utils/weave_eval.py +17 -3
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +2 -2
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +69 -44
- nat/front_ends/fastapi/message_validator.py +8 -7
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +71 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
- nat/front_ends/mcp/memory_profiler.py +320 -0
- nat/front_ends/mcp/tool_converter.py +78 -25
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +21 -8
- nat/llm/azure_openai_llm.py +14 -5
- nat/llm/litellm_llm.py +80 -0
- nat/llm/nim_llm.py +23 -9
- nat/llm/openai_llm.py +19 -7
- nat/llm/register.py +4 -0
- nat/llm/utils/thinking.py +1 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +29 -55
- nat/observability/exporter/span_exporter.py +43 -15
- nat/observability/exporter_manager.py +2 -2
- nat/observability/mixin/redaction_config_mixin.py +5 -4
- nat/observability/mixin/tagging_config_mixin.py +26 -14
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +1 -1
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +21 -14
- nat/observability/register.py +16 -0
- nat/profiler/callbacks/langchain_callback_handler.py +32 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
- nat/profiler/callbacks/token_usage_base_model.py +2 -0
- nat/profiler/decorators/framework_wrapper.py +61 -9
- nat/profiler/decorators/function_tracking.py +35 -3
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/utils.py +3 -1
- nat/registry_handlers/pypi/register_pypi.py +5 -3
- nat/registry_handlers/rest/register_rest.py +5 -3
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/register.py +2 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +111 -6
- nat/runtime/session.py +49 -3
- nat/settings/global_settings.py +2 -2
- nat/tool/chat_completion.py +4 -1
- nat/tool/code_execution/code_sandbox.py +3 -6
- nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
- nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
- nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
- nat/tool/datetime_tools.py +1 -1
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/register.py +2 -7
- nat/tool/server_tools.py +15 -2
- nat/utils/__init__.py +76 -0
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +1 -1
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +278 -72
- nat/utils/io/yaml_tools.py +73 -3
- nat/utils/log_levels.py +25 -0
- nat/utils/responses_api.py +26 -0
- nat/utils/string_utils.py +16 -0
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +6 -2
- nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -461
- nat/data_models/temperature_mixin.py +0 -43
- nat/data_models/top_p_mixin.py +0 -43
- nat/observability/processor/header_redaction_processor.py +0 -123
- nat/observability/processor/redaction_processor.py +0 -77
- nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/data_models/config.py
CHANGED
|
@@ -20,6 +20,7 @@ import typing
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
from pydantic import ConfigDict
|
|
22
22
|
from pydantic import Discriminator
|
|
23
|
+
from pydantic import Field
|
|
23
24
|
from pydantic import ValidationError
|
|
24
25
|
from pydantic import ValidationInfo
|
|
25
26
|
from pydantic import ValidatorFunctionWrapHandler
|
|
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
|
|
|
29
30
|
from nat.data_models.front_end import FrontEndBaseConfig
|
|
30
31
|
from nat.data_models.function import EmptyFunctionConfig
|
|
31
32
|
from nat.data_models.function import FunctionBaseConfig
|
|
33
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
34
|
from nat.data_models.logging import LoggingBaseConfig
|
|
35
|
+
from nat.data_models.optimizer import OptimizerConfig
|
|
33
36
|
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
|
|
34
37
|
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
35
38
|
from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
|
|
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
57
60
|
error_type = e['type']
|
|
58
61
|
if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
|
|
59
62
|
requested_type = e["ctx"]["tag"]
|
|
60
|
-
|
|
61
63
|
if (info.field_name in ('workflow', 'functions')):
|
|
62
64
|
registered_keys = GlobalTypeRegistry.get().get_registered_functions()
|
|
65
|
+
elif (info.field_name == "function_groups"):
|
|
66
|
+
registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
|
|
63
67
|
elif (info.field_name == "authentication"):
|
|
64
68
|
registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
|
|
65
69
|
elif (info.field_name == "llms"):
|
|
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
|
|
|
135
139
|
|
|
136
140
|
class TelemetryConfig(BaseModel):
|
|
137
141
|
|
|
138
|
-
logging: dict[str, LoggingBaseConfig] =
|
|
139
|
-
tracing: dict[str, TelemetryExporterBaseConfig] =
|
|
142
|
+
logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
|
|
143
|
+
tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
|
|
140
144
|
|
|
141
145
|
@field_validator("logging", "tracing", mode="wrap")
|
|
142
146
|
@classmethod
|
|
@@ -183,12 +187,16 @@ class TelemetryConfig(BaseModel):
|
|
|
183
187
|
|
|
184
188
|
class GeneralConfig(BaseModel):
|
|
185
189
|
|
|
186
|
-
model_config = ConfigDict(protected_namespaces=())
|
|
190
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
187
191
|
|
|
188
|
-
use_uvloop: bool =
|
|
192
|
+
use_uvloop: bool | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
deprecated=
|
|
195
|
+
"`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
|
|
196
|
+
"automatically determined based on platform")
|
|
189
197
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
198
|
+
This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
|
|
199
|
+
usage is now determined automatically based on the platform.
|
|
192
200
|
"""
|
|
193
201
|
|
|
194
202
|
telemetry: TelemetryConfig = TelemetryConfig()
|
|
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
|
|
|
240
248
|
general: GeneralConfig = GeneralConfig()
|
|
241
249
|
|
|
242
250
|
# Functions Configuration
|
|
243
|
-
functions: dict[str, FunctionBaseConfig] =
|
|
251
|
+
functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
|
|
252
|
+
|
|
253
|
+
# Function Groups Configuration
|
|
254
|
+
function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
|
|
244
255
|
|
|
245
256
|
# LLMs Configuration
|
|
246
|
-
llms: dict[str, LLMBaseConfig] =
|
|
257
|
+
llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
|
|
247
258
|
|
|
248
259
|
# Embedders Configuration
|
|
249
|
-
embedders: dict[str, EmbedderBaseConfig] =
|
|
260
|
+
embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
|
|
250
261
|
|
|
251
262
|
# Memory Configuration
|
|
252
|
-
memory: dict[str, MemoryBaseConfig] =
|
|
263
|
+
memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
|
|
253
264
|
|
|
254
265
|
# Object Stores Configuration
|
|
255
|
-
object_stores: dict[str, ObjectStoreBaseConfig] =
|
|
266
|
+
object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
|
|
267
|
+
|
|
268
|
+
# Optimizer Configuration
|
|
269
|
+
optimizer: OptimizerConfig = OptimizerConfig()
|
|
256
270
|
|
|
257
271
|
# Retriever Configuration
|
|
258
|
-
retrievers: dict[str, RetrieverBaseConfig] =
|
|
272
|
+
retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
|
|
259
273
|
|
|
260
274
|
# TTC Strategies
|
|
261
|
-
ttc_strategies: dict[str, TTCStrategyBaseConfig] =
|
|
275
|
+
ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
|
|
262
276
|
|
|
263
277
|
# Workflow Configuration
|
|
264
278
|
workflow: FunctionBaseConfig = EmptyFunctionConfig()
|
|
265
279
|
|
|
266
280
|
# Authentication Configuration
|
|
267
|
-
authentication: dict[str, AuthProviderBaseConfig] =
|
|
281
|
+
authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
|
|
268
282
|
|
|
269
283
|
# Evaluation Options
|
|
270
284
|
eval: EvalConfig = EvalConfig()
|
|
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
|
|
|
278
292
|
stream.write(f"Workflow Type: {self.workflow.type}\n")
|
|
279
293
|
|
|
280
294
|
stream.write(f"Number of Functions: {len(self.functions)}\n")
|
|
295
|
+
stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
|
|
281
296
|
stream.write(f"Number of LLMs: {len(self.llms)}\n")
|
|
282
297
|
stream.write(f"Number of Embedders: {len(self.embedders)}\n")
|
|
283
298
|
stream.write(f"Number of Memory: {len(self.memory)}\n")
|
|
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
|
|
|
287
302
|
stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
|
|
288
303
|
|
|
289
304
|
@field_validator("functions",
|
|
305
|
+
"function_groups",
|
|
290
306
|
"llms",
|
|
291
307
|
"embedders",
|
|
292
308
|
"memory",
|
|
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
|
|
|
328
344
|
typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
329
345
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
330
346
|
|
|
347
|
+
FunctionGroupsAnnotation = dict[str,
|
|
348
|
+
typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
|
|
349
|
+
Discriminator(TypedBaseModel.discriminator)]]
|
|
350
|
+
|
|
331
351
|
MemoryAnnotation = dict[str,
|
|
332
352
|
typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
|
|
333
353
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
|
|
|
335
355
|
ObjectStoreAnnotation = dict[str,
|
|
336
356
|
typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
|
|
337
357
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
338
|
-
|
|
339
358
|
RetrieverAnnotation = dict[str,
|
|
340
359
|
typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
|
|
341
360
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
|
|
|
344
363
|
typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
|
|
345
364
|
Discriminator(TypedBaseModel.discriminator)]]
|
|
346
365
|
|
|
347
|
-
WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
|
|
366
|
+
WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
|
|
348
367
|
Discriminator(TypedBaseModel.discriminator)]
|
|
349
368
|
|
|
350
369
|
should_rebuild = False
|
|
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
|
|
|
369
388
|
functions_field.annotation = FunctionsAnnotation
|
|
370
389
|
should_rebuild = True
|
|
371
390
|
|
|
391
|
+
function_groups_field = cls.model_fields.get("function_groups")
|
|
392
|
+
if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
|
|
393
|
+
function_groups_field.annotation = FunctionGroupsAnnotation
|
|
394
|
+
should_rebuild = True
|
|
395
|
+
|
|
372
396
|
memory_field = cls.model_fields.get("memory")
|
|
373
397
|
if memory_field is not None and memory_field.annotation != MemoryAnnotation:
|
|
374
398
|
memory_field.annotation = MemoryAnnotation
|
|
@@ -26,6 +26,7 @@ from pydantic import FilePath
|
|
|
26
26
|
from pydantic import Tag
|
|
27
27
|
|
|
28
28
|
from nat.data_models.common import BaseModelRegistryTag
|
|
29
|
+
from nat.data_models.common import SerializableSecretStr
|
|
29
30
|
from nat.data_models.common import TypedBaseModel
|
|
30
31
|
|
|
31
32
|
|
|
@@ -34,8 +35,8 @@ class EvalS3Config(BaseModel):
|
|
|
34
35
|
endpoint_url: str | None = None
|
|
35
36
|
region_name: str | None = None
|
|
36
37
|
bucket: str
|
|
37
|
-
access_key:
|
|
38
|
-
secret_key:
|
|
38
|
+
access_key: SerializableSecretStr
|
|
39
|
+
secret_key: SerializableSecretStr
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
class EvalFilterEntryConfig(BaseModel):
|
|
@@ -80,7 +81,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
|
|
|
80
81
|
|
|
81
82
|
|
|
82
83
|
def read_jsonl(file_path: FilePath):
|
|
83
|
-
with open(file_path,
|
|
84
|
+
with open(file_path, encoding='utf-8') as f:
|
|
84
85
|
data = [json.loads(line) for line in f]
|
|
85
86
|
return pd.DataFrame(data)
|
|
86
87
|
|
nat/data_models/function.py
CHANGED
|
@@ -15,6 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
17
|
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
from pydantic import field_validator
|
|
20
|
+
from pydantic import model_validator
|
|
21
|
+
|
|
18
22
|
from .common import BaseModelRegistryTag
|
|
19
23
|
from .common import TypedBaseModel
|
|
20
24
|
|
|
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
|
23
27
|
pass
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
31
|
+
"""Base configuration for function groups.
|
|
32
|
+
|
|
33
|
+
Function groups enable sharing of configurations and resources across multiple functions.
|
|
34
|
+
"""
|
|
35
|
+
include: list[str] = Field(
|
|
36
|
+
default_factory=list,
|
|
37
|
+
description="The list of function names which should be added to the global Function registry",
|
|
38
|
+
)
|
|
39
|
+
exclude: list[str] = Field(
|
|
40
|
+
default_factory=list,
|
|
41
|
+
description="The list of function names which should be excluded from default access to the group",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@field_validator("include", "exclude")
|
|
45
|
+
@classmethod
|
|
46
|
+
def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
|
|
47
|
+
if len(set(value)) != len(value):
|
|
48
|
+
raise ValueError("Function names must be unique")
|
|
49
|
+
return sorted(value)
|
|
50
|
+
|
|
51
|
+
@model_validator(mode="after")
|
|
52
|
+
def _validate_include_exclude(self):
|
|
53
|
+
if self.include and self.exclude:
|
|
54
|
+
raise ValueError("include and exclude cannot be used together")
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
|
|
26
58
|
class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
|
|
27
59
|
pass
|
|
28
60
|
|
|
29
61
|
|
|
30
62
|
FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
|
|
63
|
+
|
|
64
|
+
FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
|
|
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
|
|
|
23
23
|
A class to represent the dependencies of a function.
|
|
24
24
|
"""
|
|
25
25
|
functions: set[str] = Field(default_factory=set)
|
|
26
|
+
function_groups: set[str] = Field(default_factory=set)
|
|
26
27
|
llms: set[str] = Field(default_factory=set)
|
|
27
28
|
embedders: set[str] = Field(default_factory=set)
|
|
28
29
|
memory_clients: set[str] = Field(default_factory=set)
|
|
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
|
|
|
33
34
|
def serialize_functions(self, v: set[str]) -> list[str]:
|
|
34
35
|
return list(v)
|
|
35
36
|
|
|
37
|
+
@field_serializer("function_groups", when_used="json")
|
|
38
|
+
def serialize_function_groups(self, v: set[str]) -> list[str]:
|
|
39
|
+
return list(v)
|
|
40
|
+
|
|
36
41
|
@field_serializer("llms", when_used="json")
|
|
37
42
|
def serialize_llms(self, v: set[str]) -> list[str]:
|
|
38
43
|
return list(v)
|
|
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
|
|
|
56
61
|
def add_function(self, function: str):
|
|
57
62
|
self.functions.add(function)
|
|
58
63
|
|
|
64
|
+
def add_function_group(self, function_group: str):
|
|
65
|
+
self.function_groups.add(function_group) # pylint: disable=no-member
|
|
66
|
+
|
|
59
67
|
def add_llm(self, llm: str):
|
|
60
68
|
self.llms.add(llm)
|
|
61
69
|
|
|
@@ -103,11 +103,19 @@ class ToolSchema(BaseModel):
|
|
|
103
103
|
function: ToolDetails = Field(..., description="The function details.")
|
|
104
104
|
|
|
105
105
|
|
|
106
|
+
class ServerToolUseSchema(BaseModel):
|
|
107
|
+
name: str
|
|
108
|
+
arguments: str | dict[str, typing.Any] | typing.Any
|
|
109
|
+
output: typing.Any
|
|
110
|
+
|
|
111
|
+
model_config = ConfigDict(extra="ignore")
|
|
112
|
+
|
|
113
|
+
|
|
106
114
|
class TraceMetadata(BaseModel):
|
|
107
115
|
chat_responses: typing.Any | None = None
|
|
108
116
|
chat_inputs: typing.Any | None = None
|
|
109
117
|
tool_inputs: typing.Any | None = None
|
|
110
|
-
tool_outputs: typing.Any | None = None
|
|
118
|
+
tool_outputs: list[ServerToolUseSchema] | typing.Any | None = None
|
|
111
119
|
tool_info: typing.Any | None = None
|
|
112
120
|
span_inputs: typing.Any | None = None
|
|
113
121
|
span_outputs: typing.Any | None = None
|
nat/data_models/llm.py
CHANGED
|
@@ -14,14 +14,28 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import typing
|
|
17
|
+
from enum import Enum
|
|
18
|
+
|
|
19
|
+
from pydantic import Field
|
|
17
20
|
|
|
18
21
|
from .common import BaseModelRegistryTag
|
|
19
22
|
from .common import TypedBaseModel
|
|
20
23
|
|
|
21
24
|
|
|
25
|
+
class APITypeEnum(str, Enum):
|
|
26
|
+
CHAT_COMPLETION = "chat_completion"
|
|
27
|
+
RESPONSES = "responses"
|
|
28
|
+
|
|
29
|
+
|
|
22
30
|
class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag):
|
|
23
31
|
"""Base configuration for LLM providers."""
|
|
24
|
-
|
|
32
|
+
|
|
33
|
+
api_type: APITypeEnum = Field(default=APITypeEnum.CHAT_COMPLETION,
|
|
34
|
+
description="The type of API to use for the LLM provider.",
|
|
35
|
+
json_schema_extra={
|
|
36
|
+
"enum": [e.value for e in APITypeEnum],
|
|
37
|
+
"examples": [e.value for e in APITypeEnum],
|
|
38
|
+
})
|
|
25
39
|
|
|
26
40
|
|
|
27
41
|
LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import ConfigDict
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MCPApprovalRequiredEnum(str, Enum):
|
|
24
|
+
"""
|
|
25
|
+
Enum to specify if approval is required for tool usage in the OpenAI MCP schema.
|
|
26
|
+
"""
|
|
27
|
+
NEVER = "never"
|
|
28
|
+
ALWAYS = "always"
|
|
29
|
+
AUTO = "auto"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpenAIMCPSchemaTool(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
Represents a tool in the OpenAI MCP schema.
|
|
35
|
+
"""
|
|
36
|
+
type: str = "mcp"
|
|
37
|
+
server_label: str = Field(description="Label for the server where the tool is hosted.")
|
|
38
|
+
server_url: str = Field(description="URL of the server hosting the tool.")
|
|
39
|
+
allowed_tools: list[str] | None = Field(default=None,
|
|
40
|
+
description="List of allowed tool names that can be used by the agent.")
|
|
41
|
+
require_approval: MCPApprovalRequiredEnum = Field(default=MCPApprovalRequiredEnum.NEVER,
|
|
42
|
+
description="Specifies if approval is required for tool usage.")
|
|
43
|
+
headers: dict[str, str] | None = Field(default=None,
|
|
44
|
+
description="Optional headers to include in requests to the tool server.")
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from typing import Any
|
|
18
|
+
from typing import Generic
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from optuna import Trial
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from pydantic import ConfigDict
|
|
25
|
+
from pydantic import Field
|
|
26
|
+
from pydantic import model_validator
|
|
27
|
+
from pydantic_core import PydanticUndefined
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T", int, float, bool, str)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --------------------------------------------------------------------- #
|
|
33
|
+
# 1. Hyper‑parameter metadata container #
|
|
34
|
+
# --------------------------------------------------------------------- #
|
|
35
|
+
class SearchSpace(BaseModel, Generic[T]):
|
|
36
|
+
values: Sequence[T] | None = None
|
|
37
|
+
low: T | None = None
|
|
38
|
+
high: T | None = None
|
|
39
|
+
log: bool = False # log scale
|
|
40
|
+
step: float | None = None
|
|
41
|
+
is_prompt: bool = False
|
|
42
|
+
prompt: str | None = None # prompt to optimize
|
|
43
|
+
prompt_purpose: str | None = None # purpose of the prompt
|
|
44
|
+
|
|
45
|
+
model_config = ConfigDict(protected_namespaces=(), extra="forbid")
|
|
46
|
+
|
|
47
|
+
@model_validator(mode="after")
|
|
48
|
+
def validate_search_space_parameters(self):
|
|
49
|
+
"""Validate SearchSpace configuration."""
|
|
50
|
+
# 1. Prompt-specific validation
|
|
51
|
+
if self.is_prompt:
|
|
52
|
+
# When optimizing prompts, numeric parameters don't make sense
|
|
53
|
+
if self.low is not None or self.high is not None:
|
|
54
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'low' or 'high' parameters")
|
|
55
|
+
if self.log:
|
|
56
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'log=True'")
|
|
57
|
+
if self.step is not None:
|
|
58
|
+
raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'step' parameter")
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
# 2. Values-based validation
|
|
62
|
+
if self.values is not None:
|
|
63
|
+
# If values is provided, we don't need high/low
|
|
64
|
+
if self.high is not None or self.low is not None:
|
|
65
|
+
raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
|
|
66
|
+
# Ensure values is not empty
|
|
67
|
+
if len(self.values) == 0:
|
|
68
|
+
raise ValueError("SearchSpace 'values' must not be empty")
|
|
69
|
+
return self
|
|
70
|
+
|
|
71
|
+
# 3. Range-based validation
|
|
72
|
+
if (self.low is None) != (self.high is None): # XOR using !=
|
|
73
|
+
raise ValueError(f"SearchSpace range requires both 'low' and 'high'; got low={self.low}, high={self.high}")
|
|
74
|
+
if self.low is not None and self.high is not None and self.low >= self.high:
|
|
75
|
+
raise ValueError(f"SearchSpace 'low' must be less than 'high'; got low={self.low}, high={self.high}")
|
|
76
|
+
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
# Helper for Optuna Trials
|
|
80
|
+
def suggest(self, trial: Trial, name: str):
|
|
81
|
+
if self.is_prompt:
|
|
82
|
+
raise ValueError("Prompt optimization not currently supported using Optuna. "
|
|
83
|
+
"Use the genetic algorithm implementation instead.")
|
|
84
|
+
if self.values is not None:
|
|
85
|
+
return trial.suggest_categorical(name, self.values)
|
|
86
|
+
if isinstance(self.low, int):
|
|
87
|
+
return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
|
|
88
|
+
return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
|
|
89
|
+
|
|
90
|
+
def to_grid_values(self) -> list[Any]:
|
|
91
|
+
"""
|
|
92
|
+
Convert SearchSpace to a list of values for GridSampler.
|
|
93
|
+
|
|
94
|
+
Grid search requires explicit values. This can be provided in two ways:
|
|
95
|
+
1. Explicit values: SearchSpace(values=[0.1, 0.5, 0.9])
|
|
96
|
+
2. Range with step: SearchSpace(low=0.1, high=0.9, step=0.2)
|
|
97
|
+
|
|
98
|
+
For ranges, step is required (no default will be applied) to avoid
|
|
99
|
+
unintentional combinatorial explosion.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
if self.is_prompt:
|
|
103
|
+
raise ValueError("Prompt optimization not currently supported using Optuna. "
|
|
104
|
+
"Use the genetic algorithm implementation instead.")
|
|
105
|
+
|
|
106
|
+
# Option 1: Explicit values provided
|
|
107
|
+
if self.values is not None:
|
|
108
|
+
return list(self.values)
|
|
109
|
+
|
|
110
|
+
# Option 2: Range with required step
|
|
111
|
+
if self.low is None or self.high is None:
|
|
112
|
+
raise ValueError("Grid search requires either 'values' or both 'low' and 'high' to be defined")
|
|
113
|
+
|
|
114
|
+
if self.step is None:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Grid search with range (low={self.low}, high={self.high}) requires 'step' to be specified. "
|
|
117
|
+
"Please define the step size to discretize the range, for example: step=0.1")
|
|
118
|
+
|
|
119
|
+
# Validate step is positive
|
|
120
|
+
step_float = float(self.step)
|
|
121
|
+
if step_float <= 0:
|
|
122
|
+
raise ValueError(f"Grid search step must be positive; got step={self.step}")
|
|
123
|
+
|
|
124
|
+
# Generate grid values from range with step
|
|
125
|
+
# Use integer range only if low, high, and step are all integral
|
|
126
|
+
if (isinstance(self.low, int) and isinstance(self.high, int) and step_float.is_integer()):
|
|
127
|
+
step = int(step_float)
|
|
128
|
+
|
|
129
|
+
if self.log:
|
|
130
|
+
raise ValueError("Log scale is not supported for integer ranges in grid search. "
|
|
131
|
+
"Please use linear scale or provide explicit values.")
|
|
132
|
+
values = list(range(self.low, self.high + 1, step))
|
|
133
|
+
if values and values[-1] != self.high:
|
|
134
|
+
values.append(self.high)
|
|
135
|
+
return values
|
|
136
|
+
|
|
137
|
+
# Float range (including integer low/high with float step)
|
|
138
|
+
low_val = float(self.low)
|
|
139
|
+
high_val = float(self.high)
|
|
140
|
+
step_val = step_float
|
|
141
|
+
|
|
142
|
+
if self.log:
|
|
143
|
+
raise ValueError("Log scale is not yet supported for grid search with ranges. "
|
|
144
|
+
"Please provide explicit values using the 'values' field.")
|
|
145
|
+
|
|
146
|
+
# Use arange to respect step size
|
|
147
|
+
values = np.arange(low_val, high_val, step_val).tolist()
|
|
148
|
+
|
|
149
|
+
# Always include the high endpoint if not already present (within tolerance)
|
|
150
|
+
# This ensures the full range is explored in grid search
|
|
151
|
+
if not values or abs(values[-1] - high_val) > 1e-9:
|
|
152
|
+
values.append(high_val)
|
|
153
|
+
|
|
154
|
+
return values
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def OptimizableField(
|
|
158
|
+
default: Any = PydanticUndefined,
|
|
159
|
+
*,
|
|
160
|
+
space: SearchSpace | None = None,
|
|
161
|
+
merge_conflict: str = "overwrite",
|
|
162
|
+
**fld_kw,
|
|
163
|
+
):
|
|
164
|
+
# 1. Pull out any user‑supplied extras (must be a dict)
|
|
165
|
+
user_extra = fld_kw.pop("json_schema_extra", None) or {}
|
|
166
|
+
if not isinstance(user_extra, dict):
|
|
167
|
+
raise TypeError("`json_schema_extra` must be a mapping.")
|
|
168
|
+
|
|
169
|
+
# 2. If the space is a prompt, ensure a concrete base prompt exists
|
|
170
|
+
if space is not None and getattr(space, "is_prompt", False):
|
|
171
|
+
if getattr(space, "prompt", None) is None:
|
|
172
|
+
if default is None:
|
|
173
|
+
raise ValueError("Prompt-optimized fields require a base prompt: provide a "
|
|
174
|
+
"non-None field default or set space.prompt.")
|
|
175
|
+
# Default prompt not provided in space; fall back to the field's default
|
|
176
|
+
space.prompt = default
|
|
177
|
+
|
|
178
|
+
# 3. Prepare our own metadata
|
|
179
|
+
ours = {"optimizable": True}
|
|
180
|
+
if space is not None:
|
|
181
|
+
ours["search_space"] = space
|
|
182
|
+
|
|
183
|
+
# 4. Merge with user extras according to merge_conflict policy
|
|
184
|
+
intersect = ours.keys() & user_extra.keys()
|
|
185
|
+
if intersect:
|
|
186
|
+
if merge_conflict == "error":
|
|
187
|
+
raise ValueError("`json_schema_extra` already contains reserved key(s): "
|
|
188
|
+
f"{', '.join(intersect)}")
|
|
189
|
+
if merge_conflict == "keep":
|
|
190
|
+
# remove the ones the user already set so we don't overwrite them
|
|
191
|
+
ours = {k: v for k, v in ours.items() if k not in intersect}
|
|
192
|
+
|
|
193
|
+
merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
|
|
194
|
+
|
|
195
|
+
# 5. Return a normal Pydantic Field with merged extras
|
|
196
|
+
return Field(default, json_schema_extra=merged_extra, **fld_kw)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class OptimizableMixin(BaseModel):
|
|
200
|
+
optimizable_params: list[str] = Field(default_factory=list,
|
|
201
|
+
description="List of parameters that can be optimized.",
|
|
202
|
+
exclude=True)
|
|
203
|
+
|
|
204
|
+
search_space: dict[str, SearchSpace] = Field(
|
|
205
|
+
default_factory=dict,
|
|
206
|
+
description="Optional search space overrides for optimizable parameters.",
|
|
207
|
+
exclude=True,
|
|
208
|
+
)
|