nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. aiq/__init__.py +66 -0
  2. nat/agent/base.py +16 -0
  3. nat/agent/react_agent/agent.py +38 -13
  4. nat/agent/react_agent/prompt.py +4 -1
  5. nat/agent/react_agent/register.py +1 -1
  6. nat/agent/register.py +0 -1
  7. nat/agent/rewoo_agent/agent.py +6 -3
  8. nat/agent/rewoo_agent/prompt.py +3 -0
  9. nat/agent/rewoo_agent/register.py +4 -3
  10. nat/agent/tool_calling_agent/agent.py +92 -22
  11. nat/agent/tool_calling_agent/register.py +9 -13
  12. nat/authentication/api_key/api_key_auth_provider.py +1 -1
  13. nat/authentication/register.py +0 -1
  14. nat/builder/builder.py +1 -1
  15. nat/builder/context.py +9 -1
  16. nat/builder/function_base.py +3 -3
  17. nat/builder/function_info.py +5 -7
  18. nat/builder/user_interaction_manager.py +2 -2
  19. nat/builder/workflow.py +3 -0
  20. nat/builder/workflow_builder.py +0 -1
  21. nat/cli/commands/evaluate.py +1 -1
  22. nat/cli/commands/info/list_components.py +7 -8
  23. nat/cli/commands/info/list_mcp.py +3 -4
  24. nat/cli/commands/registry/search.py +14 -16
  25. nat/cli/commands/start.py +0 -1
  26. nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
  27. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  28. nat/cli/commands/workflow/workflow_commands.py +0 -1
  29. nat/cli/type_registry.py +7 -9
  30. nat/data_models/config.py +1 -1
  31. nat/data_models/evaluate.py +1 -1
  32. nat/data_models/function_dependencies.py +6 -6
  33. nat/data_models/intermediate_step.py +3 -3
  34. nat/data_models/model_gated_field_mixin.py +125 -0
  35. nat/data_models/swe_bench_model.py +1 -1
  36. nat/data_models/temperature_mixin.py +36 -0
  37. nat/data_models/top_p_mixin.py +36 -0
  38. nat/embedder/azure_openai_embedder.py +46 -0
  39. nat/embedder/openai_embedder.py +1 -2
  40. nat/embedder/register.py +1 -1
  41. nat/eval/config.py +2 -0
  42. nat/eval/dataset_handler/dataset_handler.py +5 -6
  43. nat/eval/evaluate.py +64 -20
  44. nat/eval/rag_evaluator/register.py +2 -2
  45. nat/eval/register.py +0 -1
  46. nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
  47. nat/eval/utils/eval_trace_ctx.py +89 -0
  48. nat/eval/utils/weave_eval.py +14 -7
  49. nat/experimental/test_time_compute/models/strategy_base.py +3 -2
  50. nat/experimental/test_time_compute/register.py +0 -1
  51. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
  52. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
  53. nat/front_ends/fastapi/message_handler.py +13 -14
  54. nat/front_ends/fastapi/message_validator.py +4 -4
  55. nat/front_ends/fastapi/step_adaptor.py +1 -1
  56. nat/front_ends/register.py +0 -1
  57. nat/llm/aws_bedrock_llm.py +3 -3
  58. nat/llm/azure_openai_llm.py +49 -0
  59. nat/llm/nim_llm.py +4 -4
  60. nat/llm/openai_llm.py +4 -4
  61. nat/llm/register.py +1 -1
  62. nat/llm/utils/env_config_value.py +2 -3
  63. nat/meta/pypi.md +9 -9
  64. nat/object_store/models.py +2 -0
  65. nat/object_store/register.py +0 -1
  66. nat/observability/exporter/base_exporter.py +1 -1
  67. nat/observability/exporter/file_exporter.py +1 -1
  68. nat/observability/register.py +3 -3
  69. nat/profiler/callbacks/langchain_callback_handler.py +9 -2
  70. nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
  71. nat/profiler/data_frame_row.py +1 -1
  72. nat/profiler/decorators/framework_wrapper.py +1 -4
  73. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  74. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  75. nat/profiler/inference_optimization/data_models.py +3 -3
  76. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  77. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  78. nat/profiler/profile_runner.py +13 -8
  79. nat/registry_handlers/package_utils.py +0 -1
  80. nat/registry_handlers/pypi/pypi_handler.py +20 -23
  81. nat/registry_handlers/register.py +3 -4
  82. nat/registry_handlers/rest/rest_handler.py +8 -9
  83. nat/retriever/register.py +0 -1
  84. nat/runtime/session.py +23 -8
  85. nat/settings/global_settings.py +13 -2
  86. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
  87. nat/tool/datetime_tools.py +49 -9
  88. nat/tool/document_search.py +1 -1
  89. nat/tool/mcp/mcp_tool.py +1 -1
  90. nat/tool/register.py +0 -1
  91. nat/utils/data_models/schema_validator.py +2 -2
  92. nat/utils/exception_handlers/automatic_retries.py +0 -2
  93. nat/utils/exception_handlers/schemas.py +1 -1
  94. nat/utils/reactive/base/observable_base.py +2 -2
  95. nat/utils/reactive/base/observer_base.py +1 -1
  96. nat/utils/reactive/observable.py +2 -2
  97. nat/utils/reactive/observer.py +2 -2
  98. nat/utils/reactive/subscription.py +1 -1
  99. nat/utils/settings/global_settings.py +4 -6
  100. nat/utils/type_utils.py +4 -4
  101. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
  102. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
  103. nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
  104. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
  105. nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
  106. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
  107. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
  108. {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
@@ -231,7 +231,7 @@ class FunctionDescriptor:
231
231
  else:
232
232
  annotations = [param.annotation for param in sig.parameters.values()]
233
233
 
234
- is_input_typed = all([a != sig.empty for a in annotations]) # pylint: disable=use-a-generator
234
+ is_input_typed = all([a != sig.empty for a in annotations])
235
235
 
236
236
  input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
237
237
 
@@ -372,18 +372,16 @@ class FunctionInfo:
372
372
 
373
373
  if (stream_to_single_fn is not None):
374
374
  raise ValueError("Cannot provide both single_fn and stream_to_single_fn")
375
- else:
376
- if (stream_to_single_fn is not None and stream_fn is None):
377
- raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
375
+ elif (stream_to_single_fn is not None and stream_fn is None):
376
+ raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
378
377
 
379
378
  if (stream_fn is not None):
380
379
  final_stream_fn = stream_fn
381
380
 
382
381
  if (single_to_stream_fn is not None):
383
382
  raise ValueError("Cannot provide both stream_fn and single_to_stream_fn")
384
- else:
385
- if (single_to_stream_fn is not None and single_fn is None):
386
- raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
383
+ elif (single_to_stream_fn is not None and single_fn is None):
384
+ raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
387
385
 
388
386
  if (single_fn is None and stream_fn is None):
389
387
  raise ValueError("At least one of single_fn or stream_fn must be provided")
@@ -61,13 +61,13 @@ class UserInteractionManager:
61
61
 
62
62
  uuid_req = str(uuid.uuid4())
63
63
  status = InteractionStatus.IN_PROGRESS
64
- timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ")
64
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
65
65
  sys_human_interaction = InteractionPrompt(id=uuid_req, status=status, timestamp=timestamp, content=content)
66
66
 
67
67
  resp = await self._context_state.user_input_callback.get()(sys_human_interaction)
68
68
 
69
69
  # Rebuild a InteractionResponse object with the response
70
- timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ")
70
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
71
71
  status = InteractionStatus.COMPLETED
72
72
  sys_human_interaction = InteractionResponse(id=uuid_req, status=status, timestamp=timestamp, content=resp)
73
73
 
nat/builder/workflow.py CHANGED
@@ -83,6 +83,9 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
83
83
 
84
84
  return self._entry_fn.has_single_output
85
85
 
86
+ async def get_all_exporters(self) -> dict[str, BaseExporter]:
87
+ return await self._exporter_manager.get_all_exporters()
88
+
86
89
  @asynccontextmanager
87
90
  async def run(self, message: InputT):
88
91
  """
@@ -127,7 +127,6 @@ class ConfiguredTTCStrategy:
127
127
  instance: StrategyBase
128
128
 
129
129
 
130
- # pylint: disable=too-many-public-methods
131
130
  class WorkflowBuilder(Builder, AbstractAsyncContextManager):
132
131
 
133
132
  def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
@@ -97,7 +97,7 @@ async def run_and_evaluate(config: EvaluationRunConfig):
97
97
 
98
98
  @eval_command.result_callback(replace=True)
99
99
  def process_nat_eval(
100
- processors, # pylint: disable=unused-argument
100
+ processors,
101
101
  *,
102
102
  config_file: Path,
103
103
  dataset: Path,
@@ -26,14 +26,13 @@ from nat.registry_handlers.schemas.search import SearchFields
26
26
  logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
- async def search_artifacts( # pylint: disable=R0917
30
- registry_handler_config: RegistryHandlerBaseConfig,
31
- component_types: list[ComponentEnum],
32
- visualize: bool,
33
- query: str,
34
- num_results: int,
35
- query_fields: list[SearchFields],
36
- save_path: str | None) -> None:
29
+ async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
30
+ component_types: list[ComponentEnum],
31
+ visualize: bool,
32
+ query: str,
33
+ num_results: int,
34
+ query_fields: list[SearchFields],
35
+ save_path: str | None) -> None:
37
36
 
38
37
  from nat.cli.type_registry import GlobalTypeRegistry
39
38
  from nat.registry_handlers.schemas.search import SearchQuery
@@ -297,8 +297,7 @@ def ping(url: str, timeout: int, json_output: bool) -> None:
297
297
 
298
298
  if json_output:
299
299
  click.echo(result.model_dump_json(indent=2))
300
+ elif result.status == "healthy":
301
+ click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
300
302
  else:
301
- if result.status == "healthy":
302
- click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
303
- else:
304
- click.echo(f"Server at {result.url} {result.status}: {result.error}")
303
+ click.echo(f"Server at {result.url} {result.status}: {result.error}")
@@ -29,14 +29,13 @@ from nat.utils.data_models.schema_validator import validate_yaml
29
29
  logger = logging.getLogger(__name__)
30
30
 
31
31
 
32
- async def search_artifacts( # pylint: disable=R0917
33
- registry_handler_config: RegistryHandlerBaseConfig,
34
- query: str,
35
- search_fields: list[SearchFields],
36
- visualize: bool,
37
- component_types: list[ComponentEnum],
38
- save_path: str | None = None,
39
- n_results: int = 10) -> None:
32
+ async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
33
+ query: str,
34
+ search_fields: list[SearchFields],
35
+ visualize: bool,
36
+ component_types: list[ComponentEnum],
37
+ save_path: str | None = None,
38
+ n_results: int = 10) -> None:
40
39
 
41
40
  from nat.cli.type_registry import GlobalTypeRegistry
42
41
  from nat.registry_handlers.schemas.search import SearchQuery
@@ -116,14 +115,13 @@ async def search_artifacts( # pylint: disable=R0917
116
115
  required=False,
117
116
  help=("The component types to include in search."),
118
117
  )
119
- def search( # pylint: disable=R0917
120
- config_file: str,
121
- channel: str,
122
- fields: list[str],
123
- query: str,
124
- component_types: list[ComponentEnum],
125
- n_results: int,
126
- output_path: str) -> None:
118
+ def search(config_file: str,
119
+ channel: str,
120
+ fields: list[str],
121
+ query: str,
122
+ component_types: list[ComponentEnum],
123
+ n_results: int,
124
+ output_path: str) -> None:
127
125
  """
128
126
  Search for NAT artifacts with the specified configuration.
129
127
  """
nat/cli/commands/start.py CHANGED
@@ -35,7 +35,6 @@ logger = logging.getLogger(__name__)
35
35
 
36
36
  class StartCommandGroup(click.Group):
37
37
 
38
- # pylint: disable=too-many-positional-arguments
39
38
  def __init__(
40
39
  self,
41
40
  name: str | None = None,
@@ -3,6 +3,9 @@ build-backend = "setuptools.build_meta"
3
3
  {% if editable %}requires = ["setuptools >= 64", "setuptools-scm>=8"]
4
4
 
5
5
  [tool.setuptools_scm]
6
+ # NAT uses the --first-parent flag to avoid tags from previous releases which have been merged into the develop branch
7
+ # from causing an unexpected version change. This can be safely removed if developing outside of the NAT repository.
8
+ git_describe_command = "git describe --long --first-parent"
6
9
  root = "{{ rel_path_to_repo_root}}"{% else %}requires = ["setuptools >= 64"]{% endif %}
7
10
 
8
11
  [project]
@@ -1,4 +1,3 @@
1
- # pylint: disable=unused-import
2
1
  # flake8: noqa
3
2
 
4
3
  # Import any tools which need to be automatically registered here
@@ -161,7 +161,6 @@ def get_workflow_path_from_name(workflow_name: str):
161
161
  default="NAT function template. Please update the description.",
162
162
  help="""A description of the component being created. Will be used to populate the docstring and will describe the
163
163
  component when inspecting installed components using 'nat info component'""")
164
- # pylint: disable=missing-param-doc
165
164
  def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
166
165
  """
167
166
  Create a new NAT workflow using templates.
nat/cli/type_registry.py CHANGED
@@ -298,7 +298,7 @@ class RegisteredPackage(BaseModel):
298
298
  discovery_metadata: DiscoveryMetadata
299
299
 
300
300
 
301
- class TypeRegistry: # pylint: disable=too-many-public-methods
301
+ class TypeRegistry:
302
302
 
303
303
  def __init__(self) -> None:
304
304
  # Telemetry Exporters
@@ -588,8 +588,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
588
588
  except KeyError as err:
589
589
  raise KeyError(
590
590
  f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
591
- "Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
592
- "there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
591
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
592
+ f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
593
593
  "Please provide an Embedder configuration from one of the following providers: "
594
594
  f"{set(self._embedder_client_provider_to_framework.keys())}") from err
595
595
 
@@ -703,8 +703,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
703
703
  except KeyError as err:
704
704
  raise KeyError(
705
705
  f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
706
- "Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
707
- "there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
706
+ f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
707
+ f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
708
708
  "Please provide a Retriever configuration from one of the following providers: "
709
709
  f"{set(self._retriever_client_provider_to_framework.keys())}") from err
710
710
 
@@ -779,7 +779,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
779
779
 
780
780
  self._registration_changed()
781
781
 
782
- def get_infos_by_type(self, component_type: ComponentEnum) -> dict: # pylint: disable=R0911
782
+ def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
783
783
 
784
784
  if component_type == ComponentEnum.FRONT_END:
785
785
  return self._registered_front_end_infos
@@ -849,8 +849,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
849
849
 
850
850
  raise ValueError(f"Supplied an unsupported component type {component_type}")
851
851
 
852
- def get_registered_types_by_component_type( # pylint: disable=R0911
853
- self, component_type: ComponentEnum) -> list[str]:
852
+ def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
854
853
 
855
854
  if component_type == ComponentEnum.FUNCTION:
856
855
  return [i.static_type() for i in self._registered_functions]
@@ -925,7 +924,6 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
925
924
  if (short_names[key.local_name] == 1):
926
925
  type_list.append((key.local_name, key.config_type))
927
926
 
928
- # pylint: disable=consider-alternative-union-syntax
929
927
  return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
930
928
 
931
929
  def compute_annotation(self, cls: type[TypedBaseModelT]):
nat/data_models/config.py CHANGED
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
47
47
 
48
48
 
49
49
  def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
50
- from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
50
+ from nat.cli.type_registry import GlobalTypeRegistry
51
51
 
52
52
  new_errors = []
53
53
  logged_once = False
@@ -108,7 +108,7 @@ class EvalConfig(BaseModel):
108
108
  @classmethod
109
109
  def rebuild_annotations(cls):
110
110
 
111
- from nat.cli.type_registry import GlobalTypeRegistry # pylint: disable=cyclic-import
111
+ from nat.cli.type_registry import GlobalTypeRegistry
112
112
 
113
113
  type_registry = GlobalTypeRegistry.get()
114
114
 
@@ -54,19 +54,19 @@ class FunctionDependencies(BaseModel):
54
54
  return list(v)
55
55
 
56
56
  def add_function(self, function: str):
57
- self.functions.add(function) # pylint: disable=no-member
57
+ self.functions.add(function)
58
58
 
59
59
  def add_llm(self, llm: str):
60
- self.llms.add(llm) # pylint: disable=no-member
60
+ self.llms.add(llm)
61
61
 
62
62
  def add_embedder(self, embedder: str):
63
- self.embedders.add(embedder) # pylint: disable=no-member
63
+ self.embedders.add(embedder)
64
64
 
65
65
  def add_memory_client(self, memory_client: str):
66
- self.memory_clients.add(memory_client) # pylint: disable=no-member
66
+ self.memory_clients.add(memory_client)
67
67
 
68
68
  def add_object_store(self, object_store: str):
69
- self.object_stores.add(object_store) # pylint: disable=no-member
69
+ self.object_stores.add(object_store)
70
70
 
71
71
  def add_retriever(self, retriever: str):
72
- self.retrievers.add(retriever) # pylint: disable=no-member
72
+ self.retrievers.add(retriever)
@@ -142,7 +142,7 @@ class IntermediateStepPayload(BaseModel):
142
142
  UUID: str = Field(default_factory=lambda: str(uuid.uuid4()))
143
143
 
144
144
  @property
145
- def event_category(self) -> IntermediateStepCategory: # pylint: disable=too-many-return-statements
145
+ def event_category(self) -> IntermediateStepCategory:
146
146
  match self.event_type:
147
147
  case IntermediateStepType.LLM_START:
148
148
  return IntermediateStepCategory.LLM
@@ -180,7 +180,7 @@ class IntermediateStepPayload(BaseModel):
180
180
  raise ValueError(f"Unknown event type: {self.event_type}")
181
181
 
182
182
  @property
183
- def event_state(self) -> IntermediateStepState: # pylint: disable=too-many-return-statements
183
+ def event_state(self) -> IntermediateStepState:
184
184
  match self.event_type:
185
185
  case IntermediateStepType.LLM_START:
186
186
  return IntermediateStepState.START
@@ -290,7 +290,7 @@ class IntermediateStep(BaseModel):
290
290
  return self.payload.usage_info
291
291
 
292
292
  @property
293
- def UUID(self) -> str: # pylint: disable=invalid-name
293
+ def UUID(self) -> str:
294
294
  return self.payload.UUID
295
295
 
296
296
  @property
@@ -0,0 +1,125 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import Sequence
17
+ from re import Pattern
18
+ from typing import Generic
19
+ from typing import TypeVar
20
+
21
+ from pydantic import BaseModel
22
+ from pydantic import model_validator
23
+
24
+ T = TypeVar("T")
25
+
26
+
27
+ class ModelGatedFieldMixin(Generic[T]):
28
+ """
29
+ A mixin that gates a field based on model support.
30
+
31
+ This should be used to automatically validate a field based on a given model.
32
+
33
+ Parameters
34
+ ----------
35
+ field_name: `str`
36
+ The name of the field.
37
+ default_if_supported: `T`
38
+ The default value of the field if it is supported for the model.
39
+ unsupported_models: `Sequence[Pattern[str]] | None`
40
+ A sequence of regex patterns that match the model names NOT supported for the field.
41
+ Defaults to None.
42
+ supported_models: `Sequence[Pattern[str]] | None`
43
+ A sequence of regex patterns that match the model names supported for the field.
44
+ Defaults to None.
45
+ model_keys: `Sequence[str]`
46
+ A sequence of keys that are used to validate the field.
47
+ Defaults to ("model_name", "model", "azure_deployment",)
48
+ """
49
+
50
+ def __init_subclass__(
51
+ cls,
52
+ field_name: str | None = None,
53
+ default_if_supported: T | None = None,
54
+ unsupported_models: Sequence[Pattern[str]] | None = None,
55
+ supported_models: Sequence[Pattern[str]] | None = None,
56
+ model_keys: Sequence[str] = ("model_name", "model", "azure_deployment"),
57
+ ) -> None:
58
+ """
59
+ Store the class variables for the field and define the model validator.
60
+ """
61
+ super().__init_subclass__()
62
+ if ModelGatedFieldMixin in cls.__bases__:
63
+ if field_name is None:
64
+ raise ValueError("field_name must be provided when subclassing ModelGatedFieldMixin")
65
+ if default_if_supported is None:
66
+ raise ValueError("default_if_supported must be provided when subclassing ModelGatedFieldMixin")
67
+ if unsupported_models is None and supported_models is None:
68
+ raise ValueError("Either unsupported_models or supported_models must be provided")
69
+ if unsupported_models is not None and supported_models is not None:
70
+ raise ValueError("Only one of unsupported_models or supported_models must be provided")
71
+ if model_keys is not None and len(model_keys) == 0:
72
+ raise ValueError("model_keys must be provided and non-empty when subclassing ModelGatedFieldMixin")
73
+ cls.field_name = field_name
74
+ cls.default_if_supported = default_if_supported
75
+ cls.unsupported_models = unsupported_models
76
+ cls.supported_models = supported_models
77
+ if model_keys is not None:
78
+ cls.model_keys = model_keys
79
+
80
+ @classmethod
81
+ def check_model(cls, model_name: str) -> bool:
82
+ """
83
+ Check if a model is supported for a given field.
84
+
85
+ Args:
86
+ model_name: The name of the model to check.
87
+ """
88
+ unsupported = getattr(cls, "unsupported_models", None)
89
+ supported = getattr(cls, "supported_models", None)
90
+ if unsupported is not None:
91
+ return not any(p.search(model_name) for p in unsupported)
92
+ if supported is not None:
93
+ return any(p.search(model_name) for p in supported)
94
+ return False
95
+
96
+ cls._model_gated_field_check_model = check_model
97
+
98
+ @classmethod
99
+ def detect_support(cls, instance: BaseModel) -> str | None:
100
+ for key in getattr(cls, "model_keys"):
101
+ if hasattr(instance, key):
102
+ model_name_value = getattr(instance, key)
103
+ is_supported = getattr(cls, "_model_gated_field_check_model")(str(model_name_value))
104
+ return key if not is_supported else None
105
+ return None
106
+
107
+ cls._model_gated_field_detect_support = detect_support
108
+
109
+ @model_validator(mode="after")
110
+ def model_validate(self):
111
+ klass = self.__class__
112
+
113
+ field_name_local = getattr(klass, "field_name")
114
+ current_value = getattr(self, field_name_local, None)
115
+
116
+ found_key = klass._model_gated_field_detect_support(self)
117
+ if found_key is not None:
118
+ if current_value is not None:
119
+ raise ValueError(
120
+ f"{field_name_local} is not supported for {found_key}: {getattr(self, found_key)}")
121
+ elif current_value is None:
122
+ setattr(self, field_name_local, getattr(klass, "default_if_supported", None))
123
+ return self
124
+
125
+ cls._model_gated_field_model_validator = model_validate
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
39
39
 
40
40
  # Handle improperly formatted JSON strings for list fields
41
41
  @field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
42
- def parse_list_fields(cls, value): # pylint: disable=no-self-argument
42
+ def parse_list_fields(cls, value):
43
43
  if isinstance(value, str):
44
44
  # Attempt to parse the string as a list
45
45
  return json.loads(value)
@@ -0,0 +1,36 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
+
23
+ _UNSUPPORTED_TEMPERATURE_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
24
+
25
+
26
+ class TemperatureMixin(
27
+ BaseModel,
28
+ ModelGatedFieldMixin[float],
29
+ field_name="temperature",
30
+ default_if_supported=0.0,
31
+ unsupported_models=_UNSUPPORTED_TEMPERATURE_MODELS,
32
+ ):
33
+ """
34
+ Mixin class for temperature configuration.
35
+ """
36
+ temperature: float | None = Field(default=None, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
@@ -0,0 +1,36 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
22
+
23
+ _UNSUPPORTED_TOP_P_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
24
+
25
+
26
+ class TopPMixin(
27
+ BaseModel,
28
+ ModelGatedFieldMixin[float],
29
+ field_name="top_p",
30
+ default_if_supported=1.0,
31
+ unsupported_models=_UNSUPPORTED_TOP_P_MODELS,
32
+ ):
33
+ """
34
+ Mixin class for top-p configuration.
35
+ """
36
+ top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Top-p for distribution sampling.")
@@ -0,0 +1,46 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import AliasChoices
17
+ from pydantic import ConfigDict
18
+ from pydantic import Field
19
+
20
+ from nat.builder.builder import Builder
21
+ from nat.builder.embedder import EmbedderProviderInfo
22
+ from nat.cli.register_workflow import register_embedder_provider
23
+ from nat.data_models.embedder import EmbedderBaseConfig
24
+ from nat.data_models.retry_mixin import RetryMixin
25
+
26
+
27
+ class AzureOpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="azure_openai"):
28
+ """An Azure OpenAI embedder provider to be used with an embedder client."""
29
+
30
+ model_config = ConfigDict(protected_namespaces=(), extra="allow")
31
+
32
+ api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
33
+ api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
34
+ azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
35
+ serialization_alias="azure_endpoint",
36
+ default=None,
37
+ description="Base URL for the hosted model.")
38
+ azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
39
+ serialization_alias="azure_deployment",
40
+ description="The Azure OpenAI hosted model/deployment name.")
41
+
42
+
43
+ @register_embedder_provider(config_type=AzureOpenAIEmbedderModelConfig)
44
+ async def azure_openai_embedder_model(config: AzureOpenAIEmbedderModelConfig, _builder: Builder):
45
+
46
+ yield EmbedderProviderInfo(config=config, description="An Azure OpenAI model for use with an Embedder client.")
@@ -34,10 +34,9 @@ class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
34
34
  model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
35
35
  serialization_alias="model",
36
36
  description="The OpenAI hosted model name.")
37
- max_retries: int = Field(default=2, description="The max number of retries for the request.")
38
37
 
39
38
 
40
39
  @register_embedder_provider(config_type=OpenAIEmbedderModelConfig)
41
- async def openai_llm(config: OpenAIEmbedderModelConfig, builder: Builder):
40
+ async def openai_embedder_model(config: OpenAIEmbedderModelConfig, _builder: Builder):
42
41
 
43
42
  yield EmbedderProviderInfo(config=config, description="An OpenAI model for use with an Embedder client.")
nat/embedder/register.py CHANGED
@@ -13,10 +13,10 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
20
19
  # Import any providers which need to be automatically registered here
20
+ from . import azure_openai_embedder
21
21
  from . import nim_embedder
22
22
  from . import openai_embedder
nat/eval/config.py CHANGED
@@ -44,6 +44,8 @@ class EvaluationRunConfig(BaseModel):
44
44
  # number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
45
45
  # concurrency. The is only used if adjust_dataset_size is true
46
46
  num_passes: int = 0
47
+ # timeout for waiting for trace export tasks to complete
48
+ export_timeout: float = 60.0
47
49
 
48
50
 
49
51
  class EvaluationRunOutput(BaseModel):
@@ -146,13 +146,12 @@ class DatasetHandler:
146
146
  # When num_passes is specified, always use concurrency * num_passes
147
147
  # This respects the user's intent for exact number of passes
148
148
  target_size = self.concurrency * self.num_passes
149
+ # When num_passes = 0, use the largest multiple of concurrency <= original_size
150
+ # If original_size < concurrency, we need at least concurrency rows
151
+ elif original_size >= self.concurrency:
152
+ target_size = (original_size // self.concurrency) * self.concurrency
149
153
  else:
150
- # When num_passes = 0, use the largest multiple of concurrency <= original_size
151
- # If original_size < concurrency, we need at least concurrency rows
152
- if original_size >= self.concurrency:
153
- target_size = (original_size // self.concurrency) * self.concurrency
154
- else:
155
- target_size = self.concurrency
154
+ target_size = self.concurrency
156
155
 
157
156
  if target_size == 0:
158
157
  raise ValueError("Input dataset too small for even one batch at given concurrency.")