nvidia-nat-test 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/test/tool_test_runner.py +74 -36
- {nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/METADATA +2 -2
- {nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/RECORD +6 -6
- {nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/WHEEL +0 -0
- {nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/top_level.txt +0 -0
nat/test/tool_test_runner.py
CHANGED
@@ -14,18 +14,29 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
import asyncio
|
17
|
+
import inspect
|
17
18
|
import logging
|
18
19
|
import typing
|
19
20
|
from contextlib import asynccontextmanager
|
20
21
|
from unittest.mock import AsyncMock
|
21
22
|
from unittest.mock import MagicMock
|
22
23
|
|
24
|
+
from nat.authentication.interfaces import AuthProviderBase
|
23
25
|
from nat.builder.builder import Builder
|
24
26
|
from nat.builder.function import Function
|
25
27
|
from nat.builder.function_info import FunctionInfo
|
26
28
|
from nat.cli.type_registry import GlobalTypeRegistry
|
29
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
30
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
27
31
|
from nat.data_models.function import FunctionBaseConfig
|
32
|
+
from nat.data_models.function_dependencies import FunctionDependencies
|
33
|
+
from nat.data_models.llm import LLMBaseConfig
|
34
|
+
from nat.data_models.memory import MemoryBaseConfig
|
28
35
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
36
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
37
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
38
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
39
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
29
40
|
from nat.object_store.interfaces import ObjectStore
|
30
41
|
from nat.runtime.loader import PluginTypes
|
31
42
|
from nat.runtime.loader import discover_and_register_plugins
|
@@ -70,14 +81,16 @@ class MockBuilder(Builder):
|
|
70
81
|
"""Add a mock TTC strategy that returns a fixed response."""
|
71
82
|
self._mocks[f"ttc_strategy_{name}"] = mock_response
|
72
83
|
|
73
|
-
|
84
|
+
def mock_auth_provider(self, name: str, mock_response: typing.Any):
|
85
|
+
"""Add a mock auth provider that returns a fixed response."""
|
86
|
+
self._mocks[f"auth_provider_{name}"] = mock_response
|
87
|
+
|
88
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
74
89
|
"""Mock implementation (no‑op)."""
|
75
90
|
pass
|
76
91
|
|
77
|
-
async def get_ttc_strategy(self,
|
78
|
-
|
79
|
-
pipeline_type: typing.Any = None,
|
80
|
-
stage_type: typing.Any = None):
|
92
|
+
async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
|
93
|
+
stage_type: StageTypeEnum) -> typing.Any:
|
81
94
|
"""Return a mock TTC strategy if one is configured."""
|
82
95
|
key = f"ttc_strategy_{strategy_name}"
|
83
96
|
if key in self._mocks:
|
@@ -90,11 +103,24 @@ class MockBuilder(Builder):
|
|
90
103
|
|
91
104
|
async def get_ttc_strategy_config(self,
|
92
105
|
strategy_name: str,
|
93
|
-
pipeline_type:
|
94
|
-
stage_type:
|
106
|
+
pipeline_type: PipelineTypeEnum,
|
107
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
95
108
|
"""Mock implementation."""
|
109
|
+
return TTCStrategyBaseConfig()
|
110
|
+
|
111
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
|
112
|
+
"""Mock implementation (no‑op)."""
|
96
113
|
pass
|
97
114
|
|
115
|
+
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
116
|
+
"""Return a mock auth provider if one is configured."""
|
117
|
+
key = f"auth_provider_{auth_provider_name}"
|
118
|
+
if key in self._mocks:
|
119
|
+
mock_auth = MagicMock()
|
120
|
+
mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
|
121
|
+
return mock_auth
|
122
|
+
raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
|
123
|
+
|
98
124
|
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
99
125
|
"""Mock implementation - not used in tool testing."""
|
100
126
|
raise NotImplementedError("Mock implementation does not support add_function")
|
@@ -109,25 +135,29 @@ class MockBuilder(Builder):
|
|
109
135
|
|
110
136
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
111
137
|
"""Mock implementation."""
|
112
|
-
|
138
|
+
return FunctionBaseConfig()
|
113
139
|
|
114
140
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
115
141
|
"""Mock implementation."""
|
116
|
-
|
142
|
+
mock_fn = AsyncMock()
|
143
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
144
|
+
return mock_fn
|
117
145
|
|
118
146
|
def get_workflow(self) -> Function:
|
119
147
|
"""Mock implementation."""
|
120
|
-
|
148
|
+
mock_fn = AsyncMock()
|
149
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
150
|
+
return mock_fn
|
121
151
|
|
122
152
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
123
153
|
"""Mock implementation."""
|
124
|
-
|
154
|
+
return FunctionBaseConfig()
|
125
155
|
|
126
156
|
def get_tool(self, fn_name: str, wrapper_type):
|
127
157
|
"""Mock implementation."""
|
128
158
|
pass
|
129
159
|
|
130
|
-
async def add_llm(self, name: str, config):
|
160
|
+
async def add_llm(self, name: str, config) -> None:
|
131
161
|
"""Mock implementation."""
|
132
162
|
pass
|
133
163
|
|
@@ -141,11 +171,11 @@ class MockBuilder(Builder):
|
|
141
171
|
return mock_llm
|
142
172
|
raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
|
143
173
|
|
144
|
-
def get_llm_config(self, llm_name: str):
|
174
|
+
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
145
175
|
"""Mock implementation."""
|
146
|
-
|
176
|
+
return LLMBaseConfig()
|
147
177
|
|
148
|
-
async def add_embedder(self, name: str, config):
|
178
|
+
async def add_embedder(self, name: str, config) -> None:
|
149
179
|
"""Mock implementation."""
|
150
180
|
pass
|
151
181
|
|
@@ -159,11 +189,11 @@ class MockBuilder(Builder):
|
|
159
189
|
return mock_embedder
|
160
190
|
raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
|
161
191
|
|
162
|
-
def get_embedder_config(self, embedder_name: str):
|
192
|
+
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
163
193
|
"""Mock implementation."""
|
164
|
-
|
194
|
+
return EmbedderBaseConfig()
|
165
195
|
|
166
|
-
async def add_memory_client(self, name: str, config):
|
196
|
+
async def add_memory_client(self, name: str, config) -> None:
|
167
197
|
"""Mock implementation."""
|
168
198
|
pass
|
169
199
|
|
@@ -177,11 +207,11 @@ class MockBuilder(Builder):
|
|
177
207
|
return mock_memory
|
178
208
|
raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
|
179
209
|
|
180
|
-
def get_memory_client_config(self, memory_name: str):
|
210
|
+
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
181
211
|
"""Mock implementation."""
|
182
|
-
|
212
|
+
return MemoryBaseConfig()
|
183
213
|
|
184
|
-
async def add_retriever(self, name: str, config):
|
214
|
+
async def add_retriever(self, name: str, config) -> None:
|
185
215
|
"""Mock implementation."""
|
186
216
|
pass
|
187
217
|
|
@@ -194,11 +224,11 @@ class MockBuilder(Builder):
|
|
194
224
|
return mock_retriever
|
195
225
|
raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
|
196
226
|
|
197
|
-
async def get_retriever_config(self, retriever_name: str):
|
227
|
+
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
198
228
|
"""Mock implementation."""
|
199
|
-
|
229
|
+
return RetrieverBaseConfig()
|
200
230
|
|
201
|
-
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
231
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> None:
|
202
232
|
"""Mock implementation for object store."""
|
203
233
|
pass
|
204
234
|
|
@@ -216,7 +246,7 @@ class MockBuilder(Builder):
|
|
216
246
|
|
217
247
|
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
218
248
|
"""Mock implementation for object store config."""
|
219
|
-
|
249
|
+
return ObjectStoreBaseConfig()
|
220
250
|
|
221
251
|
def get_user_manager(self):
|
222
252
|
"""Mock implementation."""
|
@@ -224,9 +254,9 @@ class MockBuilder(Builder):
|
|
224
254
|
mock_user.get_id = MagicMock(return_value="test_user")
|
225
255
|
return mock_user
|
226
256
|
|
227
|
-
def get_function_dependencies(self, fn_name: str):
|
257
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
228
258
|
"""Mock implementation."""
|
229
|
-
|
259
|
+
return FunctionDependencies()
|
230
260
|
|
231
261
|
|
232
262
|
class ToolTestRunner:
|
@@ -322,10 +352,15 @@ class ToolTestRunner:
|
|
322
352
|
|
323
353
|
# Execute the tool
|
324
354
|
if input_data is not None:
|
325
|
-
if
|
355
|
+
if isinstance(tool_function, Function):
|
356
|
+
result = await tool_function.ainvoke(input_data)
|
357
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
326
358
|
result = await tool_function(input_data)
|
327
359
|
else:
|
328
360
|
result = tool_function(input_data)
|
361
|
+
elif isinstance(tool_function, Function):
|
362
|
+
# Function objects require input, so pass None if no input_data
|
363
|
+
result = await tool_function.ainvoke(None)
|
329
364
|
elif asyncio.iscoroutinefunction(tool_function):
|
330
365
|
result = await tool_function()
|
331
366
|
else:
|
@@ -401,8 +436,8 @@ class ToolTestRunner:
|
|
401
436
|
elif isinstance(tool_result, FunctionInfo):
|
402
437
|
if tool_result.single_fn:
|
403
438
|
tool_function = tool_result.single_fn
|
404
|
-
elif tool_result.
|
405
|
-
tool_function = tool_result.
|
439
|
+
elif tool_result.stream_fn:
|
440
|
+
tool_function = tool_result.stream_fn
|
406
441
|
else:
|
407
442
|
raise ValueError("Tool function not found in FunctionInfo")
|
408
443
|
elif callable(tool_result):
|
@@ -412,14 +447,17 @@ class ToolTestRunner:
|
|
412
447
|
|
413
448
|
# Execute the tool
|
414
449
|
if input_data is not None:
|
415
|
-
if
|
416
|
-
result = await tool_function(input_data)
|
450
|
+
if isinstance(tool_function, Function):
|
451
|
+
result = await tool_function.ainvoke(input_data)
|
417
452
|
else:
|
418
|
-
|
419
|
-
|
420
|
-
|
453
|
+
maybe_result = tool_function(input_data)
|
454
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
455
|
+
elif isinstance(tool_function, Function):
|
456
|
+
# Function objects require input, so pass None if no input_data
|
457
|
+
result = await tool_function.ainvoke(None)
|
421
458
|
else:
|
422
|
-
|
459
|
+
maybe_result = tool_function()
|
460
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
423
461
|
|
424
462
|
# Assert expected output if provided
|
425
463
|
if expected_output is not None:
|
{nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/METADATA
RENAMED
@@ -1,12 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-test
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0a20250829
|
4
4
|
Summary: Testing utilities for NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents
|
6
6
|
Classifier: Programming Language :: Python
|
7
7
|
Requires-Python: <3.13,>=3.11
|
8
8
|
Description-Content-Type: text/markdown
|
9
|
-
Requires-Dist: nvidia-nat==v1.3.
|
9
|
+
Requires-Dist: nvidia-nat==v1.3.0a20250829
|
10
10
|
Requires-Dist: langchain-community~=0.3
|
11
11
|
Requires-Dist: pytest~=8.3
|
12
12
|
|
{nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/RECORD
RENAMED
@@ -7,9 +7,9 @@ nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2z
|
|
7
7
|
nat/test/plugin.py,sha256=sMZ7xupCgEpQCuwUksUDYMjbBj0VNlhR6SK5UcOrBzg,6953
|
8
8
|
nat/test/register.py,sha256=fbCLr3E4u8PYMFUlkRNlg53Td2YJ80iQCyxpRIbGId4,859
|
9
9
|
nat/test/test_env_fixtures.py,sha256=zGhFBiZmdDYuj8kOU__RL9LOrood3L58KG8OWXnyOjQ,2375
|
10
|
-
nat/test/tool_test_runner.py,sha256=
|
11
|
-
nvidia_nat_test-1.3.
|
12
|
-
nvidia_nat_test-1.3.
|
13
|
-
nvidia_nat_test-1.3.
|
14
|
-
nvidia_nat_test-1.3.
|
15
|
-
nvidia_nat_test-1.3.
|
10
|
+
nat/test/tool_test_runner.py,sha256=2kCydvJ6LBZ3Lh04e5_Qg-8_kOpcIkNdwGeQOPtadek,20089
|
11
|
+
nvidia_nat_test-1.3.0a20250829.dist-info/METADATA,sha256=LgoT5OxCrGslq9U0rp6lPbeVovcKrqGb6nt9-rhNUYA,1466
|
12
|
+
nvidia_nat_test-1.3.0a20250829.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
13
|
+
nvidia_nat_test-1.3.0a20250829.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
|
14
|
+
nvidia_nat_test-1.3.0a20250829.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
15
|
+
nvidia_nat_test-1.3.0a20250829.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{nvidia_nat_test-1.3.0a20250827.dist-info → nvidia_nat_test-1.3.0a20250829.dist-info}/top_level.txt
RENAMED
File without changes
|