nvidia-nat 1.3.0a20250910__py3-none-any.whl → 1.4.0a20251112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. nat/agent/base.py +13 -8
  2. nat/agent/prompt_optimizer/prompt.py +68 -0
  3. nat/agent/prompt_optimizer/register.py +149 -0
  4. nat/agent/react_agent/agent.py +6 -5
  5. nat/agent/react_agent/register.py +49 -39
  6. nat/agent/reasoning_agent/reasoning_agent.py +17 -15
  7. nat/agent/register.py +2 -0
  8. nat/agent/responses_api_agent/__init__.py +14 -0
  9. nat/agent/responses_api_agent/register.py +126 -0
  10. nat/agent/rewoo_agent/agent.py +304 -117
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +51 -38
  13. nat/agent/tool_calling_agent/agent.py +75 -17
  14. nat/agent/tool_calling_agent/register.py +46 -23
  15. nat/authentication/api_key/api_key_auth_provider.py +6 -11
  16. nat/authentication/api_key/api_key_auth_provider_config.py +8 -5
  17. nat/authentication/credential_validator/__init__.py +14 -0
  18. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  19. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  20. nat/authentication/interfaces.py +5 -2
  21. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
  22. nat/authentication/oauth2/oauth2_auth_code_flow_provider_config.py +2 -1
  23. nat/authentication/oauth2/oauth2_resource_server_config.py +125 -0
  24. nat/builder/builder.py +55 -23
  25. nat/builder/component_utils.py +9 -5
  26. nat/builder/context.py +54 -15
  27. nat/builder/eval_builder.py +14 -9
  28. nat/builder/framework_enum.py +1 -0
  29. nat/builder/front_end.py +1 -1
  30. nat/builder/function.py +370 -0
  31. nat/builder/function_info.py +1 -1
  32. nat/builder/intermediate_step_manager.py +38 -2
  33. nat/builder/workflow.py +5 -0
  34. nat/builder/workflow_builder.py +306 -54
  35. nat/cli/cli_utils/config_override.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/mcp/__init__.py +14 -0
  38. nat/cli/commands/mcp/mcp.py +986 -0
  39. nat/cli/commands/optimize.py +90 -0
  40. nat/cli/commands/start.py +1 -1
  41. nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
  42. nat/cli/commands/workflow/templates/register.py.j2 +2 -2
  43. nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
  44. nat/cli/commands/workflow/workflow_commands.py +60 -18
  45. nat/cli/entrypoint.py +15 -11
  46. nat/cli/main.py +3 -0
  47. nat/cli/register_workflow.py +38 -4
  48. nat/cli/type_registry.py +72 -1
  49. nat/control_flow/__init__.py +0 -0
  50. nat/control_flow/register.py +20 -0
  51. nat/control_flow/router_agent/__init__.py +0 -0
  52. nat/control_flow/router_agent/agent.py +329 -0
  53. nat/control_flow/router_agent/prompt.py +48 -0
  54. nat/control_flow/router_agent/register.py +91 -0
  55. nat/control_flow/sequential_executor.py +166 -0
  56. nat/data_models/agent.py +34 -0
  57. nat/data_models/api_server.py +199 -69
  58. nat/data_models/authentication.py +23 -9
  59. nat/data_models/common.py +47 -0
  60. nat/data_models/component.py +2 -0
  61. nat/data_models/component_ref.py +11 -0
  62. nat/data_models/config.py +41 -17
  63. nat/data_models/dataset_handler.py +4 -3
  64. nat/data_models/function.py +34 -0
  65. nat/data_models/function_dependencies.py +8 -0
  66. nat/data_models/intermediate_step.py +9 -1
  67. nat/data_models/llm.py +15 -1
  68. nat/data_models/openai_mcp.py +46 -0
  69. nat/data_models/optimizable.py +208 -0
  70. nat/data_models/optimizer.py +161 -0
  71. nat/data_models/span.py +41 -3
  72. nat/data_models/thinking_mixin.py +2 -2
  73. nat/embedder/azure_openai_embedder.py +2 -1
  74. nat/embedder/nim_embedder.py +3 -2
  75. nat/embedder/openai_embedder.py +3 -2
  76. nat/eval/config.py +1 -1
  77. nat/eval/dataset_handler/dataset_downloader.py +3 -2
  78. nat/eval/dataset_handler/dataset_filter.py +34 -2
  79. nat/eval/evaluate.py +10 -3
  80. nat/eval/evaluator/base_evaluator.py +1 -1
  81. nat/eval/rag_evaluator/evaluate.py +7 -4
  82. nat/eval/register.py +4 -0
  83. nat/eval/runtime_evaluator/__init__.py +14 -0
  84. nat/eval/runtime_evaluator/evaluate.py +123 -0
  85. nat/eval/runtime_evaluator/register.py +100 -0
  86. nat/eval/swe_bench_evaluator/evaluate.py +1 -1
  87. nat/eval/trajectory_evaluator/register.py +1 -1
  88. nat/eval/tunable_rag_evaluator/evaluate.py +1 -1
  89. nat/eval/usage_stats.py +2 -0
  90. nat/eval/utils/output_uploader.py +3 -2
  91. nat/eval/utils/weave_eval.py +17 -3
  92. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  93. nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
  94. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  95. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +1 -1
  96. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
  97. nat/experimental/test_time_compute/models/strategy_base.py +2 -2
  98. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -1
  99. nat/front_ends/console/authentication_flow_handler.py +82 -30
  100. nat/front_ends/console/console_front_end_plugin.py +19 -7
  101. nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
  102. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  103. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  104. nat/front_ends/fastapi/fastapi_front_end_config.py +25 -3
  105. nat/front_ends/fastapi/fastapi_front_end_plugin.py +140 -3
  106. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +445 -265
  107. nat/front_ends/fastapi/job_store.py +518 -99
  108. nat/front_ends/fastapi/main.py +11 -19
  109. nat/front_ends/fastapi/message_handler.py +69 -44
  110. nat/front_ends/fastapi/message_validator.py +8 -7
  111. nat/front_ends/fastapi/utils.py +57 -0
  112. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  113. nat/front_ends/mcp/mcp_front_end_config.py +71 -3
  114. nat/front_ends/mcp/mcp_front_end_plugin.py +85 -21
  115. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +248 -29
  116. nat/front_ends/mcp/memory_profiler.py +320 -0
  117. nat/front_ends/mcp/tool_converter.py +78 -25
  118. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  119. nat/llm/aws_bedrock_llm.py +21 -8
  120. nat/llm/azure_openai_llm.py +14 -5
  121. nat/llm/litellm_llm.py +80 -0
  122. nat/llm/nim_llm.py +23 -9
  123. nat/llm/openai_llm.py +19 -7
  124. nat/llm/register.py +4 -0
  125. nat/llm/utils/thinking.py +1 -1
  126. nat/observability/exporter/base_exporter.py +1 -1
  127. nat/observability/exporter/processing_exporter.py +29 -55
  128. nat/observability/exporter/span_exporter.py +43 -15
  129. nat/observability/exporter_manager.py +2 -2
  130. nat/observability/mixin/redaction_config_mixin.py +5 -4
  131. nat/observability/mixin/tagging_config_mixin.py +26 -14
  132. nat/observability/mixin/type_introspection_mixin.py +420 -107
  133. nat/observability/processor/batching_processor.py +1 -1
  134. nat/observability/processor/processor.py +3 -0
  135. nat/observability/processor/redaction/__init__.py +24 -0
  136. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  137. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  138. nat/observability/processor/redaction/redaction_processor.py +177 -0
  139. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  140. nat/observability/processor/span_tagging_processor.py +21 -14
  141. nat/observability/register.py +16 -0
  142. nat/profiler/callbacks/langchain_callback_handler.py +32 -7
  143. nat/profiler/callbacks/llama_index_callback_handler.py +36 -2
  144. nat/profiler/callbacks/token_usage_base_model.py +2 -0
  145. nat/profiler/decorators/framework_wrapper.py +61 -9
  146. nat/profiler/decorators/function_tracking.py +35 -3
  147. nat/profiler/forecasting/models/linear_model.py +1 -1
  148. nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
  149. nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
  150. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
  151. nat/profiler/parameter_optimization/__init__.py +0 -0
  152. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  153. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  154. nat/profiler/parameter_optimization/parameter_optimizer.py +189 -0
  155. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  156. nat/profiler/parameter_optimization/pareto_visualizer.py +460 -0
  157. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  158. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  159. nat/profiler/utils.py +3 -1
  160. nat/registry_handlers/pypi/register_pypi.py +5 -3
  161. nat/registry_handlers/rest/register_rest.py +5 -3
  162. nat/retriever/milvus/retriever.py +1 -1
  163. nat/retriever/nemo_retriever/register.py +2 -1
  164. nat/runtime/loader.py +1 -1
  165. nat/runtime/runner.py +111 -6
  166. nat/runtime/session.py +49 -3
  167. nat/settings/global_settings.py +2 -2
  168. nat/tool/chat_completion.py +4 -1
  169. nat/tool/code_execution/code_sandbox.py +3 -6
  170. nat/tool/code_execution/local_sandbox/Dockerfile.sandbox +19 -32
  171. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +6 -1
  172. nat/tool/code_execution/local_sandbox/sandbox.requirements.txt +2 -0
  173. nat/tool/code_execution/local_sandbox/start_local_sandbox.sh +10 -4
  174. nat/tool/datetime_tools.py +1 -1
  175. nat/tool/github_tools.py +450 -0
  176. nat/tool/memory_tools/add_memory_tool.py +3 -3
  177. nat/tool/memory_tools/delete_memory_tool.py +3 -4
  178. nat/tool/memory_tools/get_memory_tool.py +4 -4
  179. nat/tool/register.py +2 -7
  180. nat/tool/server_tools.py +15 -2
  181. nat/utils/__init__.py +76 -0
  182. nat/utils/callable_utils.py +70 -0
  183. nat/utils/data_models/schema_validator.py +1 -1
  184. nat/utils/decorators.py +210 -0
  185. nat/utils/exception_handlers/automatic_retries.py +278 -72
  186. nat/utils/io/yaml_tools.py +73 -3
  187. nat/utils/log_levels.py +25 -0
  188. nat/utils/responses_api.py +26 -0
  189. nat/utils/string_utils.py +16 -0
  190. nat/utils/type_converter.py +12 -3
  191. nat/utils/type_utils.py +6 -2
  192. nvidia_nat-1.4.0a20251112.dist-info/METADATA +197 -0
  193. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/RECORD +199 -165
  194. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/entry_points.txt +1 -0
  195. nat/cli/commands/info/list_mcp.py +0 -461
  196. nat/data_models/temperature_mixin.py +0 -43
  197. nat/data_models/top_p_mixin.py +0 -43
  198. nat/observability/processor/header_redaction_processor.py +0 -123
  199. nat/observability/processor/redaction_processor.py +0 -77
  200. nat/tool/code_execution/test_code_execution_sandbox.py +0 -414
  201. nat/tool/github_tools/create_github_commit.py +0 -133
  202. nat/tool/github_tools/create_github_issue.py +0 -87
  203. nat/tool/github_tools/create_github_pr.py +0 -106
  204. nat/tool/github_tools/get_github_file.py +0 -106
  205. nat/tool/github_tools/get_github_issue.py +0 -166
  206. nat/tool/github_tools/get_github_pr.py +0 -256
  207. nat/tool/github_tools/update_github_issue.py +0 -100
  208. nvidia_nat-1.3.0a20250910.dist-info/METADATA +0 -373
  209. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  210. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/WHEEL +0 -0
  211. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  212. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/licenses/LICENSE.md +0 -0
  213. {nvidia_nat-1.3.0a20250910.dist-info → nvidia_nat-1.4.0a20251112.dist-info}/top_level.txt +0 -0
nat/data_models/config.py CHANGED
@@ -20,6 +20,7 @@ import typing
20
20
  from pydantic import BaseModel
21
21
  from pydantic import ConfigDict
22
22
  from pydantic import Discriminator
23
+ from pydantic import Field
23
24
  from pydantic import ValidationError
24
25
  from pydantic import ValidationInfo
25
26
  from pydantic import ValidatorFunctionWrapHandler
@@ -29,7 +30,9 @@ from nat.data_models.evaluate import EvalConfig
29
30
  from nat.data_models.front_end import FrontEndBaseConfig
30
31
  from nat.data_models.function import EmptyFunctionConfig
31
32
  from nat.data_models.function import FunctionBaseConfig
33
+ from nat.data_models.function import FunctionGroupBaseConfig
32
34
  from nat.data_models.logging import LoggingBaseConfig
35
+ from nat.data_models.optimizer import OptimizerConfig
33
36
  from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
34
37
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
35
38
  from nat.front_ends.fastapi.fastapi_front_end_config import FastApiFrontEndConfig
@@ -57,9 +60,10 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
57
60
  error_type = e['type']
58
61
  if error_type == 'union_tag_invalid' and "ctx" in e and not logged_once:
59
62
  requested_type = e["ctx"]["tag"]
60
-
61
63
  if (info.field_name in ('workflow', 'functions')):
62
64
  registered_keys = GlobalTypeRegistry.get().get_registered_functions()
65
+ elif (info.field_name == "function_groups"):
66
+ registered_keys = GlobalTypeRegistry.get().get_registered_function_groups()
63
67
  elif (info.field_name == "authentication"):
64
68
  registered_keys = GlobalTypeRegistry.get().get_registered_auth_providers()
65
69
  elif (info.field_name == "llms"):
@@ -135,8 +139,8 @@ def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWr
135
139
 
136
140
  class TelemetryConfig(BaseModel):
137
141
 
138
- logging: dict[str, LoggingBaseConfig] = {}
139
- tracing: dict[str, TelemetryExporterBaseConfig] = {}
142
+ logging: dict[str, LoggingBaseConfig] = Field(default_factory=dict)
143
+ tracing: dict[str, TelemetryExporterBaseConfig] = Field(default_factory=dict)
140
144
 
141
145
  @field_validator("logging", "tracing", mode="wrap")
142
146
  @classmethod
@@ -183,12 +187,16 @@ class TelemetryConfig(BaseModel):
183
187
 
184
188
  class GeneralConfig(BaseModel):
185
189
 
186
- model_config = ConfigDict(protected_namespaces=())
190
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
187
191
 
188
- use_uvloop: bool = True
192
+ use_uvloop: bool | None = Field(
193
+ default=None,
194
+ deprecated=
195
+ "`use_uvloop` field is deprecated and will be removed in a future release. The use of `uv_loop` is now" +
196
+ "automatically determined based on platform")
189
197
  """
190
- Whether to use uvloop for the event loop. This can provide a significant speedup in some cases. Disable to provide
191
- better error messages when debugging.
198
+ This field is deprecated and ignored. It previously controlled whether to use uvloop as the event loop. uvloop
199
+ usage is now determined automatically based on the platform.
192
200
  """
193
201
 
194
202
  telemetry: TelemetryConfig = TelemetryConfig()
@@ -240,31 +248,37 @@ class Config(HashableBaseModel):
240
248
  general: GeneralConfig = GeneralConfig()
241
249
 
242
250
  # Functions Configuration
243
- functions: dict[str, FunctionBaseConfig] = {}
251
+ functions: dict[str, FunctionBaseConfig] = Field(default_factory=dict)
252
+
253
+ # Function Groups Configuration
254
+ function_groups: dict[str, FunctionGroupBaseConfig] = Field(default_factory=dict)
244
255
 
245
256
  # LLMs Configuration
246
- llms: dict[str, LLMBaseConfig] = {}
257
+ llms: dict[str, LLMBaseConfig] = Field(default_factory=dict)
247
258
 
248
259
  # Embedders Configuration
249
- embedders: dict[str, EmbedderBaseConfig] = {}
260
+ embedders: dict[str, EmbedderBaseConfig] = Field(default_factory=dict)
250
261
 
251
262
  # Memory Configuration
252
- memory: dict[str, MemoryBaseConfig] = {}
263
+ memory: dict[str, MemoryBaseConfig] = Field(default_factory=dict)
253
264
 
254
265
  # Object Stores Configuration
255
- object_stores: dict[str, ObjectStoreBaseConfig] = {}
266
+ object_stores: dict[str, ObjectStoreBaseConfig] = Field(default_factory=dict)
267
+
268
+ # Optimizer Configuration
269
+ optimizer: OptimizerConfig = OptimizerConfig()
256
270
 
257
271
  # Retriever Configuration
258
- retrievers: dict[str, RetrieverBaseConfig] = {}
272
+ retrievers: dict[str, RetrieverBaseConfig] = Field(default_factory=dict)
259
273
 
260
274
  # TTC Strategies
261
- ttc_strategies: dict[str, TTCStrategyBaseConfig] = {}
275
+ ttc_strategies: dict[str, TTCStrategyBaseConfig] = Field(default_factory=dict)
262
276
 
263
277
  # Workflow Configuration
264
278
  workflow: FunctionBaseConfig = EmptyFunctionConfig()
265
279
 
266
280
  # Authentication Configuration
267
- authentication: dict[str, AuthProviderBaseConfig] = {}
281
+ authentication: dict[str, AuthProviderBaseConfig] = Field(default_factory=dict)
268
282
 
269
283
  # Evaluation Options
270
284
  eval: EvalConfig = EvalConfig()
@@ -278,6 +292,7 @@ class Config(HashableBaseModel):
278
292
  stream.write(f"Workflow Type: {self.workflow.type}\n")
279
293
 
280
294
  stream.write(f"Number of Functions: {len(self.functions)}\n")
295
+ stream.write(f"Number of Function Groups: {len(self.function_groups)}\n")
281
296
  stream.write(f"Number of LLMs: {len(self.llms)}\n")
282
297
  stream.write(f"Number of Embedders: {len(self.embedders)}\n")
283
298
  stream.write(f"Number of Memory: {len(self.memory)}\n")
@@ -287,6 +302,7 @@ class Config(HashableBaseModel):
287
302
  stream.write(f"Number of Authentication Providers: {len(self.authentication)}\n")
288
303
 
289
304
  @field_validator("functions",
305
+ "function_groups",
290
306
  "llms",
291
307
  "embedders",
292
308
  "memory",
@@ -328,6 +344,10 @@ class Config(HashableBaseModel):
328
344
  typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
329
345
  Discriminator(TypedBaseModel.discriminator)]]
330
346
 
347
+ FunctionGroupsAnnotation = dict[str,
348
+ typing.Annotated[type_registry.compute_annotation(FunctionGroupBaseConfig),
349
+ Discriminator(TypedBaseModel.discriminator)]]
350
+
331
351
  MemoryAnnotation = dict[str,
332
352
  typing.Annotated[type_registry.compute_annotation(MemoryBaseConfig),
333
353
  Discriminator(TypedBaseModel.discriminator)]]
@@ -335,7 +355,6 @@ class Config(HashableBaseModel):
335
355
  ObjectStoreAnnotation = dict[str,
336
356
  typing.Annotated[type_registry.compute_annotation(ObjectStoreBaseConfig),
337
357
  Discriminator(TypedBaseModel.discriminator)]]
338
-
339
358
  RetrieverAnnotation = dict[str,
340
359
  typing.Annotated[type_registry.compute_annotation(RetrieverBaseConfig),
341
360
  Discriminator(TypedBaseModel.discriminator)]]
@@ -344,7 +363,7 @@ class Config(HashableBaseModel):
344
363
  typing.Annotated[type_registry.compute_annotation(TTCStrategyBaseConfig),
345
364
  Discriminator(TypedBaseModel.discriminator)]]
346
365
 
347
- WorkflowAnnotation = typing.Annotated[type_registry.compute_annotation(FunctionBaseConfig),
366
+ WorkflowAnnotation = typing.Annotated[(type_registry.compute_annotation(FunctionBaseConfig)),
348
367
  Discriminator(TypedBaseModel.discriminator)]
349
368
 
350
369
  should_rebuild = False
@@ -369,6 +388,11 @@ class Config(HashableBaseModel):
369
388
  functions_field.annotation = FunctionsAnnotation
370
389
  should_rebuild = True
371
390
 
391
+ function_groups_field = cls.model_fields.get("function_groups")
392
+ if function_groups_field is not None and function_groups_field.annotation != FunctionGroupsAnnotation:
393
+ function_groups_field.annotation = FunctionGroupsAnnotation
394
+ should_rebuild = True
395
+
372
396
  memory_field = cls.model_fields.get("memory")
373
397
  if memory_field is not None and memory_field.annotation != MemoryAnnotation:
374
398
  memory_field.annotation = MemoryAnnotation
@@ -26,6 +26,7 @@ from pydantic import FilePath
26
26
  from pydantic import Tag
27
27
 
28
28
  from nat.data_models.common import BaseModelRegistryTag
29
+ from nat.data_models.common import SerializableSecretStr
29
30
  from nat.data_models.common import TypedBaseModel
30
31
 
31
32
 
@@ -34,8 +35,8 @@ class EvalS3Config(BaseModel):
34
35
  endpoint_url: str | None = None
35
36
  region_name: str | None = None
36
37
  bucket: str
37
- access_key: str
38
- secret_key: str
38
+ access_key: SerializableSecretStr
39
+ secret_key: SerializableSecretStr
39
40
 
40
41
 
41
42
  class EvalFilterEntryConfig(BaseModel):
@@ -80,7 +81,7 @@ class EvalDatasetJsonConfig(EvalDatasetBaseConfig, name="json"):
80
81
 
81
82
 
82
83
  def read_jsonl(file_path: FilePath):
83
- with open(file_path, 'r', encoding='utf-8') as f:
84
+ with open(file_path, encoding='utf-8') as f:
84
85
  data = [json.loads(line) for line in f]
85
86
  return pd.DataFrame(data)
86
87
 
@@ -15,6 +15,10 @@
15
15
 
16
16
  import typing
17
17
 
18
+ from pydantic import Field
19
+ from pydantic import field_validator
20
+ from pydantic import model_validator
21
+
18
22
  from .common import BaseModelRegistryTag
19
23
  from .common import TypedBaseModel
20
24
 
@@ -23,8 +27,38 @@ class FunctionBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
27
  pass
24
28
 
25
29
 
30
+ class FunctionGroupBaseConfig(TypedBaseModel, BaseModelRegistryTag):
31
+ """Base configuration for function groups.
32
+
33
+ Function groups enable sharing of configurations and resources across multiple functions.
34
+ """
35
+ include: list[str] = Field(
36
+ default_factory=list,
37
+ description="The list of function names which should be added to the global Function registry",
38
+ )
39
+ exclude: list[str] = Field(
40
+ default_factory=list,
41
+ description="The list of function names which should be excluded from default access to the group",
42
+ )
43
+
44
+ @field_validator("include", "exclude")
45
+ @classmethod
46
+ def _validate_fields_include_exclude(cls, value: list[str]) -> list[str]:
47
+ if len(set(value)) != len(value):
48
+ raise ValueError("Function names must be unique")
49
+ return sorted(value)
50
+
51
+ @model_validator(mode="after")
52
+ def _validate_include_exclude(self):
53
+ if self.include and self.exclude:
54
+ raise ValueError("include and exclude cannot be used together")
55
+ return self
56
+
57
+
26
58
  class EmptyFunctionConfig(FunctionBaseConfig, name="EmptyFunctionConfig"):
27
59
  pass
28
60
 
29
61
 
30
62
  FunctionConfigT = typing.TypeVar("FunctionConfigT", bound=FunctionBaseConfig)
63
+
64
+ FunctionGroupConfigT = typing.TypeVar("FunctionGroupConfigT", bound=FunctionGroupBaseConfig)
@@ -23,6 +23,7 @@ class FunctionDependencies(BaseModel):
23
23
  A class to represent the dependencies of a function.
24
24
  """
25
25
  functions: set[str] = Field(default_factory=set)
26
+ function_groups: set[str] = Field(default_factory=set)
26
27
  llms: set[str] = Field(default_factory=set)
27
28
  embedders: set[str] = Field(default_factory=set)
28
29
  memory_clients: set[str] = Field(default_factory=set)
@@ -33,6 +34,10 @@ class FunctionDependencies(BaseModel):
33
34
  def serialize_functions(self, v: set[str]) -> list[str]:
34
35
  return list(v)
35
36
 
37
+ @field_serializer("function_groups", when_used="json")
38
+ def serialize_function_groups(self, v: set[str]) -> list[str]:
39
+ return list(v)
40
+
36
41
  @field_serializer("llms", when_used="json")
37
42
  def serialize_llms(self, v: set[str]) -> list[str]:
38
43
  return list(v)
@@ -56,6 +61,9 @@ class FunctionDependencies(BaseModel):
56
61
  def add_function(self, function: str):
57
62
  self.functions.add(function)
58
63
 
64
+ def add_function_group(self, function_group: str):
65
+ self.function_groups.add(function_group) # pylint: disable=no-member
66
+
59
67
  def add_llm(self, llm: str):
60
68
  self.llms.add(llm)
61
69
 
@@ -103,11 +103,19 @@ class ToolSchema(BaseModel):
103
103
  function: ToolDetails = Field(..., description="The function details.")
104
104
 
105
105
 
106
+ class ServerToolUseSchema(BaseModel):
107
+ name: str
108
+ arguments: str | dict[str, typing.Any] | typing.Any
109
+ output: typing.Any
110
+
111
+ model_config = ConfigDict(extra="ignore")
112
+
113
+
106
114
  class TraceMetadata(BaseModel):
107
115
  chat_responses: typing.Any | None = None
108
116
  chat_inputs: typing.Any | None = None
109
117
  tool_inputs: typing.Any | None = None
110
- tool_outputs: typing.Any | None = None
118
+ tool_outputs: list[ServerToolUseSchema] | typing.Any | None = None
111
119
  tool_info: typing.Any | None = None
112
120
  span_inputs: typing.Any | None = None
113
121
  span_outputs: typing.Any | None = None
nat/data_models/llm.py CHANGED
@@ -14,14 +14,28 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import typing
17
+ from enum import Enum
18
+
19
+ from pydantic import Field
17
20
 
18
21
  from .common import BaseModelRegistryTag
19
22
  from .common import TypedBaseModel
20
23
 
21
24
 
25
+ class APITypeEnum(str, Enum):
26
+ CHAT_COMPLETION = "chat_completion"
27
+ RESPONSES = "responses"
28
+
29
+
22
30
  class LLMBaseConfig(TypedBaseModel, BaseModelRegistryTag):
23
31
  """Base configuration for LLM providers."""
24
- pass
32
+
33
+ api_type: APITypeEnum = Field(default=APITypeEnum.CHAT_COMPLETION,
34
+ description="The type of API to use for the LLM provider.",
35
+ json_schema_extra={
36
+ "enum": [e.value for e in APITypeEnum],
37
+ "examples": [e.value for e in APITypeEnum],
38
+ })
25
39
 
26
40
 
27
41
  LLMBaseConfigT = typing.TypeVar("LLMBaseConfigT", bound=LLMBaseConfig)
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import ConfigDict
20
+ from pydantic import Field
21
+
22
+
23
+ class MCPApprovalRequiredEnum(str, Enum):
24
+ """
25
+ Enum to specify if approval is required for tool usage in the OpenAI MCP schema.
26
+ """
27
+ NEVER = "never"
28
+ ALWAYS = "always"
29
+ AUTO = "auto"
30
+
31
+
32
+ class OpenAIMCPSchemaTool(BaseModel):
33
+ """
34
+ Represents a tool in the OpenAI MCP schema.
35
+ """
36
+ type: str = "mcp"
37
+ server_label: str = Field(description="Label for the server where the tool is hosted.")
38
+ server_url: str = Field(description="URL of the server hosting the tool.")
39
+ allowed_tools: list[str] | None = Field(default=None,
40
+ description="List of allowed tool names that can be used by the agent.")
41
+ require_approval: MCPApprovalRequiredEnum = Field(default=MCPApprovalRequiredEnum.NEVER,
42
+ description="Specifies if approval is required for tool usage.")
43
+ headers: dict[str, str] | None = Field(default=None,
44
+ description="Optional headers to include in requests to the tool server.")
45
+
46
+ model_config = ConfigDict(use_enum_values=True)
@@ -0,0 +1,208 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from typing import Any
18
+ from typing import Generic
19
+ from typing import TypeVar
20
+
21
+ import numpy as np
22
+ from optuna import Trial
23
+ from pydantic import BaseModel
24
+ from pydantic import ConfigDict
25
+ from pydantic import Field
26
+ from pydantic import model_validator
27
+ from pydantic_core import PydanticUndefined
28
+
29
+ T = TypeVar("T", int, float, bool, str)
30
+
31
+
32
+ # --------------------------------------------------------------------- #
33
+ # 1. Hyper‑parameter metadata container #
34
+ # --------------------------------------------------------------------- #
35
+ class SearchSpace(BaseModel, Generic[T]):
36
+ values: Sequence[T] | None = None
37
+ low: T | None = None
38
+ high: T | None = None
39
+ log: bool = False # log scale
40
+ step: float | None = None
41
+ is_prompt: bool = False
42
+ prompt: str | None = None # prompt to optimize
43
+ prompt_purpose: str | None = None # purpose of the prompt
44
+
45
+ model_config = ConfigDict(protected_namespaces=(), extra="forbid")
46
+
47
+ @model_validator(mode="after")
48
+ def validate_search_space_parameters(self):
49
+ """Validate SearchSpace configuration."""
50
+ # 1. Prompt-specific validation
51
+ if self.is_prompt:
52
+ # When optimizing prompts, numeric parameters don't make sense
53
+ if self.low is not None or self.high is not None:
54
+ raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'low' or 'high' parameters")
55
+ if self.log:
56
+ raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'log=True'")
57
+ if self.step is not None:
58
+ raise ValueError("SearchSpace with 'is_prompt=True' cannot have 'step' parameter")
59
+ return self
60
+
61
+ # 2. Values-based validation
62
+ if self.values is not None:
63
+ # If values is provided, we don't need high/low
64
+ if self.high is not None or self.low is not None:
65
+ raise ValueError("SearchSpace 'values' is mutually exclusive with 'high' and 'low'")
66
+ # Ensure values is not empty
67
+ if len(self.values) == 0:
68
+ raise ValueError("SearchSpace 'values' must not be empty")
69
+ return self
70
+
71
+ # 3. Range-based validation
72
+ if (self.low is None) != (self.high is None): # XOR using !=
73
+ raise ValueError(f"SearchSpace range requires both 'low' and 'high'; got low={self.low}, high={self.high}")
74
+ if self.low is not None and self.high is not None and self.low >= self.high:
75
+ raise ValueError(f"SearchSpace 'low' must be less than 'high'; got low={self.low}, high={self.high}")
76
+
77
+ return self
78
+
79
+ # Helper for Optuna Trials
80
+ def suggest(self, trial: Trial, name: str):
81
+ if self.is_prompt:
82
+ raise ValueError("Prompt optimization not currently supported using Optuna. "
83
+ "Use the genetic algorithm implementation instead.")
84
+ if self.values is not None:
85
+ return trial.suggest_categorical(name, self.values)
86
+ if isinstance(self.low, int):
87
+ return trial.suggest_int(name, self.low, self.high, log=self.log, step=self.step)
88
+ return trial.suggest_float(name, self.low, self.high, log=self.log, step=self.step)
89
+
90
+ def to_grid_values(self) -> list[Any]:
91
+ """
92
+ Convert SearchSpace to a list of values for GridSampler.
93
+
94
+ Grid search requires explicit values. This can be provided in two ways:
95
+ 1. Explicit values: SearchSpace(values=[0.1, 0.5, 0.9])
96
+ 2. Range with step: SearchSpace(low=0.1, high=0.9, step=0.2)
97
+
98
+ For ranges, step is required (no default will be applied) to avoid
99
+ unintentional combinatorial explosion.
100
+ """
101
+
102
+ if self.is_prompt:
103
+ raise ValueError("Prompt optimization not currently supported using Optuna. "
104
+ "Use the genetic algorithm implementation instead.")
105
+
106
+ # Option 1: Explicit values provided
107
+ if self.values is not None:
108
+ return list(self.values)
109
+
110
+ # Option 2: Range with required step
111
+ if self.low is None or self.high is None:
112
+ raise ValueError("Grid search requires either 'values' or both 'low' and 'high' to be defined")
113
+
114
+ if self.step is None:
115
+ raise ValueError(
116
+ f"Grid search with range (low={self.low}, high={self.high}) requires 'step' to be specified. "
117
+ "Please define the step size to discretize the range, for example: step=0.1")
118
+
119
+ # Validate step is positive
120
+ step_float = float(self.step)
121
+ if step_float <= 0:
122
+ raise ValueError(f"Grid search step must be positive; got step={self.step}")
123
+
124
+ # Generate grid values from range with step
125
+ # Use integer range only if low, high, and step are all integral
126
+ if (isinstance(self.low, int) and isinstance(self.high, int) and step_float.is_integer()):
127
+ step = int(step_float)
128
+
129
+ if self.log:
130
+ raise ValueError("Log scale is not supported for integer ranges in grid search. "
131
+ "Please use linear scale or provide explicit values.")
132
+ values = list(range(self.low, self.high + 1, step))
133
+ if values and values[-1] != self.high:
134
+ values.append(self.high)
135
+ return values
136
+
137
+ # Float range (including integer low/high with float step)
138
+ low_val = float(self.low)
139
+ high_val = float(self.high)
140
+ step_val = step_float
141
+
142
+ if self.log:
143
+ raise ValueError("Log scale is not yet supported for grid search with ranges. "
144
+ "Please provide explicit values using the 'values' field.")
145
+
146
+ # Use arange to respect step size
147
+ values = np.arange(low_val, high_val, step_val).tolist()
148
+
149
+ # Always include the high endpoint if not already present (within tolerance)
150
+ # This ensures the full range is explored in grid search
151
+ if not values or abs(values[-1] - high_val) > 1e-9:
152
+ values.append(high_val)
153
+
154
+ return values
155
+
156
+
157
+ def OptimizableField(
158
+ default: Any = PydanticUndefined,
159
+ *,
160
+ space: SearchSpace | None = None,
161
+ merge_conflict: str = "overwrite",
162
+ **fld_kw,
163
+ ):
164
+ # 1. Pull out any user‑supplied extras (must be a dict)
165
+ user_extra = fld_kw.pop("json_schema_extra", None) or {}
166
+ if not isinstance(user_extra, dict):
167
+ raise TypeError("`json_schema_extra` must be a mapping.")
168
+
169
+ # 2. If the space is a prompt, ensure a concrete base prompt exists
170
+ if space is not None and getattr(space, "is_prompt", False):
171
+ if getattr(space, "prompt", None) is None:
172
+ if default is None:
173
+ raise ValueError("Prompt-optimized fields require a base prompt: provide a "
174
+ "non-None field default or set space.prompt.")
175
+ # Default prompt not provided in space; fall back to the field's default
176
+ space.prompt = default
177
+
178
+ # 3. Prepare our own metadata
179
+ ours = {"optimizable": True}
180
+ if space is not None:
181
+ ours["search_space"] = space
182
+
183
+ # 4. Merge with user extras according to merge_conflict policy
184
+ intersect = ours.keys() & user_extra.keys()
185
+ if intersect:
186
+ if merge_conflict == "error":
187
+ raise ValueError("`json_schema_extra` already contains reserved key(s): "
188
+ f"{', '.join(intersect)}")
189
+ if merge_conflict == "keep":
190
+ # remove the ones the user already set so we don't overwrite them
191
+ ours = {k: v for k, v in ours.items() if k not in intersect}
192
+
193
+ merged_extra = {**user_extra, **ours} # ours wins if 'overwrite'
194
+
195
+ # 5. Return a normal Pydantic Field with merged extras
196
+ return Field(default, json_schema_extra=merged_extra, **fld_kw)
197
+
198
+
199
+ class OptimizableMixin(BaseModel):
200
+ optimizable_params: list[str] = Field(default_factory=list,
201
+ description="List of parameters that can be optimized.",
202
+ exclude=True)
203
+
204
+ search_space: dict[str, SearchSpace] = Field(
205
+ default_factory=dict,
206
+ description="Optional search space overrides for optimizable parameters.",
207
+ exclude=True,
208
+ )