nvidia-nat 1.3a20250819__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +66 -0
- nat/agent/base.py +16 -0
- nat/agent/react_agent/agent.py +38 -13
- nat/agent/react_agent/prompt.py +4 -1
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +6 -3
- nat/agent/rewoo_agent/prompt.py +3 -0
- nat/agent/rewoo_agent/register.py +4 -3
- nat/agent/tool_calling_agent/agent.py +92 -22
- nat/agent/tool_calling_agent/register.py +9 -13
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +3 -0
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +3 -0
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +7 -9
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/openai_embedder.py +1 -2
- nat/embedder/register.py +1 -1
- nat/eval/config.py +2 -0
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +64 -20
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +14 -7
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +49 -0
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +1 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/meta/pypi.md +9 -9
- nat/object_store/models.py +2 -0
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +9 -2
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +13 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +17 -15
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +107 -100
- nvidia_nat-1.3.0a20250823.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +1 -0
- nvidia_nat-1.3a20250819.dist-info/licenses/LICENSE-3rd-party.txt +0 -3686
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3a20250819.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
nat/builder/function_info.py
CHANGED
|
@@ -231,7 +231,7 @@ class FunctionDescriptor:
|
|
|
231
231
|
else:
|
|
232
232
|
annotations = [param.annotation for param in sig.parameters.values()]
|
|
233
233
|
|
|
234
|
-
is_input_typed = all([a != sig.empty for a in annotations])
|
|
234
|
+
is_input_typed = all([a != sig.empty for a in annotations])
|
|
235
235
|
|
|
236
236
|
input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
|
|
237
237
|
|
|
@@ -372,18 +372,16 @@ class FunctionInfo:
|
|
|
372
372
|
|
|
373
373
|
if (stream_to_single_fn is not None):
|
|
374
374
|
raise ValueError("Cannot provide both single_fn and stream_to_single_fn")
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
375
|
+
elif (stream_to_single_fn is not None and stream_fn is None):
|
|
376
|
+
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
378
377
|
|
|
379
378
|
if (stream_fn is not None):
|
|
380
379
|
final_stream_fn = stream_fn
|
|
381
380
|
|
|
382
381
|
if (single_to_stream_fn is not None):
|
|
383
382
|
raise ValueError("Cannot provide both stream_fn and single_to_stream_fn")
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
383
|
+
elif (single_to_stream_fn is not None and single_fn is None):
|
|
384
|
+
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
387
385
|
|
|
388
386
|
if (single_fn is None and stream_fn is None):
|
|
389
387
|
raise ValueError("At least one of single_fn or stream_fn must be provided")
|
|
@@ -61,13 +61,13 @@ class UserInteractionManager:
|
|
|
61
61
|
|
|
62
62
|
uuid_req = str(uuid.uuid4())
|
|
63
63
|
status = InteractionStatus.IN_PROGRESS
|
|
64
|
-
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
64
|
+
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
65
65
|
sys_human_interaction = InteractionPrompt(id=uuid_req, status=status, timestamp=timestamp, content=content)
|
|
66
66
|
|
|
67
67
|
resp = await self._context_state.user_input_callback.get()(sys_human_interaction)
|
|
68
68
|
|
|
69
69
|
# Rebuild a InteractionResponse object with the response
|
|
70
|
-
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
70
|
+
timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
71
71
|
status = InteractionStatus.COMPLETED
|
|
72
72
|
sys_human_interaction = InteractionResponse(id=uuid_req, status=status, timestamp=timestamp, content=resp)
|
|
73
73
|
|
nat/builder/workflow.py
CHANGED
|
@@ -83,6 +83,9 @@ class Workflow(FunctionBase[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
83
83
|
|
|
84
84
|
return self._entry_fn.has_single_output
|
|
85
85
|
|
|
86
|
+
async def get_all_exporters(self) -> dict[str, BaseExporter]:
|
|
87
|
+
return await self._exporter_manager.get_all_exporters()
|
|
88
|
+
|
|
86
89
|
@asynccontextmanager
|
|
87
90
|
async def run(self, message: InputT):
|
|
88
91
|
"""
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -127,7 +127,6 @@ class ConfiguredTTCStrategy:
|
|
|
127
127
|
instance: StrategyBase
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
# pylint: disable=too-many-public-methods
|
|
131
130
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
132
131
|
|
|
133
132
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
nat/cli/commands/evaluate.py
CHANGED
|
@@ -26,14 +26,13 @@ from nat.registry_handlers.schemas.search import SearchFields
|
|
|
26
26
|
logger = logging.getLogger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
async def search_artifacts(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
save_path: str | None) -> None:
|
|
29
|
+
async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
|
|
30
|
+
component_types: list[ComponentEnum],
|
|
31
|
+
visualize: bool,
|
|
32
|
+
query: str,
|
|
33
|
+
num_results: int,
|
|
34
|
+
query_fields: list[SearchFields],
|
|
35
|
+
save_path: str | None) -> None:
|
|
37
36
|
|
|
38
37
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
39
38
|
from nat.registry_handlers.schemas.search import SearchQuery
|
|
@@ -297,8 +297,7 @@ def ping(url: str, timeout: int, json_output: bool) -> None:
|
|
|
297
297
|
|
|
298
298
|
if json_output:
|
|
299
299
|
click.echo(result.model_dump_json(indent=2))
|
|
300
|
+
elif result.status == "healthy":
|
|
301
|
+
click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
|
|
300
302
|
else:
|
|
301
|
-
|
|
302
|
-
click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
|
|
303
|
-
else:
|
|
304
|
-
click.echo(f"Server at {result.url} {result.status}: {result.error}")
|
|
303
|
+
click.echo(f"Server at {result.url} {result.status}: {result.error}")
|
|
@@ -29,14 +29,13 @@ from nat.utils.data_models.schema_validator import validate_yaml
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
async def search_artifacts(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
n_results: int = 10) -> None:
|
|
32
|
+
async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
|
|
33
|
+
query: str,
|
|
34
|
+
search_fields: list[SearchFields],
|
|
35
|
+
visualize: bool,
|
|
36
|
+
component_types: list[ComponentEnum],
|
|
37
|
+
save_path: str | None = None,
|
|
38
|
+
n_results: int = 10) -> None:
|
|
40
39
|
|
|
41
40
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
42
41
|
from nat.registry_handlers.schemas.search import SearchQuery
|
|
@@ -116,14 +115,13 @@ async def search_artifacts( # pylint: disable=R0917
|
|
|
116
115
|
required=False,
|
|
117
116
|
help=("The component types to include in search."),
|
|
118
117
|
)
|
|
119
|
-
def search(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
output_path: str) -> None:
|
|
118
|
+
def search(config_file: str,
|
|
119
|
+
channel: str,
|
|
120
|
+
fields: list[str],
|
|
121
|
+
query: str,
|
|
122
|
+
component_types: list[ComponentEnum],
|
|
123
|
+
n_results: int,
|
|
124
|
+
output_path: str) -> None:
|
|
127
125
|
"""
|
|
128
126
|
Search for NAT artifacts with the specified configuration.
|
|
129
127
|
"""
|
nat/cli/commands/start.py
CHANGED
|
@@ -3,6 +3,9 @@ build-backend = "setuptools.build_meta"
|
|
|
3
3
|
{% if editable %}requires = ["setuptools >= 64", "setuptools-scm>=8"]
|
|
4
4
|
|
|
5
5
|
[tool.setuptools_scm]
|
|
6
|
+
# NAT uses the --first-parent flag to avoid tags from previous releases which have been merged into the develop branch
|
|
7
|
+
# from causing an unexpected version change. This can be safely removed if developing outside of the NAT repository.
|
|
8
|
+
git_describe_command = "git describe --long --first-parent"
|
|
6
9
|
root = "{{ rel_path_to_repo_root}}"{% else %}requires = ["setuptools >= 64"]{% endif %}
|
|
7
10
|
|
|
8
11
|
[project]
|
|
@@ -161,7 +161,6 @@ def get_workflow_path_from_name(workflow_name: str):
|
|
|
161
161
|
default="NAT function template. Please update the description.",
|
|
162
162
|
help="""A description of the component being created. Will be used to populate the docstring and will describe the
|
|
163
163
|
component when inspecting installed components using 'nat info component'""")
|
|
164
|
-
# pylint: disable=missing-param-doc
|
|
165
164
|
def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
|
|
166
165
|
"""
|
|
167
166
|
Create a new NAT workflow using templates.
|
nat/cli/type_registry.py
CHANGED
|
@@ -298,7 +298,7 @@ class RegisteredPackage(BaseModel):
|
|
|
298
298
|
discovery_metadata: DiscoveryMetadata
|
|
299
299
|
|
|
300
300
|
|
|
301
|
-
class TypeRegistry:
|
|
301
|
+
class TypeRegistry:
|
|
302
302
|
|
|
303
303
|
def __init__(self) -> None:
|
|
304
304
|
# Telemetry Exporters
|
|
@@ -588,8 +588,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
588
588
|
except KeyError as err:
|
|
589
589
|
raise KeyError(
|
|
590
590
|
f"An invalid Embedder config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
591
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
-
"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
591
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Embedder client but "
|
|
592
|
+
f"there is no registered conversion from that Embedder provider to LLM framework: {wrapper_type}. "
|
|
593
593
|
"Please provide an Embedder configuration from one of the following providers: "
|
|
594
594
|
f"{set(self._embedder_client_provider_to_framework.keys())}") from err
|
|
595
595
|
|
|
@@ -703,8 +703,8 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
703
703
|
except KeyError as err:
|
|
704
704
|
raise KeyError(
|
|
705
705
|
f"An invalid Retriever config and wrapper combination was supplied. Config: `{config_type}`, "
|
|
706
|
-
"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
-
"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
706
|
+
f"Wrapper: `{wrapper_type}`. The workflow is requesting a {wrapper_type} Retriever client but "
|
|
707
|
+
f"there is no registered conversion from that Retriever provider to LLM framework: {wrapper_type}. "
|
|
708
708
|
"Please provide a Retriever configuration from one of the following providers: "
|
|
709
709
|
f"{set(self._retriever_client_provider_to_framework.keys())}") from err
|
|
710
710
|
|
|
@@ -779,7 +779,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
779
779
|
|
|
780
780
|
self._registration_changed()
|
|
781
781
|
|
|
782
|
-
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
782
|
+
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
783
783
|
|
|
784
784
|
if component_type == ComponentEnum.FRONT_END:
|
|
785
785
|
return self._registered_front_end_infos
|
|
@@ -849,8 +849,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
849
849
|
|
|
850
850
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
851
|
|
|
852
|
-
def get_registered_types_by_component_type(
|
|
853
|
-
self, component_type: ComponentEnum) -> list[str]:
|
|
852
|
+
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
854
853
|
|
|
855
854
|
if component_type == ComponentEnum.FUNCTION:
|
|
856
855
|
return [i.static_type() for i in self._registered_functions]
|
|
@@ -925,7 +924,6 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
925
924
|
if (short_names[key.local_name] == 1):
|
|
926
925
|
type_list.append((key.local_name, key.config_type))
|
|
927
926
|
|
|
928
|
-
# pylint: disable=consider-alternative-union-syntax
|
|
929
927
|
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
928
|
|
|
931
929
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
nat/data_models/config.py
CHANGED
|
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
50
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
50
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
51
51
|
|
|
52
52
|
new_errors = []
|
|
53
53
|
logged_once = False
|
nat/data_models/evaluate.py
CHANGED
|
@@ -108,7 +108,7 @@ class EvalConfig(BaseModel):
|
|
|
108
108
|
@classmethod
|
|
109
109
|
def rebuild_annotations(cls):
|
|
110
110
|
|
|
111
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
111
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
112
112
|
|
|
113
113
|
type_registry = GlobalTypeRegistry.get()
|
|
114
114
|
|
|
@@ -54,19 +54,19 @@ class FunctionDependencies(BaseModel):
|
|
|
54
54
|
return list(v)
|
|
55
55
|
|
|
56
56
|
def add_function(self, function: str):
|
|
57
|
-
self.functions.add(function)
|
|
57
|
+
self.functions.add(function)
|
|
58
58
|
|
|
59
59
|
def add_llm(self, llm: str):
|
|
60
|
-
self.llms.add(llm)
|
|
60
|
+
self.llms.add(llm)
|
|
61
61
|
|
|
62
62
|
def add_embedder(self, embedder: str):
|
|
63
|
-
self.embedders.add(embedder)
|
|
63
|
+
self.embedders.add(embedder)
|
|
64
64
|
|
|
65
65
|
def add_memory_client(self, memory_client: str):
|
|
66
|
-
self.memory_clients.add(memory_client)
|
|
66
|
+
self.memory_clients.add(memory_client)
|
|
67
67
|
|
|
68
68
|
def add_object_store(self, object_store: str):
|
|
69
|
-
self.object_stores.add(object_store)
|
|
69
|
+
self.object_stores.add(object_store)
|
|
70
70
|
|
|
71
71
|
def add_retriever(self, retriever: str):
|
|
72
|
-
self.retrievers.add(retriever)
|
|
72
|
+
self.retrievers.add(retriever)
|
|
@@ -142,7 +142,7 @@ class IntermediateStepPayload(BaseModel):
|
|
|
142
142
|
UUID: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
143
143
|
|
|
144
144
|
@property
|
|
145
|
-
def event_category(self) -> IntermediateStepCategory:
|
|
145
|
+
def event_category(self) -> IntermediateStepCategory:
|
|
146
146
|
match self.event_type:
|
|
147
147
|
case IntermediateStepType.LLM_START:
|
|
148
148
|
return IntermediateStepCategory.LLM
|
|
@@ -180,7 +180,7 @@ class IntermediateStepPayload(BaseModel):
|
|
|
180
180
|
raise ValueError(f"Unknown event type: {self.event_type}")
|
|
181
181
|
|
|
182
182
|
@property
|
|
183
|
-
def event_state(self) -> IntermediateStepState:
|
|
183
|
+
def event_state(self) -> IntermediateStepState:
|
|
184
184
|
match self.event_type:
|
|
185
185
|
case IntermediateStepType.LLM_START:
|
|
186
186
|
return IntermediateStepState.START
|
|
@@ -290,7 +290,7 @@ class IntermediateStep(BaseModel):
|
|
|
290
290
|
return self.payload.usage_info
|
|
291
291
|
|
|
292
292
|
@property
|
|
293
|
-
def UUID(self) -> str:
|
|
293
|
+
def UUID(self) -> str:
|
|
294
294
|
return self.payload.UUID
|
|
295
295
|
|
|
296
296
|
@property
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from re import Pattern
|
|
18
|
+
from typing import Generic
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from pydantic import model_validator
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelGatedFieldMixin(Generic[T]):
|
|
28
|
+
"""
|
|
29
|
+
A mixin that gates a field based on model support.
|
|
30
|
+
|
|
31
|
+
This should be used to automatically validate a field based on a given model.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
field_name: `str`
|
|
36
|
+
The name of the field.
|
|
37
|
+
default_if_supported: `T`
|
|
38
|
+
The default value of the field if it is supported for the model.
|
|
39
|
+
unsupported_models: `Sequence[Pattern[str]] | None`
|
|
40
|
+
A sequence of regex patterns that match the model names NOT supported for the field.
|
|
41
|
+
Defaults to None.
|
|
42
|
+
supported_models: `Sequence[Pattern[str]] | None`
|
|
43
|
+
A sequence of regex patterns that match the model names supported for the field.
|
|
44
|
+
Defaults to None.
|
|
45
|
+
model_keys: `Sequence[str]`
|
|
46
|
+
A sequence of keys that are used to validate the field.
|
|
47
|
+
Defaults to ("model_name", "model", "azure_deployment",)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init_subclass__(
|
|
51
|
+
cls,
|
|
52
|
+
field_name: str | None = None,
|
|
53
|
+
default_if_supported: T | None = None,
|
|
54
|
+
unsupported_models: Sequence[Pattern[str]] | None = None,
|
|
55
|
+
supported_models: Sequence[Pattern[str]] | None = None,
|
|
56
|
+
model_keys: Sequence[str] = ("model_name", "model", "azure_deployment"),
|
|
57
|
+
) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Store the class variables for the field and define the model validator.
|
|
60
|
+
"""
|
|
61
|
+
super().__init_subclass__()
|
|
62
|
+
if ModelGatedFieldMixin in cls.__bases__:
|
|
63
|
+
if field_name is None:
|
|
64
|
+
raise ValueError("field_name must be provided when subclassing ModelGatedFieldMixin")
|
|
65
|
+
if default_if_supported is None:
|
|
66
|
+
raise ValueError("default_if_supported must be provided when subclassing ModelGatedFieldMixin")
|
|
67
|
+
if unsupported_models is None and supported_models is None:
|
|
68
|
+
raise ValueError("Either unsupported_models or supported_models must be provided")
|
|
69
|
+
if unsupported_models is not None and supported_models is not None:
|
|
70
|
+
raise ValueError("Only one of unsupported_models or supported_models must be provided")
|
|
71
|
+
if model_keys is not None and len(model_keys) == 0:
|
|
72
|
+
raise ValueError("model_keys must be provided and non-empty when subclassing ModelGatedFieldMixin")
|
|
73
|
+
cls.field_name = field_name
|
|
74
|
+
cls.default_if_supported = default_if_supported
|
|
75
|
+
cls.unsupported_models = unsupported_models
|
|
76
|
+
cls.supported_models = supported_models
|
|
77
|
+
if model_keys is not None:
|
|
78
|
+
cls.model_keys = model_keys
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def check_model(cls, model_name: str) -> bool:
|
|
82
|
+
"""
|
|
83
|
+
Check if a model is supported for a given field.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_name: The name of the model to check.
|
|
87
|
+
"""
|
|
88
|
+
unsupported = getattr(cls, "unsupported_models", None)
|
|
89
|
+
supported = getattr(cls, "supported_models", None)
|
|
90
|
+
if unsupported is not None:
|
|
91
|
+
return not any(p.search(model_name) for p in unsupported)
|
|
92
|
+
if supported is not None:
|
|
93
|
+
return any(p.search(model_name) for p in supported)
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
cls._model_gated_field_check_model = check_model
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def detect_support(cls, instance: BaseModel) -> str | None:
|
|
100
|
+
for key in getattr(cls, "model_keys"):
|
|
101
|
+
if hasattr(instance, key):
|
|
102
|
+
model_name_value = getattr(instance, key)
|
|
103
|
+
is_supported = getattr(cls, "_model_gated_field_check_model")(str(model_name_value))
|
|
104
|
+
return key if not is_supported else None
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
cls._model_gated_field_detect_support = detect_support
|
|
108
|
+
|
|
109
|
+
@model_validator(mode="after")
|
|
110
|
+
def model_validate(self):
|
|
111
|
+
klass = self.__class__
|
|
112
|
+
|
|
113
|
+
field_name_local = getattr(klass, "field_name")
|
|
114
|
+
current_value = getattr(self, field_name_local, None)
|
|
115
|
+
|
|
116
|
+
found_key = klass._model_gated_field_detect_support(self)
|
|
117
|
+
if found_key is not None:
|
|
118
|
+
if current_value is not None:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"{field_name_local} is not supported for {found_key}: {getattr(self, found_key)}")
|
|
121
|
+
elif current_value is None:
|
|
122
|
+
setattr(self, field_name_local, getattr(klass, "default_if_supported", None))
|
|
123
|
+
return self
|
|
124
|
+
|
|
125
|
+
cls._model_gated_field_model_validator = model_validate
|
|
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
|
|
|
39
39
|
|
|
40
40
|
# Handle improperly formatted JSON strings for list fields
|
|
41
41
|
@field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
|
|
42
|
-
def parse_list_fields(cls, value):
|
|
42
|
+
def parse_list_fields(cls, value):
|
|
43
43
|
if isinstance(value, str):
|
|
44
44
|
# Attempt to parse the string as a list
|
|
45
45
|
return json.loads(value)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
|
|
22
|
+
|
|
23
|
+
_UNSUPPORTED_TEMPERATURE_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TemperatureMixin(
|
|
27
|
+
BaseModel,
|
|
28
|
+
ModelGatedFieldMixin[float],
|
|
29
|
+
field_name="temperature",
|
|
30
|
+
default_if_supported=0.0,
|
|
31
|
+
unsupported_models=_UNSUPPORTED_TEMPERATURE_MODELS,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for temperature configuration.
|
|
35
|
+
"""
|
|
36
|
+
temperature: float | None = Field(default=None, ge=0.0, le=1.0, description="Sampling temperature in [0, 1].")
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.model_gated_field_mixin import ModelGatedFieldMixin
|
|
22
|
+
|
|
23
|
+
_UNSUPPORTED_TOP_P_MODELS = (re.compile(r"gpt-?5", re.IGNORECASE), )
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TopPMixin(
|
|
27
|
+
BaseModel,
|
|
28
|
+
ModelGatedFieldMixin[float],
|
|
29
|
+
field_name="top_p",
|
|
30
|
+
default_if_supported=1.0,
|
|
31
|
+
unsupported_models=_UNSUPPORTED_TOP_P_MODELS,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Mixin class for top-p configuration.
|
|
35
|
+
"""
|
|
36
|
+
top_p: float | None = Field(default=None, ge=0.0, le=1.0, description="Top-p for distribution sampling.")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import AliasChoices
|
|
17
|
+
from pydantic import ConfigDict
|
|
18
|
+
from pydantic import Field
|
|
19
|
+
|
|
20
|
+
from nat.builder.builder import Builder
|
|
21
|
+
from nat.builder.embedder import EmbedderProviderInfo
|
|
22
|
+
from nat.cli.register_workflow import register_embedder_provider
|
|
23
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
24
|
+
from nat.data_models.retry_mixin import RetryMixin
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AzureOpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="azure_openai"):
|
|
28
|
+
"""An Azure OpenAI embedder provider to be used with an embedder client."""
|
|
29
|
+
|
|
30
|
+
model_config = ConfigDict(protected_namespaces=(), extra="allow")
|
|
31
|
+
|
|
32
|
+
api_key: str | None = Field(default=None, description="Azure OpenAI API key to interact with hosted model.")
|
|
33
|
+
api_version: str = Field(default="2025-04-01-preview", description="Azure OpenAI API version.")
|
|
34
|
+
azure_endpoint: str | None = Field(validation_alias=AliasChoices("azure_endpoint", "base_url"),
|
|
35
|
+
serialization_alias="azure_endpoint",
|
|
36
|
+
default=None,
|
|
37
|
+
description="Base URL for the hosted model.")
|
|
38
|
+
azure_deployment: str = Field(validation_alias=AliasChoices("azure_deployment", "model_name", "model"),
|
|
39
|
+
serialization_alias="azure_deployment",
|
|
40
|
+
description="The Azure OpenAI hosted model/deployment name.")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_embedder_provider(config_type=AzureOpenAIEmbedderModelConfig)
|
|
44
|
+
async def azure_openai_embedder_model(config: AzureOpenAIEmbedderModelConfig, _builder: Builder):
|
|
45
|
+
|
|
46
|
+
yield EmbedderProviderInfo(config=config, description="An Azure OpenAI model for use with an Embedder client.")
|
nat/embedder/openai_embedder.py
CHANGED
|
@@ -34,10 +34,9 @@ class OpenAIEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, name="openai"):
|
|
|
34
34
|
model_name: str = Field(validation_alias=AliasChoices("model_name", "model"),
|
|
35
35
|
serialization_alias="model",
|
|
36
36
|
description="The OpenAI hosted model name.")
|
|
37
|
-
max_retries: int = Field(default=2, description="The max number of retries for the request.")
|
|
38
37
|
|
|
39
38
|
|
|
40
39
|
@register_embedder_provider(config_type=OpenAIEmbedderModelConfig)
|
|
41
|
-
async def
|
|
40
|
+
async def openai_embedder_model(config: OpenAIEmbedderModelConfig, _builder: Builder):
|
|
42
41
|
|
|
43
42
|
yield EmbedderProviderInfo(config=config, description="An OpenAI model for use with an Embedder client.")
|
nat/embedder/register.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
# isort:skip_file
|
|
19
18
|
|
|
20
19
|
# Import any providers which need to be automatically registered here
|
|
20
|
+
from . import azure_openai_embedder
|
|
21
21
|
from . import nim_embedder
|
|
22
22
|
from . import openai_embedder
|
nat/eval/config.py
CHANGED
|
@@ -44,6 +44,8 @@ class EvaluationRunConfig(BaseModel):
|
|
|
44
44
|
# number of passes at each concurrency, if 0 the dataset is adjusted to a multiple of the
|
|
45
45
|
# concurrency. The is only used if adjust_dataset_size is true
|
|
46
46
|
num_passes: int = 0
|
|
47
|
+
# timeout for waiting for trace export tasks to complete
|
|
48
|
+
export_timeout: float = 60.0
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
class EvaluationRunOutput(BaseModel):
|
|
@@ -146,13 +146,12 @@ class DatasetHandler:
|
|
|
146
146
|
# When num_passes is specified, always use concurrency * num_passes
|
|
147
147
|
# This respects the user's intent for exact number of passes
|
|
148
148
|
target_size = self.concurrency * self.num_passes
|
|
149
|
+
# When num_passes = 0, use the largest multiple of concurrency <= original_size
|
|
150
|
+
# If original_size < concurrency, we need at least concurrency rows
|
|
151
|
+
elif original_size >= self.concurrency:
|
|
152
|
+
target_size = (original_size // self.concurrency) * self.concurrency
|
|
149
153
|
else:
|
|
150
|
-
|
|
151
|
-
# If original_size < concurrency, we need at least concurrency rows
|
|
152
|
-
if original_size >= self.concurrency:
|
|
153
|
-
target_size = (original_size // self.concurrency) * self.concurrency
|
|
154
|
-
else:
|
|
155
|
-
target_size = self.concurrency
|
|
154
|
+
target_size = self.concurrency
|
|
156
155
|
|
|
157
156
|
if target_size == 0:
|
|
158
157
|
raise ValueError("Input dataset too small for even one batch at given concurrency.")
|