nvidia-nat 1.3.0a20250822__py3-none-any.whl → 1.3.0a20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +0 -1
- nat/agent/react_agent/agent.py +21 -3
- nat/agent/react_agent/register.py +1 -1
- nat/agent/register.py +0 -1
- nat/agent/rewoo_agent/agent.py +0 -1
- nat/agent/rewoo_agent/register.py +1 -1
- nat/agent/tool_calling_agent/agent.py +0 -1
- nat/agent/tool_calling_agent/register.py +1 -1
- nat/authentication/api_key/api_key_auth_provider.py +1 -1
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +1 -1
- nat/builder/context.py +9 -1
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +5 -7
- nat/builder/workflow_builder.py +0 -1
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/info/list_mcp.py +3 -4
- nat/cli/commands/registry/search.py +14 -16
- nat/cli/commands/start.py +0 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +0 -1
- nat/cli/type_registry.py +3 -5
- nat/data_models/config.py +1 -1
- nat/data_models/evaluate.py +1 -1
- nat/data_models/function_dependencies.py +6 -6
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/model_gated_field_mixin.py +125 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +36 -0
- nat/data_models/top_p_mixin.py +36 -0
- nat/embedder/register.py +0 -1
- nat/eval/dataset_handler/dataset_handler.py +5 -6
- nat/eval/evaluate.py +7 -8
- nat/eval/rag_evaluator/register.py +2 -2
- nat/eval/register.py +0 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +0 -3
- nat/eval/utils/weave_eval.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +3 -2
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +0 -2
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +48 -49
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/front_ends/register.py +0 -1
- nat/llm/aws_bedrock_llm.py +3 -3
- nat/llm/azure_openai_llm.py +3 -4
- nat/llm/nim_llm.py +4 -4
- nat/llm/openai_llm.py +4 -4
- nat/llm/register.py +0 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/object_store/register.py +0 -1
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/register.py +3 -3
- nat/profiler/callbacks/langchain_callback_handler.py +1 -1
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +1 -1
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +1 -4
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/profile_runner.py +13 -8
- nat/registry_handlers/package_utils.py +0 -1
- nat/registry_handlers/pypi/pypi_handler.py +20 -23
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +8 -9
- nat/retriever/register.py +0 -1
- nat/runtime/session.py +23 -8
- nat/settings/global_settings.py +0 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/document_search.py +1 -1
- nat/tool/mcp/mcp_tool.py +1 -1
- nat/tool/register.py +0 -1
- nat/utils/data_models/schema_validator.py +2 -2
- nat/utils/exception_handlers/automatic_retries.py +0 -2
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +2 -2
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +4 -6
- nat/utils/type_utils.py +4 -4
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/METADATA +1 -1
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/RECORD +94 -91
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250822.dist-info → nvidia_nat-1.3.0a20250823.dist-info}/top_level.txt +0 -0
aiq/__init__.py
CHANGED
|
@@ -13,10 +13,10 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import sys
|
|
17
16
|
import importlib
|
|
18
17
|
import importlib.abc
|
|
19
18
|
import importlib.util
|
|
19
|
+
import sys
|
|
20
20
|
import warnings
|
|
21
21
|
|
|
22
22
|
|
|
@@ -26,7 +26,7 @@ class CompatFinder(importlib.abc.MetaPathFinder):
|
|
|
26
26
|
self.alias_prefix = alias_prefix
|
|
27
27
|
self.target_prefix = target_prefix
|
|
28
28
|
|
|
29
|
-
def find_spec(self, fullname, path, target=None):
|
|
29
|
+
def find_spec(self, fullname, path, target=None):
|
|
30
30
|
if fullname == self.alias_prefix or fullname.startswith(self.alias_prefix + "."):
|
|
31
31
|
# Map aiq.something -> nat.something
|
|
32
32
|
target_name = self.target_prefix + fullname[len(self.alias_prefix):]
|
nat/agent/base.py
CHANGED
|
@@ -179,7 +179,6 @@ class BaseAgent(ABC):
|
|
|
179
179
|
logger.debug("%s Retrying tool call for %s in %d seconds...", AGENT_LOG_PREFIX, tool.name, sleep_time)
|
|
180
180
|
await asyncio.sleep(sleep_time)
|
|
181
181
|
|
|
182
|
-
# pylint: disable=C0209
|
|
183
182
|
# All retries exhausted, return error message
|
|
184
183
|
error_content = "Tool call failed after all retry attempts. Last error: %s" % str(last_exception)
|
|
185
184
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
# pylint: disable=R0917
|
|
18
17
|
import logging
|
|
18
|
+
import re
|
|
19
19
|
import typing
|
|
20
20
|
from json import JSONDecodeError
|
|
21
21
|
|
|
@@ -23,12 +23,14 @@ from langchain_core.agents import AgentAction
|
|
|
23
23
|
from langchain_core.agents import AgentFinish
|
|
24
24
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
25
25
|
from langchain_core.language_models import BaseChatModel
|
|
26
|
+
from langchain_core.language_models import LanguageModelInput
|
|
26
27
|
from langchain_core.messages.ai import AIMessage
|
|
27
28
|
from langchain_core.messages.base import BaseMessage
|
|
28
29
|
from langchain_core.messages.human import HumanMessage
|
|
29
30
|
from langchain_core.messages.tool import ToolMessage
|
|
30
31
|
from langchain_core.prompts import ChatPromptTemplate
|
|
31
32
|
from langchain_core.prompts import MessagesPlaceholder
|
|
33
|
+
from langchain_core.runnables import Runnable
|
|
32
34
|
from langchain_core.runnables.config import RunnableConfig
|
|
33
35
|
from langchain_core.tools import BaseTool
|
|
34
36
|
from pydantic import BaseModel
|
|
@@ -97,11 +99,27 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
97
99
|
f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")
|
|
98
100
|
prompt = prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
|
|
99
101
|
# construct the ReAct Agent
|
|
100
|
-
|
|
101
|
-
self.agent = prompt | bound_llm
|
|
102
|
+
self.agent = prompt | self._maybe_bind_llm_and_yield()
|
|
102
103
|
self.tools_dict = {tool.name: tool for tool in tools}
|
|
103
104
|
logger.debug("%s Initialized ReAct Agent Graph", AGENT_LOG_PREFIX)
|
|
104
105
|
|
|
106
|
+
def _maybe_bind_llm_and_yield(self) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
107
|
+
"""
|
|
108
|
+
Bind additional parameters to the LLM if needed
|
|
109
|
+
- if the LLM is a smart model, no need to bind any additional parameters
|
|
110
|
+
- if the LLM is a non-smart model, bind a stop sequence to the LLM
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Runnable[LanguageModelInput, BaseMessage]: The LLM with any additional parameters bound.
|
|
114
|
+
"""
|
|
115
|
+
# models that don't need (or don't support)a stop sequence
|
|
116
|
+
smart_models = re.compile(r"gpt-?5", re.IGNORECASE)
|
|
117
|
+
if any(smart_models.search(getattr(self.llm, model, "")) for model in ["model", "model_name"]):
|
|
118
|
+
# no need to bind any additional parameters to the LLM
|
|
119
|
+
return self.llm
|
|
120
|
+
# add a stop sequence to the LLM
|
|
121
|
+
return self.llm.bind(stop=["Observation:"])
|
|
122
|
+
|
|
105
123
|
def _get_tool(self, tool_name: str):
|
|
106
124
|
try:
|
|
107
125
|
return self.tools_dict.get(tool_name)
|
|
@@ -125,7 +125,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
125
125
|
|
|
126
126
|
# get and return the output from the state
|
|
127
127
|
state = ReActGraphState(**state)
|
|
128
|
-
output_message = state.messages[-1]
|
|
128
|
+
output_message = state.messages[-1]
|
|
129
129
|
return ChatResponse.from_string(str(output_message.content))
|
|
130
130
|
|
|
131
131
|
except Exception as ex:
|
nat/agent/register.py
CHANGED
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -133,7 +133,7 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
133
133
|
|
|
134
134
|
# get and return the output from the state
|
|
135
135
|
state = ReWOOGraphState(**state)
|
|
136
|
-
output_message = state.result.content
|
|
136
|
+
output_message = state.result.content
|
|
137
137
|
return ChatResponse.from_string(output_message)
|
|
138
138
|
|
|
139
139
|
except Exception as ex:
|
|
@@ -86,7 +86,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
86
86
|
|
|
87
87
|
# get and return the output from the state
|
|
88
88
|
state = ToolCallAgentGraphState(**state)
|
|
89
|
-
output_message = state.messages[-1]
|
|
89
|
+
output_message = state.messages[-1]
|
|
90
90
|
return output_message.content
|
|
91
91
|
except Exception as ex:
|
|
92
92
|
logger.exception("%s Tool Calling Agent failed with exception: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
|
|
@@ -31,7 +31,7 @@ class APIKeyAuthProvider(AuthProviderBase[APIKeyAuthProviderConfig]):
|
|
|
31
31
|
# fmt: off
|
|
32
32
|
def __init__(self,
|
|
33
33
|
config: APIKeyAuthProviderConfig,
|
|
34
|
-
config_name: str | None = None) -> None:
|
|
34
|
+
config_name: str | None = None) -> None:
|
|
35
35
|
assert isinstance(config, APIKeyAuthProviderConfig), ("Config is not APIKeyAuthProviderConfig")
|
|
36
36
|
super().__init__(config)
|
|
37
37
|
# fmt: on
|
nat/authentication/register.py
CHANGED
nat/builder/builder.py
CHANGED
|
@@ -58,7 +58,7 @@ class UserManagerHolder():
|
|
|
58
58
|
return self._context.user_manager.get_id()
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
class Builder(ABC):
|
|
61
|
+
class Builder(ABC):
|
|
62
62
|
|
|
63
63
|
@abstractmethod
|
|
64
64
|
async def add_function(self, name: str | FunctionRef, config: FunctionBaseConfig) -> Function:
|
nat/builder/context.py
CHANGED
|
@@ -38,7 +38,7 @@ from nat.utils.reactive.subject import Subject
|
|
|
38
38
|
|
|
39
39
|
class Singleton(type):
|
|
40
40
|
|
|
41
|
-
def __init__(cls, name, bases, dict):
|
|
41
|
+
def __init__(cls, name, bases, dict):
|
|
42
42
|
super(Singleton, cls).__init__(name, bases, dict)
|
|
43
43
|
cls.instance = None
|
|
44
44
|
|
|
@@ -65,6 +65,7 @@ class ContextState(metaclass=Singleton):
|
|
|
65
65
|
|
|
66
66
|
def __init__(self):
|
|
67
67
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
68
|
+
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
68
69
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
69
70
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
70
71
|
self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
|
|
@@ -165,6 +166,13 @@ class Context:
|
|
|
165
166
|
"""
|
|
166
167
|
return self._context_state.conversation_id.get()
|
|
167
168
|
|
|
169
|
+
@property
|
|
170
|
+
def user_message_id(self) -> str | None:
|
|
171
|
+
"""
|
|
172
|
+
This property retrieves the user message ID which is the unique identifier for the current user message.
|
|
173
|
+
"""
|
|
174
|
+
return self._context_state.user_message_id.get()
|
|
175
|
+
|
|
168
176
|
@contextmanager
|
|
169
177
|
def push_active_function(self, function_name: str, input_data: typing.Any | None):
|
|
170
178
|
"""
|
nat/builder/function_base.py
CHANGED
|
@@ -111,7 +111,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
111
111
|
ValueError
|
|
112
112
|
If the input type cannot be determined from the class definition
|
|
113
113
|
"""
|
|
114
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
114
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
115
115
|
|
|
116
116
|
base_cls_args = typing.get_args(base_cls)
|
|
117
117
|
|
|
@@ -196,7 +196,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
196
196
|
ValueError
|
|
197
197
|
If the streaming output type cannot be determined from the class definition
|
|
198
198
|
"""
|
|
199
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
199
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
200
200
|
|
|
201
201
|
base_cls_args = typing.get_args(base_cls)
|
|
202
202
|
|
|
@@ -269,7 +269,7 @@ class FunctionBase(typing.Generic[InputT, StreamingOutputT, SingleOutputT], ABC)
|
|
|
269
269
|
ValueError
|
|
270
270
|
If the single output type cannot be determined from the class definition
|
|
271
271
|
"""
|
|
272
|
-
for base_cls in self.__class__.__orig_bases__:
|
|
272
|
+
for base_cls in self.__class__.__orig_bases__:
|
|
273
273
|
|
|
274
274
|
base_cls_args = typing.get_args(base_cls)
|
|
275
275
|
|
nat/builder/function_info.py
CHANGED
|
@@ -231,7 +231,7 @@ class FunctionDescriptor:
|
|
|
231
231
|
else:
|
|
232
232
|
annotations = [param.annotation for param in sig.parameters.values()]
|
|
233
233
|
|
|
234
|
-
is_input_typed = all([a != sig.empty for a in annotations])
|
|
234
|
+
is_input_typed = all([a != sig.empty for a in annotations])
|
|
235
235
|
|
|
236
236
|
input_type = tuple[*annotations] if is_input_typed else None # noqa: syntax-error
|
|
237
237
|
|
|
@@ -372,18 +372,16 @@ class FunctionInfo:
|
|
|
372
372
|
|
|
373
373
|
if (stream_to_single_fn is not None):
|
|
374
374
|
raise ValueError("Cannot provide both single_fn and stream_to_single_fn")
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
375
|
+
elif (stream_to_single_fn is not None and stream_fn is None):
|
|
376
|
+
raise ValueError("stream_fn must be provided if stream_to_single_fn is provided")
|
|
378
377
|
|
|
379
378
|
if (stream_fn is not None):
|
|
380
379
|
final_stream_fn = stream_fn
|
|
381
380
|
|
|
382
381
|
if (single_to_stream_fn is not None):
|
|
383
382
|
raise ValueError("Cannot provide both stream_fn and single_to_stream_fn")
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
383
|
+
elif (single_to_stream_fn is not None and single_fn is None):
|
|
384
|
+
raise ValueError("single_fn must be provided if single_to_stream_fn is provided")
|
|
387
385
|
|
|
388
386
|
if (single_fn is None and stream_fn is None):
|
|
389
387
|
raise ValueError("At least one of single_fn or stream_fn must be provided")
|
nat/builder/workflow_builder.py
CHANGED
|
@@ -127,7 +127,6 @@ class ConfiguredTTCStrategy:
|
|
|
127
127
|
instance: StrategyBase
|
|
128
128
|
|
|
129
129
|
|
|
130
|
-
# pylint: disable=too-many-public-methods
|
|
131
130
|
class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
132
131
|
|
|
133
132
|
def __init__(self, *, general_config: GeneralConfig | None = None, registry: TypeRegistry | None = None):
|
nat/cli/commands/evaluate.py
CHANGED
|
@@ -26,14 +26,13 @@ from nat.registry_handlers.schemas.search import SearchFields
|
|
|
26
26
|
logger = logging.getLogger(__name__)
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
async def search_artifacts(
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
save_path: str | None) -> None:
|
|
29
|
+
async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
|
|
30
|
+
component_types: list[ComponentEnum],
|
|
31
|
+
visualize: bool,
|
|
32
|
+
query: str,
|
|
33
|
+
num_results: int,
|
|
34
|
+
query_fields: list[SearchFields],
|
|
35
|
+
save_path: str | None) -> None:
|
|
37
36
|
|
|
38
37
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
39
38
|
from nat.registry_handlers.schemas.search import SearchQuery
|
|
@@ -297,8 +297,7 @@ def ping(url: str, timeout: int, json_output: bool) -> None:
|
|
|
297
297
|
|
|
298
298
|
if json_output:
|
|
299
299
|
click.echo(result.model_dump_json(indent=2))
|
|
300
|
+
elif result.status == "healthy":
|
|
301
|
+
click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
|
|
300
302
|
else:
|
|
301
|
-
|
|
302
|
-
click.echo(f"Server at {result.url} is healthy (response time: {result.response_time_ms}ms)")
|
|
303
|
-
else:
|
|
304
|
-
click.echo(f"Server at {result.url} {result.status}: {result.error}")
|
|
303
|
+
click.echo(f"Server at {result.url} {result.status}: {result.error}")
|
|
@@ -29,14 +29,13 @@ from nat.utils.data_models.schema_validator import validate_yaml
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
async def search_artifacts(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
n_results: int = 10) -> None:
|
|
32
|
+
async def search_artifacts(registry_handler_config: RegistryHandlerBaseConfig,
|
|
33
|
+
query: str,
|
|
34
|
+
search_fields: list[SearchFields],
|
|
35
|
+
visualize: bool,
|
|
36
|
+
component_types: list[ComponentEnum],
|
|
37
|
+
save_path: str | None = None,
|
|
38
|
+
n_results: int = 10) -> None:
|
|
40
39
|
|
|
41
40
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
42
41
|
from nat.registry_handlers.schemas.search import SearchQuery
|
|
@@ -116,14 +115,13 @@ async def search_artifacts( # pylint: disable=R0917
|
|
|
116
115
|
required=False,
|
|
117
116
|
help=("The component types to include in search."),
|
|
118
117
|
)
|
|
119
|
-
def search(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
output_path: str) -> None:
|
|
118
|
+
def search(config_file: str,
|
|
119
|
+
channel: str,
|
|
120
|
+
fields: list[str],
|
|
121
|
+
query: str,
|
|
122
|
+
component_types: list[ComponentEnum],
|
|
123
|
+
n_results: int,
|
|
124
|
+
output_path: str) -> None:
|
|
127
125
|
"""
|
|
128
126
|
Search for NAT artifacts with the specified configuration.
|
|
129
127
|
"""
|
nat/cli/commands/start.py
CHANGED
|
@@ -161,7 +161,6 @@ def get_workflow_path_from_name(workflow_name: str):
|
|
|
161
161
|
default="NAT function template. Please update the description.",
|
|
162
162
|
help="""A description of the component being created. Will be used to populate the docstring and will describe the
|
|
163
163
|
component when inspecting installed components using 'nat info component'""")
|
|
164
|
-
# pylint: disable=missing-param-doc
|
|
165
164
|
def create_command(workflow_name: str, install: bool, workflow_dir: str, description: str):
|
|
166
165
|
"""
|
|
167
166
|
Create a new NAT workflow using templates.
|
nat/cli/type_registry.py
CHANGED
|
@@ -298,7 +298,7 @@ class RegisteredPackage(BaseModel):
|
|
|
298
298
|
discovery_metadata: DiscoveryMetadata
|
|
299
299
|
|
|
300
300
|
|
|
301
|
-
class TypeRegistry:
|
|
301
|
+
class TypeRegistry:
|
|
302
302
|
|
|
303
303
|
def __init__(self) -> None:
|
|
304
304
|
# Telemetry Exporters
|
|
@@ -779,7 +779,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
779
779
|
|
|
780
780
|
self._registration_changed()
|
|
781
781
|
|
|
782
|
-
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
782
|
+
def get_infos_by_type(self, component_type: ComponentEnum) -> dict:
|
|
783
783
|
|
|
784
784
|
if component_type == ComponentEnum.FRONT_END:
|
|
785
785
|
return self._registered_front_end_infos
|
|
@@ -849,8 +849,7 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
849
849
|
|
|
850
850
|
raise ValueError(f"Supplied an unsupported component type {component_type}")
|
|
851
851
|
|
|
852
|
-
def get_registered_types_by_component_type(
|
|
853
|
-
self, component_type: ComponentEnum) -> list[str]:
|
|
852
|
+
def get_registered_types_by_component_type(self, component_type: ComponentEnum) -> list[str]:
|
|
854
853
|
|
|
855
854
|
if component_type == ComponentEnum.FUNCTION:
|
|
856
855
|
return [i.static_type() for i in self._registered_functions]
|
|
@@ -925,7 +924,6 @@ class TypeRegistry: # pylint: disable=too-many-public-methods
|
|
|
925
924
|
if (short_names[key.local_name] == 1):
|
|
926
925
|
type_list.append((key.local_name, key.config_type))
|
|
927
926
|
|
|
928
|
-
# pylint: disable=consider-alternative-union-syntax
|
|
929
927
|
return typing.Union[tuple(typing.Annotated[x_type, Tag(x_id)] for x_id, x_type in type_list)]
|
|
930
928
|
|
|
931
929
|
def compute_annotation(self, cls: type[TypedBaseModelT]):
|
nat/data_models/config.py
CHANGED
|
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _process_validation_error(err: ValidationError, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
|
50
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
50
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
51
51
|
|
|
52
52
|
new_errors = []
|
|
53
53
|
logged_once = False
|
nat/data_models/evaluate.py
CHANGED
|
@@ -108,7 +108,7 @@ class EvalConfig(BaseModel):
|
|
|
108
108
|
@classmethod
|
|
109
109
|
def rebuild_annotations(cls):
|
|
110
110
|
|
|
111
|
-
from nat.cli.type_registry import GlobalTypeRegistry
|
|
111
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
112
112
|
|
|
113
113
|
type_registry = GlobalTypeRegistry.get()
|
|
114
114
|
|
|
@@ -54,19 +54,19 @@ class FunctionDependencies(BaseModel):
|
|
|
54
54
|
return list(v)
|
|
55
55
|
|
|
56
56
|
def add_function(self, function: str):
|
|
57
|
-
self.functions.add(function)
|
|
57
|
+
self.functions.add(function)
|
|
58
58
|
|
|
59
59
|
def add_llm(self, llm: str):
|
|
60
|
-
self.llms.add(llm)
|
|
60
|
+
self.llms.add(llm)
|
|
61
61
|
|
|
62
62
|
def add_embedder(self, embedder: str):
|
|
63
|
-
self.embedders.add(embedder)
|
|
63
|
+
self.embedders.add(embedder)
|
|
64
64
|
|
|
65
65
|
def add_memory_client(self, memory_client: str):
|
|
66
|
-
self.memory_clients.add(memory_client)
|
|
66
|
+
self.memory_clients.add(memory_client)
|
|
67
67
|
|
|
68
68
|
def add_object_store(self, object_store: str):
|
|
69
|
-
self.object_stores.add(object_store)
|
|
69
|
+
self.object_stores.add(object_store)
|
|
70
70
|
|
|
71
71
|
def add_retriever(self, retriever: str):
|
|
72
|
-
self.retrievers.add(retriever)
|
|
72
|
+
self.retrievers.add(retriever)
|
|
@@ -142,7 +142,7 @@ class IntermediateStepPayload(BaseModel):
|
|
|
142
142
|
UUID: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
143
143
|
|
|
144
144
|
@property
|
|
145
|
-
def event_category(self) -> IntermediateStepCategory:
|
|
145
|
+
def event_category(self) -> IntermediateStepCategory:
|
|
146
146
|
match self.event_type:
|
|
147
147
|
case IntermediateStepType.LLM_START:
|
|
148
148
|
return IntermediateStepCategory.LLM
|
|
@@ -180,7 +180,7 @@ class IntermediateStepPayload(BaseModel):
|
|
|
180
180
|
raise ValueError(f"Unknown event type: {self.event_type}")
|
|
181
181
|
|
|
182
182
|
@property
|
|
183
|
-
def event_state(self) -> IntermediateStepState:
|
|
183
|
+
def event_state(self) -> IntermediateStepState:
|
|
184
184
|
match self.event_type:
|
|
185
185
|
case IntermediateStepType.LLM_START:
|
|
186
186
|
return IntermediateStepState.START
|
|
@@ -290,7 +290,7 @@ class IntermediateStep(BaseModel):
|
|
|
290
290
|
return self.payload.usage_info
|
|
291
291
|
|
|
292
292
|
@property
|
|
293
|
-
def UUID(self) -> str:
|
|
293
|
+
def UUID(self) -> str:
|
|
294
294
|
return self.payload.UUID
|
|
295
295
|
|
|
296
296
|
@property
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import Sequence
|
|
17
|
+
from re import Pattern
|
|
18
|
+
from typing import Generic
|
|
19
|
+
from typing import TypeVar
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
from pydantic import model_validator
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelGatedFieldMixin(Generic[T]):
|
|
28
|
+
"""
|
|
29
|
+
A mixin that gates a field based on model support.
|
|
30
|
+
|
|
31
|
+
This should be used to automatically validate a field based on a given model.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
field_name: `str`
|
|
36
|
+
The name of the field.
|
|
37
|
+
default_if_supported: `T`
|
|
38
|
+
The default value of the field if it is supported for the model.
|
|
39
|
+
unsupported_models: `Sequence[Pattern[str]] | None`
|
|
40
|
+
A sequence of regex patterns that match the model names NOT supported for the field.
|
|
41
|
+
Defaults to None.
|
|
42
|
+
supported_models: `Sequence[Pattern[str]] | None`
|
|
43
|
+
A sequence of regex patterns that match the model names supported for the field.
|
|
44
|
+
Defaults to None.
|
|
45
|
+
model_keys: `Sequence[str]`
|
|
46
|
+
A sequence of keys that are used to validate the field.
|
|
47
|
+
Defaults to ("model_name", "model", "azure_deployment",)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init_subclass__(
|
|
51
|
+
cls,
|
|
52
|
+
field_name: str | None = None,
|
|
53
|
+
default_if_supported: T | None = None,
|
|
54
|
+
unsupported_models: Sequence[Pattern[str]] | None = None,
|
|
55
|
+
supported_models: Sequence[Pattern[str]] | None = None,
|
|
56
|
+
model_keys: Sequence[str] = ("model_name", "model", "azure_deployment"),
|
|
57
|
+
) -> None:
|
|
58
|
+
"""
|
|
59
|
+
Store the class variables for the field and define the model validator.
|
|
60
|
+
"""
|
|
61
|
+
super().__init_subclass__()
|
|
62
|
+
if ModelGatedFieldMixin in cls.__bases__:
|
|
63
|
+
if field_name is None:
|
|
64
|
+
raise ValueError("field_name must be provided when subclassing ModelGatedFieldMixin")
|
|
65
|
+
if default_if_supported is None:
|
|
66
|
+
raise ValueError("default_if_supported must be provided when subclassing ModelGatedFieldMixin")
|
|
67
|
+
if unsupported_models is None and supported_models is None:
|
|
68
|
+
raise ValueError("Either unsupported_models or supported_models must be provided")
|
|
69
|
+
if unsupported_models is not None and supported_models is not None:
|
|
70
|
+
raise ValueError("Only one of unsupported_models or supported_models must be provided")
|
|
71
|
+
if model_keys is not None and len(model_keys) == 0:
|
|
72
|
+
raise ValueError("model_keys must be provided and non-empty when subclassing ModelGatedFieldMixin")
|
|
73
|
+
cls.field_name = field_name
|
|
74
|
+
cls.default_if_supported = default_if_supported
|
|
75
|
+
cls.unsupported_models = unsupported_models
|
|
76
|
+
cls.supported_models = supported_models
|
|
77
|
+
if model_keys is not None:
|
|
78
|
+
cls.model_keys = model_keys
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def check_model(cls, model_name: str) -> bool:
|
|
82
|
+
"""
|
|
83
|
+
Check if a model is supported for a given field.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_name: The name of the model to check.
|
|
87
|
+
"""
|
|
88
|
+
unsupported = getattr(cls, "unsupported_models", None)
|
|
89
|
+
supported = getattr(cls, "supported_models", None)
|
|
90
|
+
if unsupported is not None:
|
|
91
|
+
return not any(p.search(model_name) for p in unsupported)
|
|
92
|
+
if supported is not None:
|
|
93
|
+
return any(p.search(model_name) for p in supported)
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
cls._model_gated_field_check_model = check_model
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def detect_support(cls, instance: BaseModel) -> str | None:
|
|
100
|
+
for key in getattr(cls, "model_keys"):
|
|
101
|
+
if hasattr(instance, key):
|
|
102
|
+
model_name_value = getattr(instance, key)
|
|
103
|
+
is_supported = getattr(cls, "_model_gated_field_check_model")(str(model_name_value))
|
|
104
|
+
return key if not is_supported else None
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
cls._model_gated_field_detect_support = detect_support
|
|
108
|
+
|
|
109
|
+
@model_validator(mode="after")
|
|
110
|
+
def model_validate(self):
|
|
111
|
+
klass = self.__class__
|
|
112
|
+
|
|
113
|
+
field_name_local = getattr(klass, "field_name")
|
|
114
|
+
current_value = getattr(self, field_name_local, None)
|
|
115
|
+
|
|
116
|
+
found_key = klass._model_gated_field_detect_support(self)
|
|
117
|
+
if found_key is not None:
|
|
118
|
+
if current_value is not None:
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"{field_name_local} is not supported for {found_key}: {getattr(self, found_key)}")
|
|
121
|
+
elif current_value is None:
|
|
122
|
+
setattr(self, field_name_local, getattr(klass, "default_if_supported", None))
|
|
123
|
+
return self
|
|
124
|
+
|
|
125
|
+
cls._model_gated_field_model_validator = model_validate
|
|
@@ -39,7 +39,7 @@ class SWEBenchInput(BaseModel):
|
|
|
39
39
|
|
|
40
40
|
# Handle improperly formatted JSON strings for list fields
|
|
41
41
|
@field_validator("FAIL_TO_PASS", "PASS_TO_PASS", mode="before")
|
|
42
|
-
def parse_list_fields(cls, value):
|
|
42
|
+
def parse_list_fields(cls, value):
|
|
43
43
|
if isinstance(value, str):
|
|
44
44
|
# Attempt to parse the string as a list
|
|
45
45
|
return json.loads(value)
|