nvidia-nat 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. nat/agent/base.py +6 -6
  2. nat/agent/dual_node.py +7 -2
  3. nat/agent/react_agent/agent.py +6 -1
  4. nat/agent/react_agent/register.py +4 -0
  5. nat/agent/rewoo_agent/agent.py +7 -2
  6. nat/agent/rewoo_agent/register.py +5 -1
  7. nat/agent/tool_calling_agent/agent.py +6 -1
  8. nat/agent/tool_calling_agent/register.py +4 -0
  9. nat/builder/context.py +7 -2
  10. nat/cli/commands/object_store/__init__.py +14 -0
  11. nat/cli/commands/object_store/object_store.py +227 -0
  12. nat/cli/entrypoint.py +3 -1
  13. nat/data_models/gated_field_mixin.py +12 -14
  14. nat/data_models/temperature_mixin.py +1 -1
  15. nat/data_models/thinking_mixin.py +68 -0
  16. nat/data_models/top_p_mixin.py +1 -1
  17. nat/llm/aws_bedrock_llm.py +10 -9
  18. nat/llm/azure_openai_llm.py +9 -1
  19. nat/llm/nim_llm.py +2 -1
  20. nat/llm/openai_llm.py +2 -1
  21. nat/llm/utils/thinking.py +215 -0
  22. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  23. nat/observability/processor/processor_factory.py +70 -0
  24. nat/profiler/decorators/function_tracking.py +125 -0
  25. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/METADATA +3 -1
  26. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/RECORD +31 -25
  27. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/WHEEL +0 -0
  28. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/entry_points.txt +0 -0
  29. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  30. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/licenses/LICENSE.md +0 -0
  31. {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250828.dist-info}/top_level.txt +0 -0
nat/agent/base.py CHANGED
@@ -70,12 +70,14 @@ class BaseAgent(ABC):
70
70
  llm: BaseChatModel,
71
71
  tools: list[BaseTool],
72
72
  callbacks: list[AsyncCallbackHandler] | None = None,
73
- detailed_logs: bool = False) -> None:
73
+ detailed_logs: bool = False,
74
+ log_response_max_chars: int = 1000) -> None:
74
75
  logger.debug("Initializing Agent Graph")
75
76
  self.llm = llm
76
77
  self.tools = tools
77
78
  self.callbacks = callbacks or []
78
79
  self.detailed_logs = detailed_logs
80
+ self.log_response_max_chars = log_response_max_chars
79
81
  self.graph = None
80
82
 
81
83
  async def _stream_llm(self,
@@ -184,7 +186,7 @@ class BaseAgent(ABC):
184
186
  logger.error("%s %s", AGENT_LOG_PREFIX, error_content)
185
187
  return ToolMessage(name=tool.name, tool_call_id=tool.name, content=error_content, status="error")
186
188
 
187
- def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str, max_chars: int = 1000) -> None:
189
+ def _log_tool_response(self, tool_name: str, tool_input: Any, tool_response: str) -> None:
188
190
  """
189
191
  Log tool response with consistent formatting and length limits.
190
192
 
@@ -196,13 +198,11 @@ class BaseAgent(ABC):
196
198
  The input that was passed to the tool
197
199
  tool_response : str
198
200
  The response from the tool
199
- max_chars : int
200
- Maximum number of characters to log (default: 1000)
201
201
  """
202
202
  if self.detailed_logs:
203
203
  # Truncate tool response if too long
204
- display_response = tool_response[:max_chars] + "...(rest of response truncated)" if len(
205
- tool_response) > max_chars else tool_response
204
+ display_response = tool_response[:self.log_response_max_chars] + "...(rest of response truncated)" if len(
205
+ tool_response) > self.log_response_max_chars else tool_response
206
206
 
207
207
  # Format the tool input for display
208
208
  tool_input_str = str(tool_input)
nat/agent/dual_node.py CHANGED
@@ -35,8 +35,13 @@ class DualNodeAgent(BaseAgent):
35
35
  llm: BaseChatModel,
36
36
  tools: list[BaseTool],
37
37
  callbacks: list[AsyncCallbackHandler] | None = None,
38
- detailed_logs: bool = False):
39
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
38
+ detailed_logs: bool = False,
39
+ log_response_max_chars: int = 1000):
40
+ super().__init__(llm=llm,
41
+ tools=tools,
42
+ callbacks=callbacks,
43
+ detailed_logs=detailed_logs,
44
+ log_response_max_chars=log_response_max_chars)
40
45
 
41
46
  @abstractmethod
42
47
  async def agent_node(self, state: BaseModel) -> BaseModel:
@@ -73,11 +73,16 @@ class ReActAgentGraph(DualNodeAgent):
73
73
  use_tool_schema: bool = True,
74
74
  callbacks: list[AsyncCallbackHandler] | None = None,
75
75
  detailed_logs: bool = False,
76
+ log_response_max_chars: int = 1000,
76
77
  retry_agent_response_parsing_errors: bool = True,
77
78
  parse_agent_response_max_retries: int = 1,
78
79
  tool_call_max_retries: int = 1,
79
80
  pass_tool_call_errors_to_agent: bool = True):
80
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
81
+ super().__init__(llm=llm,
82
+ tools=tools,
83
+ callbacks=callbacks,
84
+ detailed_logs=detailed_logs,
85
+ log_response_max_chars=log_response_max_chars)
81
86
  self.parse_agent_response_max_retries = (parse_agent_response_max_retries
82
87
  if retry_agent_response_parsing_errors else 1)
83
88
  self.tool_call_max_retries = tool_call_max_retries
@@ -17,6 +17,7 @@ import logging
17
17
 
18
18
  from pydantic import AliasChoices
19
19
  from pydantic import Field
20
+ from pydantic import PositiveInt
20
21
 
21
22
  from nat.builder.builder import Builder
22
23
  from nat.builder.framework_enum import LLMFrameworkEnum
@@ -65,6 +66,8 @@ class ReActAgentWorkflowConfig(FunctionBaseConfig, name="react_agent"):
65
66
  default=None,
66
67
  description="Provides the SYSTEM_PROMPT to use with the agent") # defaults to SYSTEM_PROMPT in prompt.py
67
68
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
69
+ log_response_max_chars: PositiveInt = Field(
70
+ default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
68
71
  use_openai_api: bool = Field(default=False,
69
72
  description=("Use OpenAI API for the input/output types to the function. "
70
73
  "If False, strings will be used."))
@@ -100,6 +103,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
100
103
  tools=tools,
101
104
  use_tool_schema=config.include_tool_input_schema_in_tool_description,
102
105
  detailed_logs=config.verbose,
106
+ log_response_max_chars=config.log_response_max_chars,
103
107
  retry_agent_response_parsing_errors=config.retry_agent_response_parsing_errors,
104
108
  parse_agent_response_max_retries=config.parse_agent_response_max_retries,
105
109
  tool_call_max_retries=config.tool_call_max_retries,
@@ -66,8 +66,13 @@ class ReWOOAgentGraph(BaseAgent):
66
66
  tools: list[BaseTool],
67
67
  use_tool_schema: bool = True,
68
68
  callbacks: list[AsyncCallbackHandler] | None = None,
69
- detailed_logs: bool = False):
70
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
69
+ detailed_logs: bool = False,
70
+ log_response_max_chars: int = 1000):
71
+ super().__init__(llm=llm,
72
+ tools=tools,
73
+ callbacks=callbacks,
74
+ detailed_logs=detailed_logs,
75
+ log_response_max_chars=log_response_max_chars)
71
76
 
72
77
  logger.debug(
73
78
  "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
@@ -17,6 +17,7 @@ import logging
17
17
 
18
18
  from pydantic import AliasChoices
19
19
  from pydantic import Field
20
+ from pydantic import PositiveInt
20
21
 
21
22
  from nat.builder.builder import Builder
22
23
  from nat.builder.framework_enum import LLMFrameworkEnum
@@ -52,6 +53,8 @@ class ReWOOAgentWorkflowConfig(FunctionBaseConfig, name="rewoo_agent"):
52
53
  default=None,
53
54
  description="Provides the SOLVER_PROMPT to use with the agent") # defaults to SOLVER_PROMPT in prompt.py
54
55
  max_history: int = Field(default=15, description="Maximum number of messages to keep in the conversation history.")
56
+ log_response_max_chars: PositiveInt = Field(
57
+ default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
55
58
  use_openai_api: bool = Field(default=False,
56
59
  description=("Use OpenAI API for the input/output types to the function. "
57
60
  "If False, strings will be used."))
@@ -113,7 +116,8 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
113
116
  solver_prompt=solver_prompt,
114
117
  tools=tools,
115
118
  use_tool_schema=config.include_tool_input_schema_in_tool_description,
116
- detailed_logs=config.verbose).build_graph()
119
+ detailed_logs=config.verbose,
120
+ log_response_max_chars=config.log_response_max_chars).build_graph()
117
121
 
118
122
  async def _response_fn(input_message: ChatRequest) -> ChatResponse:
119
123
  try:
@@ -55,9 +55,14 @@ class ToolCallAgentGraph(DualNodeAgent):
55
55
  prompt: str | None = None,
56
56
  callbacks: list[AsyncCallbackHandler] = None,
57
57
  detailed_logs: bool = False,
58
+ log_response_max_chars: int = 1000,
58
59
  handle_tool_errors: bool = True,
59
60
  ):
60
- super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)
61
+ super().__init__(llm=llm,
62
+ tools=tools,
63
+ callbacks=callbacks,
64
+ detailed_logs=detailed_logs,
65
+ log_response_max_chars=log_response_max_chars)
61
66
  # some LLMs support tool calling
62
67
  # these models accept the tool's input schema and decide when to use a tool based on the input's relevance
63
68
  try:
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from pydantic import Field
19
+ from pydantic import PositiveInt
19
20
 
20
21
  from nat.builder.builder import Builder
21
22
  from nat.builder.framework_enum import LLMFrameworkEnum
@@ -41,6 +42,8 @@ class ToolCallAgentWorkflowConfig(FunctionBaseConfig, name="tool_calling_agent")
41
42
  handle_tool_errors: bool = Field(default=True, description="Specify ability to handle tool calling errors.")
42
43
  description: str = Field(default="Tool Calling Agent Workflow", description="Description of this functions use.")
43
44
  max_iterations: int = Field(default=15, description="Number of tool calls before stoping the tool calling agent.")
45
+ log_response_max_chars: PositiveInt = Field(
46
+ default=1000, description="Maximum number of characters to display in logs when logging tool responses.")
44
47
  system_prompt: str | None = Field(default=None, description="Provides the system prompt to use with the agent.")
45
48
  additional_instructions: str | None = Field(default=None,
46
49
  description="Additional instructions appended to the system prompt.")
@@ -70,6 +73,7 @@ async def tool_calling_agent_workflow(config: ToolCallAgentWorkflowConfig, build
70
73
  tools=tools,
71
74
  prompt=prompt,
72
75
  detailed_logs=config.verbose,
76
+ log_response_max_chars=config.log_response_max_chars,
73
77
  handle_tool_errors=config.handle_tool_errors).build_graph()
74
78
 
75
79
  async def _response_fn(input_message: str) -> str:
nat/builder/context.py CHANGED
@@ -31,6 +31,7 @@ from nat.data_models.intermediate_step import IntermediateStep
31
31
  from nat.data_models.intermediate_step import IntermediateStepPayload
32
32
  from nat.data_models.intermediate_step import IntermediateStepType
33
33
  from nat.data_models.intermediate_step import StreamEventData
34
+ from nat.data_models.intermediate_step import TraceMetadata
34
35
  from nat.data_models.invocation_node import InvocationNode
35
36
  from nat.runtime.user_metadata import RequestAttributes
36
37
  from nat.utils.reactive.subject import Subject
@@ -174,7 +175,10 @@ class Context:
174
175
  return self._context_state.user_message_id.get()
175
176
 
176
177
  @contextmanager
177
- def push_active_function(self, function_name: str, input_data: typing.Any | None):
178
+ def push_active_function(self,
179
+ function_name: str,
180
+ input_data: typing.Any | None,
181
+ metadata: dict[str, typing.Any] | TraceMetadata | None = None):
178
182
  """
179
183
  Set the 'active_function' in context, push an invocation node,
180
184
  AND create an OTel child span for that function call.
@@ -195,7 +199,8 @@ class Context:
195
199
  IntermediateStepPayload(UUID=current_function_id,
196
200
  event_type=IntermediateStepType.FUNCTION_START,
197
201
  name=function_name,
198
- data=StreamEventData(input=input_data)))
202
+ data=StreamEventData(input=input_data),
203
+ metadata=metadata))
199
204
 
200
205
  manager = ActiveFunctionContextManager()
201
206
 
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,227 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import importlib
18
+ import logging
19
+ import mimetypes
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import click
24
+
25
+ from nat.builder.workflow_builder import WorkflowBuilder
26
+ from nat.data_models.object_store import ObjectStoreBaseConfig
27
+ from nat.object_store.interfaces import ObjectStore
28
+ from nat.object_store.models import ObjectStoreItem
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ STORE_CONFIGS = {
33
+ "s3": {
34
+ "module": "nat.plugins.s3.object_store", "config_class": "S3ObjectStoreClientConfig"
35
+ },
36
+ "mysql": {
37
+ "module": "nat.plugins.mysql.object_store", "config_class": "MySQLObjectStoreClientConfig"
38
+ },
39
+ "redis": {
40
+ "module": "nat.plugins.redis.object_store", "config_class": "RedisObjectStoreClientConfig"
41
+ }
42
+ }
43
+
44
+
45
+ def get_object_store_config(**kwargs) -> ObjectStoreBaseConfig:
46
+ """Process common object store arguments and return the config class"""
47
+ store_type = kwargs.pop("store_type")
48
+ config = STORE_CONFIGS[store_type]
49
+ module = importlib.import_module(config["module"])
50
+ config_class = getattr(module, config["config_class"])
51
+ return config_class(**kwargs)
52
+
53
+
54
+ async def upload_file(object_store: ObjectStore, file_path: Path, key: str):
55
+ """
56
+ Upload a single file to object store.
57
+
58
+ Args:
59
+ object_store: The object store instance to use.
60
+ file_path: The path to the file to upload.
61
+ key: The key to upload the file to.
62
+ """
63
+ try:
64
+ data = await asyncio.to_thread(file_path.read_bytes)
65
+
66
+ item = ObjectStoreItem(data=data,
67
+ content_type=mimetypes.guess_type(str(file_path))[0],
68
+ metadata={
69
+ "original_filename": file_path.name,
70
+ "file_size": str(len(data)),
71
+ "file_extension": file_path.suffix,
72
+ "upload_timestamp": str(int(time.time()))
73
+ })
74
+
75
+ # Upload using upsert to allow overwriting
76
+ await object_store.upsert_object(key, item)
77
+ click.echo(f"✅ Uploaded: {file_path.name} -> {key}")
78
+
79
+ except Exception as e:
80
+ raise RuntimeError(f"Failed to upload {file_path.name}:\n{e}") from e
81
+
82
+
83
+ def object_store_command_decorator(async_func):
84
+ """
85
+ Decorator that handles the common object store command pattern.
86
+
87
+ The decorated function should take (store: ObjectStore, kwargs) as parameters
88
+ and return an exit code (0 for success).
89
+ """
90
+
91
+ @click.pass_context
92
+ def wrapper(ctx: click.Context, **kwargs):
93
+ config = ctx.obj["store_config"]
94
+
95
+ async def work():
96
+ async with WorkflowBuilder() as builder:
97
+ await builder.add_object_store(name="store", config=config)
98
+ store = await builder.get_object_store_client("store")
99
+ return await async_func(store, **kwargs)
100
+
101
+ try:
102
+ exit_code = asyncio.run(work())
103
+ except Exception as e:
104
+ raise click.ClickException(f"Command failed: {e}") from e
105
+ if exit_code != 0:
106
+ raise click.ClickException(f"Command failed with exit code {exit_code}")
107
+ return exit_code
108
+
109
+ return wrapper
110
+
111
+
112
+ @click.command(name="upload", help="Upload a directory to an object store.")
113
+ @click.argument("local_dir",
114
+ type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
115
+ required=True)
116
+ @click.help_option("--help", "-h")
117
+ @object_store_command_decorator
118
+ async def upload_command(store: ObjectStore, local_dir: Path, **_kwargs):
119
+ """
120
+ Upload a directory to an object store.
121
+
122
+ Args:
123
+ local_dir: The local directory to upload.
124
+ store: The object store to use.
125
+ _kwargs: Additional keyword arguments.
126
+ """
127
+ try:
128
+ click.echo(f"📁 Processing directory: {local_dir}")
129
+ file_count = 0
130
+
131
+ # Process each file recursively
132
+ for file_path in local_dir.rglob("*"):
133
+ if file_path.is_file():
134
+ key = file_path.relative_to(local_dir).as_posix()
135
+ await upload_file(store, file_path, key)
136
+ file_count += 1
137
+
138
+ click.echo(f"✅ Directory uploaded successfully! {file_count} files uploaded.")
139
+ return 0
140
+
141
+ except Exception as e:
142
+ raise click.ClickException(f"❌ Failed to upload directory {local_dir}:\n {e}") from e
143
+
144
+
145
+ @click.command(name="delete", help="Delete files from an object store.")
146
+ @click.argument("keys", type=str, required=True, nargs=-1)
147
+ @click.help_option("--help", "-h")
148
+ @object_store_command_decorator
149
+ async def delete_command(store: ObjectStore, keys: list[str], **_kwargs):
150
+ """
151
+ Delete files from an object store.
152
+
153
+ Args:
154
+ store: The object store to use.
155
+ keys: The keys to delete.
156
+ _kwargs: Additional keyword arguments.
157
+ """
158
+ deleted_count = 0
159
+ failed_count = 0
160
+ for key in keys:
161
+ try:
162
+ await store.delete_object(key)
163
+ click.echo(f"✅ Deleted: {key}")
164
+ deleted_count += 1
165
+ except Exception as e:
166
+ click.echo(f"❌ Failed to delete {key}: {e}")
167
+ failed_count += 1
168
+
169
+ click.echo(f"✅ Deletion completed! {deleted_count} keys deleted. {failed_count} keys failed to delete.")
170
+ return 0 if failed_count == 0 else 1
171
+
172
+
173
+ @click.group(name="object-store", invoke_without_command=False, help="Manage object store operations.")
174
+ def object_store_command(**_kwargs):
175
+ """Manage object store operations including uploading files and directories."""
176
+ pass
177
+
178
+
179
+ def register_object_store_commands():
180
+
181
+ @click.group(name="s3", invoke_without_command=False, help="S3 object store operations.")
182
+ @click.argument("bucket_name", type=str, required=True)
183
+ @click.option("--endpoint-url", type=str, help="S3 endpoint URL")
184
+ @click.option("--access-key", type=str, help="S3 access key")
185
+ @click.option("--secret-key", type=str, help="S3 secret key")
186
+ @click.option("--region", type=str, help="S3 region")
187
+ @click.pass_context
188
+ def s3(ctx: click.Context, **kwargs):
189
+ ctx.ensure_object(dict)
190
+ ctx.obj["store_config"] = get_object_store_config(store_type="s3", **kwargs)
191
+
192
+ @click.group(name="mysql", invoke_without_command=False, help="MySQL object store operations.")
193
+ @click.argument("bucket_name", type=str, required=True)
194
+ @click.option("--host", type=str, help="MySQL host")
195
+ @click.option("--port", type=int, help="MySQL port")
196
+ @click.option("--db", type=str, help="MySQL database name")
197
+ @click.option("--username", type=str, help="MySQL username")
198
+ @click.option("--password", type=str, help="MySQL password")
199
+ @click.pass_context
200
+ def mysql(ctx: click.Context, **kwargs):
201
+ ctx.ensure_object(dict)
202
+ ctx.obj["store_config"] = get_object_store_config(store_type="mysql", **kwargs)
203
+
204
+ @click.group(name="redis", invoke_without_command=False, help="Redis object store operations.")
205
+ @click.argument("bucket_name", type=str, required=True)
206
+ @click.option("--host", type=str, help="Redis host")
207
+ @click.option("--port", type=int, help="Redis port")
208
+ @click.option("--db", type=int, help="Redis db")
209
+ @click.pass_context
210
+ def redis(ctx: click.Context, **kwargs):
211
+ ctx.ensure_object(dict)
212
+ ctx.obj["store_config"] = get_object_store_config(store_type="redis", **kwargs)
213
+
214
+ commands = {"s3": s3, "mysql": mysql, "redis": redis}
215
+
216
+ for store_type, config in STORE_CONFIGS.items():
217
+ try:
218
+ importlib.import_module(config["module"])
219
+ command = commands[store_type]
220
+ object_store_command.add_command(command, name=store_type)
221
+ command.add_command(upload_command, name="upload")
222
+ command.add_command(delete_command, name="delete")
223
+ except ImportError:
224
+ pass
225
+
226
+
227
+ register_object_store_commands()
nat/cli/entrypoint.py CHANGED
@@ -33,6 +33,7 @@ import nest_asyncio
33
33
  from .commands.configure.configure import configure_command
34
34
  from .commands.evaluate import eval_command
35
35
  from .commands.info.info import info_command
36
+ from .commands.object_store.object_store import object_store_command
36
37
  from .commands.registry.registry import registry_command
37
38
  from .commands.sizing.sizing import sizing
38
39
  from .commands.start import start_command
@@ -107,11 +108,12 @@ cli.add_command(uninstall_command, name="uninstall")
107
108
  cli.add_command(validate_command, name="validate")
108
109
  cli.add_command(workflow_command, name="workflow")
109
110
  cli.add_command(sizing, name="sizing")
111
+ cli.add_command(object_store_command, name="object-store")
110
112
 
111
113
  # Aliases
112
114
  cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
113
115
  cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
114
- cli.add_command(start_command.get_command(None, "mcp"), name="mcp")
116
+ cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
115
117
 
116
118
 
117
119
  @cli.result_callback()
@@ -16,8 +16,6 @@
16
16
  from collections.abc import Sequence
17
17
  from dataclasses import dataclass
18
18
  from re import Pattern
19
- from typing import Generic
20
- from typing import TypeVar
21
19
 
22
20
  from pydantic import model_validator
23
21
 
@@ -33,10 +31,7 @@ class GatedFieldMixinConfig:
33
31
  keys: Sequence[str]
34
32
 
35
33
 
36
- T = TypeVar("T")
37
-
38
-
39
- class GatedFieldMixin(Generic[T]):
34
+ class GatedFieldMixin:
40
35
  """
41
36
  A mixin that gates a field based on specified keys.
42
37
 
@@ -46,7 +41,7 @@ class GatedFieldMixin(Generic[T]):
46
41
  ----------
47
42
  field_name: `str`
48
43
  The name of the field.
49
- default_if_supported: `T | None`
44
+ default_if_supported: `object | None`
50
45
  The default value of the field if it is supported for the key.
51
46
  keys: `Sequence[str]`
52
47
  A sequence of keys that are used to validate the field.
@@ -61,7 +56,7 @@ class GatedFieldMixin(Generic[T]):
61
56
  def __init_subclass__(
62
57
  cls,
63
58
  field_name: str | None = None,
64
- default_if_supported: T | None = None,
59
+ default_if_supported: object | None = None,
65
60
  keys: Sequence[str] | None = None,
66
61
  unsupported: Sequence[Pattern[str]] | None = None,
67
62
  supported: Sequence[Pattern[str]] | None = None,
@@ -90,7 +85,7 @@ class GatedFieldMixin(Generic[T]):
90
85
  def _setup_direct_mixin(
91
86
  cls,
92
87
  field_name: str,
93
- default_if_supported: T | None,
88
+ default_if_supported: object | None,
94
89
  unsupported: Sequence[Pattern[str]] | None,
95
90
  supported: Sequence[Pattern[str]] | None,
96
91
  keys: Sequence[str],
@@ -135,7 +130,7 @@ class GatedFieldMixin(Generic[T]):
135
130
  def _create_gated_field_validator(
136
131
  cls,
137
132
  field_name: str,
138
- default_if_supported: T | None,
133
+ default_if_supported: object | None,
139
134
  unsupported: Sequence[Pattern[str]] | None,
140
135
  supported: Sequence[Pattern[str]] | None,
141
136
  keys: Sequence[str],
@@ -167,16 +162,19 @@ class GatedFieldMixin(Generic[T]):
167
162
  keys: Sequence[str],
168
163
  ) -> bool:
169
164
  """Check if a specific field is supported based on its configuration and keys."""
165
+ seen = False
170
166
  for key in keys:
171
167
  if not hasattr(instance, key):
172
168
  continue
169
+ seen = True
173
170
  value = str(getattr(instance, key))
174
171
  if supported is not None:
175
- return any(p.search(value) for p in supported)
172
+ if any(p.search(value) for p in supported):
173
+ return True
176
174
  elif unsupported is not None:
177
- return not any(p.search(value) for p in unsupported)
178
- # Default to supported if no model keys found
179
- return True
175
+ if any(p.search(value) for p in unsupported):
176
+ return False
177
+ return True if not seen else (unsupported is not None)
180
178
 
181
179
  @classmethod
182
180
  def _find_blocking_key(
@@ -23,7 +23,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
23
23
 
24
24
  class TemperatureMixin(
25
25
  BaseModel,
26
- GatedFieldMixin[float],
26
+ GatedFieldMixin,
27
27
  field_name="temperature",
28
28
  default_if_supported=0.0,
29
29
  keys=("model_name", "model", "azure_deployment"),
@@ -0,0 +1,68 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+
18
+ from pydantic import BaseModel
19
+ from pydantic import Field
20
+
21
+ from nat.data_models.gated_field_mixin import GatedFieldMixin
22
+
23
+ # The system prompt format for thinking is different for these, so we need to distinguish them here with two separate
24
+ # regex patterns
25
+ _NVIDIA_NEMOTRON_REGEX = re.compile(r"^nvidia/nvidia.*nemotron", re.IGNORECASE)
26
+ _LLAMA_NEMOTRON_REGEX = re.compile(r"^nvidia/llama.*nemotron", re.IGNORECASE)
27
+ _MODEL_KEYS = ("model_name", "model", "azure_deployment")
28
+
29
+
30
+ class ThinkingMixin(
31
+ BaseModel,
32
+ GatedFieldMixin,
33
+ field_name="thinking",
34
+ default_if_supported=None,
35
+ keys=_MODEL_KEYS,
36
+ supported=(_NVIDIA_NEMOTRON_REGEX, _LLAMA_NEMOTRON_REGEX),
37
+ ):
38
+ """
39
+ Mixin class for thinking configuration. Only supported on Nemotron models.
40
+
41
+ Attributes:
42
+ thinking: Whether to enable thinking. Defaults to None when supported on the model.
43
+ """
44
+ thinking: bool | None = Field(
45
+ default=None,
46
+ description="Whether to enable thinking. Defaults to None when supported on the model.",
47
+ exclude=True,
48
+ )
49
+
50
+ @property
51
+ def thinking_system_prompt(self) -> str | None:
52
+ """
53
+ Returns the system prompt to use for thinking.
54
+ For NVIDIA Nemotron, returns "/think" if enabled, else "/no_think".
55
+ For Llama Nemotron, returns "detailed thinking on" if enabled, else "detailed thinking off".
56
+ If thinking is not supported on the model, returns None.
57
+
58
+ Returns:
59
+ str | None: The system prompt to use for thinking.
60
+ """
61
+ if self.thinking is None:
62
+ return None
63
+ for key in _MODEL_KEYS:
64
+ if hasattr(self, key):
65
+ if _NVIDIA_NEMOTRON_REGEX.match(getattr(self, key)):
66
+ return "/think" if self.thinking else "/no_think"
67
+ elif _LLAMA_NEMOTRON_REGEX.match(getattr(self, key)):
68
+ return f"detailed thinking {'on' if self.thinking else 'off'}"
@@ -23,7 +23,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
23
23
 
24
24
  class TopPMixin(
25
25
  BaseModel,
26
- GatedFieldMixin[float],
26
+ GatedFieldMixin,
27
27
  field_name="top_p",
28
28
  default_if_supported=1.0,
29
29
  keys=("model_name", "model", "azure_deployment"),