nvidia-nat 1.3.0a20250826__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +6 -6
- nat/agent/dual_node.py +7 -2
- nat/agent/react_agent/agent.py +6 -1
- nat/agent/react_agent/register.py +4 -0
- nat/agent/rewoo_agent/agent.py +7 -2
- nat/agent/rewoo_agent/register.py +5 -1
- nat/agent/tool_calling_agent/agent.py +6 -1
- nat/agent/tool_calling_agent/register.py +4 -0
- nat/builder/context.py +7 -2
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/entrypoint.py +3 -1
- nat/data_models/gated_field_mixin.py +12 -14
- nat/data_models/temperature_mixin.py +1 -1
- nat/data_models/thinking_mixin.py +68 -0
- nat/data_models/top_p_mixin.py +1 -1
- nat/llm/aws_bedrock_llm.py +10 -9
- nat/llm/azure_openai_llm.py +9 -1
- nat/llm/nim_llm.py +2 -1
- nat/llm/openai_llm.py +2 -1
- nat/llm/utils/thinking.py +215 -0
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/profiler/decorators/function_tracking.py +125 -0
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/METADATA +3 -1
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/RECORD +31 -25
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250826.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/top_level.txt +0 -0
nat/agent/base.py
CHANGED
|
@@ -70,12 +70,14 @@ class BaseAgent(ABC):
|
|
|
70
70
|
llm: BaseChatModel,
|
|
71
71
|
tools: list[BaseTool],
|
|
72
72
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
73
|
-
detailed_logs: bool = False
|
|
73
|
+
detailed_logs: bool = False,
|
|
74
|
+
log_response_max_chars: int = 1000) -> None:
|
|
74
75
|
logger.debug("Initializing Agent Graph")
|
|
75
76
|
self.llm = llm
|
|
76
77
|
self.tools = tools
|
|
77
78
|
self.callbacks = callbacks or []
|
|
78
79
|
self.detailed_logs = detailed_logs
|
|
80
|
+
self.log_response_max_chars = log_response_max_chars
|
|
79
81
|
self.graph = None
|
|
80
82
|
|
|
81
83
|
async def _stream_llm(self,
|
|
@@ -184,7 +186,7 @@ class BaseAgent(ABC):
|
|
|
184
186
|
logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
|
|
185
187
|
return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
|
|
186
188
|
|
|
187
|
-
def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str
|
|
189
|
+
def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str) -> None:
|
|
188
190
|
"""
|
|
189
191
|
Log tool response with consistent formatting and length limits.
|
|
190
192
|
|
|
@@ -196,13 +198,11 @@ class BaseAgent(ABC):
|
|
|
196
198
|
The input that was passed to the tool
|
|
197
199
|
tool_response : str
|
|
198
200
|
The response from the tool
|
|
199
|
-
max_chars : int
|
|
200
|
-
Maximum number of characters to log (default: 1000)
|
|
201
201
|
"""
|
|
202
202
|
if self.detailed_logs:
|
|
203
203
|
# Truncate tool response if too long
|
|
204
|
-
display_response = tool_response[:
|
|
205
|
-
tool_response) >
|
|
204
|
+
display_response = tool_response[:self.log_response_max_chars] + "...(rest of response truncated)" if len(
|
|
205
|
+
tool_response) > self.log_response_max_chars else tool_response
|
|
206
206
|
|
|
207
207
|
# Format the tool input for display
|
|
208
208
|
tool_input_str = str(tool_input)
|
nat/agent/dual_node.py
CHANGED
|
@@ -35,8 +35,13 @@ class DualNodeAgent(BaseAgent):
|
|
|
35
35
|
llm: BaseChatModel,
|
|
36
36
|
tools: list[BaseTool],
|
|
37
37
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
38
|
-
detailed_logs: bool = False
|
|
39
|
-
|
|
38
|
+
detailed_logs: bool = False,
|
|
39
|
+
log_response_max_chars: int = 1000):
|
|
40
|
+
super().__init__(llm=llm,
|
|
41
|
+
tools=tools,
|
|
42
|
+
callbacks=callbacks,
|
|
43
|
+
detailed_logs=detailed_logs,
|
|
44
|
+
log_response_max_chars=log_response_max_chars)
|
|
40
45
|
|
|
41
46
|
@abstractmethod
|
|
42
47
|
async def agent_node(self, state: BaseModel) -> BaseModel:
|
nat/agent/react_agent/agent.py
CHANGED
|
@@ -73,11 +73,16 @@ class ReActAgentGraph(DualNodeAgent):
|
|
|
73
73
|
use_tool_schema: bool = True,
|
|
74
74
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
75
75
|
detailed_logs: bool = False,
|
|
76
|
+
log_response_max_chars: int = 1000,
|
|
76
77
|
retry_agent_response_parsing_errors: bool = True,
|
|
77
78
|
parse_agent_response_max_retries: int = 1,
|
|
78
79
|
tool_call_max_retries: int = 1,
|
|
79
80
|
pass_tool_call_errors_to_agent: bool = True):
|
|
80
|
-
super().__init__(llm=llm,
|
|
81
|
+
super().__init__(llm=llm,
|
|
82
|
+
tools=tools,
|
|
83
|
+
callbacks=callbacks,
|
|
84
|
+
detailed_logs=detailed_logs,
|
|
85
|
+
log_response_max_chars=log_response_max_chars)
|
|
81
86
|
self.parse_agent_response_max_retries = (parse_agent_response_max_retries
|
|
82
87
|
if retry_agent_response_parsing_errors else 1)
|
|
83
88
|
self.tool_call_max_retries = tool_call_max_retries
|
|
@@ -17,6 +17,7 @@ import logging
|
|
|
17
17
|
|
|
18
18
|
from pydantic import AliasChoices
|
|
19
19
|
from pydantic import Field
|
|
20
|
+
from pydantic import PositiveInt
|
|
20
21
|
|
|
21
22
|
from nat.builder.builder import Builder
|
|
22
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
@@ -65,6 +66,8 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
|
|
|
65
66
|
default=None,
|
|
66
67
|
description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
|
|
67
68
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
69
|
+
log_response_max_chars: PositiveInt = Field(
|
|
70
|
+
default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
|
|
68
71
|
use_openai_api: bool = Field(default=False,
|
|
69
72
|
description=("Use OpenAI API for the input/output types to the function. "
|
|
70
73
|
"If False, strings will be used."))
|
|
@@ -100,6 +103,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
100
103
|
tools=tools,
|
|
101
104
|
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
102
105
|
detailed_logs=config.verbose,
|
|
106
|
+
log_response_max_chars=config.log_response_max_chars,
|
|
103
107
|
retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
|
|
104
108
|
parse_agent_response_max_retries=config.parse_agent_response_max_retries,
|
|
105
109
|
tool_call_max_retries=config.tool_call_max_retries,
|
nat/agent/rewoo_agent/agent.py
CHANGED
|
@@ -66,8 +66,13 @@ class ReWOOAgentGraph(BaseAgent):
|
|
|
66
66
|
tools: list[BaseTool],
|
|
67
67
|
use_tool_schema: bool = True,
|
|
68
68
|
callbacks: list[AsyncCallbackHandler] | None = None,
|
|
69
|
-
detailed_logs: bool = False
|
|
70
|
-
|
|
69
|
+
detailed_logs: bool = False,
|
|
70
|
+
log_response_max_chars: int = 1000):
|
|
71
|
+
super().__init__(llm=llm,
|
|
72
|
+
tools=tools,
|
|
73
|
+
callbacks=callbacks,
|
|
74
|
+
detailed_logs=detailed_logs,
|
|
75
|
+
log_response_max_chars=log_response_max_chars)
|
|
71
76
|
|
|
72
77
|
logger.debug(
|
|
73
78
|
"%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
|
|
@@ -17,6 +17,7 @@ import logging
|
|
|
17
17
|
|
|
18
18
|
from pydantic import AliasChoices
|
|
19
19
|
from pydantic import Field
|
|
20
|
+
from pydantic import PositiveInt
|
|
20
21
|
|
|
21
22
|
from nat.builder.builder import Builder
|
|
22
23
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
@@ -52,6 +53,8 @@ class ReWOOAgentWorkflowConfig(FunctionBaseConfig, name="rewoo_agent"):
|
|
|
52
53
|
default=None,
|
|
53
54
|
description="Provides the SOLVER_PROMPT to use with the agent") # defaults to SOLVER_PROMPT in prompt.py
|
|
54
55
|
max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
|
|
56
|
+
log_response_max_chars: PositiveInt = Field(
|
|
57
|
+
default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
|
|
55
58
|
use_openai_api: bool = Field(default=False,
|
|
56
59
|
description=("Use OpenAI API for the input/output types to the function. "
|
|
57
60
|
"If False, strings will be used."))
|
|
@@ -113,7 +116,8 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
113
116
|
solver_prompt=solver_prompt,
|
|
114
117
|
tools=tools,
|
|
115
118
|
use_tool_schema=config.include_tool_input_schema_in_tool_description,
|
|
116
|
-
detailed_logs=config.verbose
|
|
119
|
+
detailed_logs=config.verbose,
|
|
120
|
+
log_response_max_chars=config.log_response_max_chars).build_graph()
|
|
117
121
|
|
|
118
122
|
async def _response_fn(input_message: ChatRequest) -> ChatResponse:
|
|
119
123
|
try:
|
|
@@ -55,9 +55,14 @@ class ToolCallAgentGraph(DualNodeAgent):
|
|
|
55
55
|
prompt: str | None = None,
|
|
56
56
|
callbacks: list[AsyncCallbackHandler] = None,
|
|
57
57
|
detailed_logs: bool = False,
|
|
58
|
+
log_response_max_chars: int = 1000,
|
|
58
59
|
handle_tool_errors: bool = True,
|
|
59
60
|
):
|
|
60
|
-
super().__init__(llm=llm,
|
|
61
|
+
super().__init__(llm=llm,
|
|
62
|
+
tools=tools,
|
|
63
|
+
callbacks=callbacks,
|
|
64
|
+
detailed_logs=detailed_logs,
|
|
65
|
+
log_response_max_chars=log_response_max_chars)
|
|
61
66
|
# some LLMs support tool calling
|
|
62
67
|
# these models accept the tool's input schema and decide when to use a tool based on the input's relevance
|
|
63
68
|
try:
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from pydantic import Field
|
|
19
|
+
from pydantic import PositiveInt
|
|
19
20
|
|
|
20
21
|
from nat.builder.builder import Builder
|
|
21
22
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
@@ -41,6 +42,8 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
|
|
|
41
42
|
handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
|
|
42
43
|
description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
|
|
43
44
|
max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
|
|
45
|
+
log_response_max_chars: PositiveInt = Field(
|
|
46
|
+
default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
|
|
44
47
|
system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
|
|
45
48
|
additional_instructions: str | None = Field(default=None,
|
|
46
49
|
description="Additional instructions appended to the system prompt.")
|
|
@@ -70,6 +73,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
|
|
|
70
73
|
tools=tools,
|
|
71
74
|
prompt=prompt,
|
|
72
75
|
detailed_logs=config.verbose,
|
|
76
|
+
log_response_max_chars=config.log_response_max_chars,
|
|
73
77
|
handle_tool_errors=config.handle_tool_errors).build_graph()
|
|
74
78
|
|
|
75
79
|
async def _response_fn(input_message: str) -> str:
|
nat/builder/context.py
CHANGED
|
@@ -31,6 +31,7 @@ from nat.data_models.intermediate_step import IntermediateStep
|
|
|
31
31
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
32
32
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
33
33
|
from nat.data_models.intermediate_step import StreamEventData
|
|
34
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
34
35
|
from nat.data_models.invocation_node import InvocationNode
|
|
35
36
|
from nat.runtime.user_metadata import RequestAttributes
|
|
36
37
|
from nat.utils.reactive.subject import Subject
|
|
@@ -174,7 +175,10 @@ class Context:
|
|
|
174
175
|
return self._context_state.user_message_id.get()
|
|
175
176
|
|
|
176
177
|
@contextmanager
|
|
177
|
-
def push_active_function(self,
|
|
178
|
+
def push_active_function(self,
|
|
179
|
+
function_name: str,
|
|
180
|
+
input_data: typing.Any | None,
|
|
181
|
+
metadata: dict[str, typing.Any] | TraceMetadata | None = None):
|
|
178
182
|
"""
|
|
179
183
|
Set the 'active_function' in context, push an invocation node,
|
|
180
184
|
AND create an OTel child span for that function call.
|
|
@@ -195,7 +199,8 @@ class Context:
|
|
|
195
199
|
IntermediateStepPayload(UUID=current_function_id,
|
|
196
200
|
event_type=IntermediateStepType.FUNCTION_START,
|
|
197
201
|
name=function_name,
|
|
198
|
-
data=StreamEventData(input=input_data)
|
|
202
|
+
data=StreamEventData(input=input_data),
|
|
203
|
+
metadata=metadata))
|
|
199
204
|
|
|
200
205
|
manager = ActiveFunctionContextManager()
|
|
201
206
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import importlib
|
|
18
|
+
import logging
|
|
19
|
+
import mimetypes
|
|
20
|
+
import time
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import click
|
|
24
|
+
|
|
25
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
26
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
27
|
+
from nat.object_store.interfaces import ObjectStore
|
|
28
|
+
from nat.object_store.models import ObjectStoreItem
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
STORE_CONFIGS = {
|
|
33
|
+
"s3": {
|
|
34
|
+
"module": "nat.plugins.s3.object_store", "config_class": "S3ObjectStoreClientConfig"
|
|
35
|
+
},
|
|
36
|
+
"mysql": {
|
|
37
|
+
"module": "nat.plugins.mysql.object_store", "config_class": "MySQLObjectStoreClientConfig"
|
|
38
|
+
},
|
|
39
|
+
"redis": {
|
|
40
|
+
"module": "nat.plugins.redis.object_store", "config_class": "RedisObjectStoreClientConfig"
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_object_store_config(**kwargs) -> ObjectStoreBaseConfig:
|
|
46
|
+
"""Process common object store arguments and return the config class"""
|
|
47
|
+
store_type = kwargs.pop("store_type")
|
|
48
|
+
config = STORE_CONFIGS[store_type]
|
|
49
|
+
module = importlib.import_module(config["module"])
|
|
50
|
+
config_class = getattr(module, config["config_class"])
|
|
51
|
+
return config_class(**kwargs)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def upload_file(object_store: ObjectStore, file_path: Path, key: str):
|
|
55
|
+
"""
|
|
56
|
+
Upload a single file to object store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
object_store: The object store instance to use.
|
|
60
|
+
file_path: The path to the file to upload.
|
|
61
|
+
key: The key to upload the file to.
|
|
62
|
+
"""
|
|
63
|
+
try:
|
|
64
|
+
data = await asyncio.to_thread(file_path.read_bytes)
|
|
65
|
+
|
|
66
|
+
item = ObjectStoreItem(data=data,
|
|
67
|
+
content_type=mimetypes.guess_type(str(file_path))[0],
|
|
68
|
+
metadata={
|
|
69
|
+
"original_filename": file_path.name,
|
|
70
|
+
"file_size": str(len(data)),
|
|
71
|
+
"file_extension": file_path.suffix,
|
|
72
|
+
"upload_timestamp": str(int(time.time()))
|
|
73
|
+
})
|
|
74
|
+
|
|
75
|
+
# Upload using upsert to allow overwriting
|
|
76
|
+
await object_store.upsert_object(key, item)
|
|
77
|
+
click.echo(f"✅ Uploaded: {file_path.name} -> {key}")
|
|
78
|
+
|
|
79
|
+
except Exception as e:
|
|
80
|
+
raise RuntimeError(f"Failed to upload {file_path.name}:\n{e}") from e
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def object_store_command_decorator(async_func):
|
|
84
|
+
"""
|
|
85
|
+
Decorator that handles the common object store command pattern.
|
|
86
|
+
|
|
87
|
+
The decorated function should take (store: ObjectStore, kwargs) as parameters
|
|
88
|
+
and return an exit code (0 for success).
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
@click.pass_context
|
|
92
|
+
def wrapper(ctx: click.Context, **kwargs):
|
|
93
|
+
config = ctx.obj["store_config"]
|
|
94
|
+
|
|
95
|
+
async def work():
|
|
96
|
+
async with WorkflowBuilder() as builder:
|
|
97
|
+
await builder.add_object_store(name="store", config=config)
|
|
98
|
+
store = await builder.get_object_store_client("store")
|
|
99
|
+
return await async_func(store, **kwargs)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
exit_code = asyncio.run(work())
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise click.ClickException(f"Command failed: {e}") from e
|
|
105
|
+
if exit_code != 0:
|
|
106
|
+
raise click.ClickException(f"Command failed with exit code {exit_code}")
|
|
107
|
+
return exit_code
|
|
108
|
+
|
|
109
|
+
return wrapper
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@click.command(name="upload", help="Upload a directory to an object store.")
|
|
113
|
+
@click.argument("local_dir",
|
|
114
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
|
|
115
|
+
required=True)
|
|
116
|
+
@click.help_option("--help", "-h")
|
|
117
|
+
@object_store_command_decorator
|
|
118
|
+
async def upload_command(store: ObjectStore, local_dir: Path, **_kwargs):
|
|
119
|
+
"""
|
|
120
|
+
Upload a directory to an object store.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
local_dir: The local directory to upload.
|
|
124
|
+
store: The object store to use.
|
|
125
|
+
_kwargs: Additional keyword arguments.
|
|
126
|
+
"""
|
|
127
|
+
try:
|
|
128
|
+
click.echo(f"📁 Processing directory: {local_dir}")
|
|
129
|
+
file_count = 0
|
|
130
|
+
|
|
131
|
+
# Process each file recursively
|
|
132
|
+
for file_path in local_dir.rglob("*"):
|
|
133
|
+
if file_path.is_file():
|
|
134
|
+
key = file_path.relative_to(local_dir).as_posix()
|
|
135
|
+
await upload_file(store, file_path, key)
|
|
136
|
+
file_count += 1
|
|
137
|
+
|
|
138
|
+
click.echo(f"✅ Directory uploaded successfully! {file_count} files uploaded.")
|
|
139
|
+
return 0
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise click.ClickException(f"❌ Failed to upload directory {local_dir}:\n {e}") from e
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@click.command(name="delete", help="Delete files from an object store.")
|
|
146
|
+
@click.argument("keys", type=str, required=True, nargs=-1)
|
|
147
|
+
@click.help_option("--help", "-h")
|
|
148
|
+
@object_store_command_decorator
|
|
149
|
+
async def delete_command(store: ObjectStore, keys: list[str], **_kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Delete files from an object store.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
store: The object store to use.
|
|
155
|
+
keys: The keys to delete.
|
|
156
|
+
_kwargs: Additional keyword arguments.
|
|
157
|
+
"""
|
|
158
|
+
deleted_count = 0
|
|
159
|
+
failed_count = 0
|
|
160
|
+
for key in keys:
|
|
161
|
+
try:
|
|
162
|
+
await store.delete_object(key)
|
|
163
|
+
click.echo(f"✅ Deleted: {key}")
|
|
164
|
+
deleted_count += 1
|
|
165
|
+
except Exception as e:
|
|
166
|
+
click.echo(f"❌ Failed to delete {key}: {e}")
|
|
167
|
+
failed_count += 1
|
|
168
|
+
|
|
169
|
+
click.echo(f"✅ Deletion completed! {deleted_count} keys deleted. {failed_count} keys failed to delete.")
|
|
170
|
+
return 0 if failed_count == 0 else 1
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@click.group(name="object-store", invoke_without_command=False, help="Manage object store operations.")
|
|
174
|
+
def object_store_command(**_kwargs):
|
|
175
|
+
"""Manage object store operations including uploading files and directories."""
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def register_object_store_commands():
|
|
180
|
+
|
|
181
|
+
@click.group(name="s3", invoke_without_command=False, help="S3 object store operations.")
|
|
182
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
183
|
+
@click.option("--endpoint-url", type=str, help="S3 endpoint URL")
|
|
184
|
+
@click.option("--access-key", type=str, help="S3 access key")
|
|
185
|
+
@click.option("--secret-key", type=str, help="S3 secret key")
|
|
186
|
+
@click.option("--region", type=str, help="S3 region")
|
|
187
|
+
@click.pass_context
|
|
188
|
+
def s3(ctx: click.Context, **kwargs):
|
|
189
|
+
ctx.ensure_object(dict)
|
|
190
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="s3", **kwargs)
|
|
191
|
+
|
|
192
|
+
@click.group(name="mysql", invoke_without_command=False, help="MySQL object store operations.")
|
|
193
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
194
|
+
@click.option("--host", type=str, help="MySQL host")
|
|
195
|
+
@click.option("--port", type=int, help="MySQL port")
|
|
196
|
+
@click.option("--db", type=str, help="MySQL database name")
|
|
197
|
+
@click.option("--username", type=str, help="MySQL username")
|
|
198
|
+
@click.option("--password", type=str, help="MySQL password")
|
|
199
|
+
@click.pass_context
|
|
200
|
+
def mysql(ctx: click.Context, **kwargs):
|
|
201
|
+
ctx.ensure_object(dict)
|
|
202
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="mysql", **kwargs)
|
|
203
|
+
|
|
204
|
+
@click.group(name="redis", invoke_without_command=False, help="Redis object store operations.")
|
|
205
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
206
|
+
@click.option("--host", type=str, help="Redis host")
|
|
207
|
+
@click.option("--port", type=int, help="Redis port")
|
|
208
|
+
@click.option("--db", type=int, help="Redis db")
|
|
209
|
+
@click.pass_context
|
|
210
|
+
def redis(ctx: click.Context, **kwargs):
|
|
211
|
+
ctx.ensure_object(dict)
|
|
212
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="redis", **kwargs)
|
|
213
|
+
|
|
214
|
+
commands = {"s3": s3, "mysql": mysql, "redis": redis}
|
|
215
|
+
|
|
216
|
+
for store_type, config in STORE_CONFIGS.items():
|
|
217
|
+
try:
|
|
218
|
+
importlib.import_module(config["module"])
|
|
219
|
+
command = commands[store_type]
|
|
220
|
+
object_store_command.add_command(command, name=store_type)
|
|
221
|
+
command.add_command(upload_command, name="upload")
|
|
222
|
+
command.add_command(delete_command, name="delete")
|
|
223
|
+
except ImportError:
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
register_object_store_commands()
|
nat/cli/entrypoint.py
CHANGED
|
@@ -33,6 +33,7 @@ import nest_asyncio
|
|
|
33
33
|
from .commands.configure.configure import configure_command
|
|
34
34
|
from .commands.evaluate import eval_command
|
|
35
35
|
from .commands.info.info import info_command
|
|
36
|
+
from .commands.object_store.object_store import object_store_command
|
|
36
37
|
from .commands.registry.registry import registry_command
|
|
37
38
|
from .commands.sizing.sizing import sizing
|
|
38
39
|
from .commands.start import start_command
|
|
@@ -107,11 +108,12 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
107
108
|
cli.add_command(validate_command, name="validate")
|
|
108
109
|
cli.add_command(workflow_command, name="workflow")
|
|
109
110
|
cli.add_command(sizing, name="sizing")
|
|
111
|
+
cli.add_command(object_store_command, name="object-store")
|
|
110
112
|
|
|
111
113
|
# Aliases
|
|
112
114
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
113
115
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
114
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp")
|
|
116
|
+
cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
|
|
115
117
|
|
|
116
118
|
|
|
117
119
|
@cli.result_callback()
|
|
@@ -16,8 +16,6 @@
|
|
|
16
16
|
from collections.abc import Sequence
|
|
17
17
|
from dataclasses import dataclass
|
|
18
18
|
from re import Pattern
|
|
19
|
-
from typing import Generic
|
|
20
|
-
from typing import TypeVar
|
|
21
19
|
|
|
22
20
|
from pydantic import model_validator
|
|
23
21
|
|
|
@@ -33,10 +31,7 @@ class GatedFieldMixinConfig:
|
|
|
33
31
|
keys: Sequence[str]
|
|
34
32
|
|
|
35
33
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class GatedFieldMixin(Generic[T]):
|
|
34
|
+
class GatedFieldMixin:
|
|
40
35
|
"""
|
|
41
36
|
A mixin that gates a field based on specified keys.
|
|
42
37
|
|
|
@@ -46,7 +41,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
46
41
|
----------
|
|
47
42
|
field_name: `str`
|
|
48
43
|
The name of the field.
|
|
49
|
-
default_if_supported: `
|
|
44
|
+
default_if_supported: `object | None`
|
|
50
45
|
The default value of the field if it is supported for the key.
|
|
51
46
|
keys: `Sequence[str]`
|
|
52
47
|
A sequence of keys that are used to validate the field.
|
|
@@ -61,7 +56,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
61
56
|
def __init_subclass__(
|
|
62
57
|
cls,
|
|
63
58
|
field_name: str | None = None,
|
|
64
|
-
default_if_supported:
|
|
59
|
+
default_if_supported: object | None = None,
|
|
65
60
|
keys: Sequence[str] | None = None,
|
|
66
61
|
unsupported: Sequence[Pattern[str]] | None = None,
|
|
67
62
|
supported: Sequence[Pattern[str]] | None = None,
|
|
@@ -90,7 +85,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
90
85
|
def _setup_direct_mixin(
|
|
91
86
|
cls,
|
|
92
87
|
field_name: str,
|
|
93
|
-
default_if_supported:
|
|
88
|
+
default_if_supported: object | None,
|
|
94
89
|
unsupported: Sequence[Pattern[str]] | None,
|
|
95
90
|
supported: Sequence[Pattern[str]] | None,
|
|
96
91
|
keys: Sequence[str],
|
|
@@ -135,7 +130,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
135
130
|
def _create_gated_field_validator(
|
|
136
131
|
cls,
|
|
137
132
|
field_name: str,
|
|
138
|
-
default_if_supported:
|
|
133
|
+
default_if_supported: object | None,
|
|
139
134
|
unsupported: Sequence[Pattern[str]] | None,
|
|
140
135
|
supported: Sequence[Pattern[str]] | None,
|
|
141
136
|
keys: Sequence[str],
|
|
@@ -167,16 +162,19 @@ class GatedFieldMixin(Generic[T]):
|
|
|
167
162
|
keys: Sequence[str],
|
|
168
163
|
) -> bool:
|
|
169
164
|
"""Check if a specific field is supported based on its configuration and keys."""
|
|
165
|
+
seen = False
|
|
170
166
|
for key in keys:
|
|
171
167
|
if not hasattr(instance, key):
|
|
172
168
|
continue
|
|
169
|
+
seen = True
|
|
173
170
|
value = str(getattr(instance, key))
|
|
174
171
|
if supported is not None:
|
|
175
|
-
|
|
172
|
+
if any(p.search(value) for p in supported):
|
|
173
|
+
return True
|
|
176
174
|
elif unsupported is not None:
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
return True
|
|
175
|
+
if any(p.search(value) for p in unsupported):
|
|
176
|
+
return False
|
|
177
|
+
return True if not seen else (unsupported is not None)
|
|
180
178
|
|
|
181
179
|
@classmethod
|
|
182
180
|
def _find_blocking_key(
|
|
@@ -23,7 +23,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
|
23
23
|
|
|
24
24
|
class TemperatureMixin(
|
|
25
25
|
BaseModel,
|
|
26
|
-
GatedFieldMixin
|
|
26
|
+
GatedFieldMixin,
|
|
27
27
|
field_name="temperature",
|
|
28
28
|
default_if_supported=0.0,
|
|
29
29
|
keys=("model_name", "model", "azure_deployment"),
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from pydantic import BaseModel
|
|
19
|
+
from pydantic import Field
|
|
20
|
+
|
|
21
|
+
from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
22
|
+
|
|
23
|
+
# The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
|
|
24
|
+
# regex patterns
|
|
25
|
+
_NVIDIA_NEMOTRON_REGEX = re.compile(r"^nvidia/nvidia.*nemotron", re.IGNORECASE)
|
|
26
|
+
_LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
|
|
27
|
+
_MODEL_KEYS = ("model_name", "model", "azure_deployment")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ThinkingMixin(
|
|
31
|
+
BaseModel,
|
|
32
|
+
GatedFieldMixin,
|
|
33
|
+
field_name="thinking",
|
|
34
|
+
default_if_supported=None,
|
|
35
|
+
keys=_MODEL_KEYS,
|
|
36
|
+
supported=(_NVIDIA_NEMOTRON_REGEX, _LLAMA_NEMOTRON_REGEX),
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Mixin class for thinking configuration. Only supported on Nemotron models.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
thinking: Whether to enable thinking. Defaults to None when supported on the model.
|
|
43
|
+
"""
|
|
44
|
+
thinking: bool | None = Field(
|
|
45
|
+
default=None,
|
|
46
|
+
description="Whether to enable thinking. Defaults to None when supported on the model.",
|
|
47
|
+
exclude=True,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def thinking_system_prompt(self) -> str | None:
|
|
52
|
+
"""
|
|
53
|
+
Returns the system prompt to use for thinking.
|
|
54
|
+
For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
|
|
55
|
+
For Llama Nemotron, returns "detailed thinking on" if enabled, else "detailed thinking off".
|
|
56
|
+
If thinking is not supported on the model, returns None.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
str | None: The system prompt to use for thinking.
|
|
60
|
+
"""
|
|
61
|
+
if self.thinking is None:
|
|
62
|
+
return None
|
|
63
|
+
for key in _MODEL_KEYS:
|
|
64
|
+
if hasattr(self, key):
|
|
65
|
+
if _NVIDIA_NEMOTRON_REGEX.match(getattr(self, key)):
|
|
66
|
+
return "/think" if self.thinking else "/no_think"
|
|
67
|
+
elif _LLAMA_NEMOTRON_REGEX.match(getattr(self, key)):
|
|
68
|
+
return f"detailed thinking {'on' if self.thinking else 'off'}"
|
nat/data_models/top_p_mixin.py
CHANGED