nvidia-nat-test 1.2.0__py3-none-any.whl → 1.4.0a20251212__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-test might be problematic. Click here for more details.
- nat/test/functions.py +9 -1
- nat/test/llm.py +244 -0
- nat/test/plugin.py +789 -12
- nat/test/register.py +2 -1
- nat/test/tool_test_runner.py +203 -47
- nat/test/utils.py +148 -0
- {nvidia_nat_test-1.2.0.dist-info → nvidia_nat_test-1.4.0a20251212.dist-info}/METADATA +14 -3
- nvidia_nat_test-1.4.0a20251212.dist-info/RECORD +18 -0
- nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_test-1.2.0.dist-info/RECORD +0 -14
- {nvidia_nat_test-1.2.0.dist-info → nvidia_nat_test-1.4.0a20251212.dist-info}/WHEEL +0 -0
- {nvidia_nat_test-1.2.0.dist-info → nvidia_nat_test-1.4.0a20251212.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_test-1.2.0.dist-info → nvidia_nat_test-1.4.0a20251212.dist-info}/top_level.txt +0 -0
nat/test/register.py
CHANGED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
# isort:skip_file
|
|
19
18
|
|
|
@@ -22,3 +21,5 @@
|
|
|
22
21
|
from . import embedder
|
|
23
22
|
from . import functions
|
|
24
23
|
from . import memory
|
|
24
|
+
from . import llm
|
|
25
|
+
from . import utils
|
nat/test/tool_test_runner.py
CHANGED
|
@@ -14,18 +14,42 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import asyncio
|
|
17
|
+
import inspect
|
|
17
18
|
import logging
|
|
18
19
|
import typing
|
|
20
|
+
from collections.abc import Sequence
|
|
19
21
|
from contextlib import asynccontextmanager
|
|
20
22
|
from unittest.mock import AsyncMock
|
|
21
23
|
from unittest.mock import MagicMock
|
|
22
24
|
|
|
25
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
23
26
|
from nat.builder.builder import Builder
|
|
24
27
|
from nat.builder.function import Function
|
|
28
|
+
from nat.builder.function import FunctionGroup
|
|
25
29
|
from nat.builder.function_info import FunctionInfo
|
|
26
30
|
from nat.cli.type_registry import GlobalTypeRegistry
|
|
31
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
32
|
+
from nat.data_models.component_ref import MiddlewareRef
|
|
33
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
34
|
+
from nat.data_models.finetuning import TrainerAdapterConfig
|
|
35
|
+
from nat.data_models.finetuning import TrainerConfig
|
|
36
|
+
from nat.data_models.finetuning import TrajectoryBuilderConfig
|
|
27
37
|
from nat.data_models.function import FunctionBaseConfig
|
|
38
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
39
|
+
from nat.data_models.function_dependencies import FunctionDependencies
|
|
40
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
41
|
+
from nat.data_models.memory import MemoryBaseConfig
|
|
42
|
+
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
|
|
28
43
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
44
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
|
45
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
46
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
47
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
48
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
49
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
50
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
51
|
+
from nat.memory.interfaces import MemoryEditor
|
|
52
|
+
from nat.middleware import FunctionMiddleware
|
|
29
53
|
from nat.object_store.interfaces import ObjectStore
|
|
30
54
|
from nat.runtime.loader import PluginTypes
|
|
31
55
|
from nat.runtime.loader import discover_and_register_plugins
|
|
@@ -46,6 +70,10 @@ class MockBuilder(Builder):
|
|
|
46
70
|
"""Add a mock function that returns a fixed response."""
|
|
47
71
|
self._mocks[name] = mock_response
|
|
48
72
|
|
|
73
|
+
def mock_function_group(self, name: str, mock_response: typing.Any):
|
|
74
|
+
"""Add a mock function group that returns a fixed response."""
|
|
75
|
+
self._mocks[name] = mock_response
|
|
76
|
+
|
|
49
77
|
def mock_llm(self, name: str, mock_response: typing.Any):
|
|
50
78
|
"""Add a mock LLM that returns a fixed response."""
|
|
51
79
|
self._mocks[f"llm_{name}"] = mock_response
|
|
@@ -70,14 +98,28 @@ class MockBuilder(Builder):
|
|
|
70
98
|
"""Add a mock TTC strategy that returns a fixed response."""
|
|
71
99
|
self._mocks[f"ttc_strategy_{name}"] = mock_response
|
|
72
100
|
|
|
73
|
-
|
|
101
|
+
def mock_auth_provider(self, name: str, mock_response: typing.Any):
|
|
102
|
+
"""Add a mock auth provider that returns a fixed response."""
|
|
103
|
+
self._mocks[f"auth_provider_{name}"] = mock_response
|
|
104
|
+
|
|
105
|
+
def mock_trainer(self, name: str, mock_response: typing.Any):
|
|
106
|
+
"""Add a mock trainer that returns a fixed response."""
|
|
107
|
+
self._mocks[f"trainer_{name}"] = mock_response
|
|
108
|
+
|
|
109
|
+
def mock_trainer_adapter(self, name: str, mock_response: typing.Any):
|
|
110
|
+
"""Add a mock trainer adapter that returns a fixed response."""
|
|
111
|
+
self._mocks[f"trainer_adapter_{name}"] = mock_response
|
|
112
|
+
|
|
113
|
+
def mock_trajectory_builder(self, name: str, mock_response: typing.Any):
|
|
114
|
+
"""Add a mock trajectory builder that returns a fixed response."""
|
|
115
|
+
self._mocks[f"trajectory_builder_{name}"] = mock_response
|
|
116
|
+
|
|
117
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
74
118
|
"""Mock implementation (no‑op)."""
|
|
75
119
|
pass
|
|
76
120
|
|
|
77
|
-
async def get_ttc_strategy(self,
|
|
78
|
-
|
|
79
|
-
pipeline_type: typing.Any = None,
|
|
80
|
-
stage_type: typing.Any = None):
|
|
121
|
+
async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
|
|
122
|
+
stage_type: StageTypeEnum) -> typing.Any:
|
|
81
123
|
"""Return a mock TTC strategy if one is configured."""
|
|
82
124
|
key = f"ttc_strategy_{strategy_name}"
|
|
83
125
|
if key in self._mocks:
|
|
@@ -90,16 +132,29 @@ class MockBuilder(Builder):
|
|
|
90
132
|
|
|
91
133
|
async def get_ttc_strategy_config(self,
|
|
92
134
|
strategy_name: str,
|
|
93
|
-
pipeline_type:
|
|
94
|
-
stage_type:
|
|
135
|
+
pipeline_type: PipelineTypeEnum,
|
|
136
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
95
137
|
"""Mock implementation."""
|
|
138
|
+
return TTCStrategyBaseConfig()
|
|
139
|
+
|
|
140
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
|
|
141
|
+
"""Mock implementation (no‑op)."""
|
|
96
142
|
pass
|
|
97
143
|
|
|
144
|
+
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
145
|
+
"""Return a mock auth provider if one is configured."""
|
|
146
|
+
key = f"auth_provider_{auth_provider_name}"
|
|
147
|
+
if key in self._mocks:
|
|
148
|
+
mock_auth = MagicMock()
|
|
149
|
+
mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
|
|
150
|
+
return mock_auth
|
|
151
|
+
raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
|
|
152
|
+
|
|
98
153
|
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
99
154
|
"""Mock implementation - not used in tool testing."""
|
|
100
155
|
raise NotImplementedError("Mock implementation does not support add_function")
|
|
101
156
|
|
|
102
|
-
def get_function(self, name: str) -> Function:
|
|
157
|
+
async def get_function(self, name: str) -> Function:
|
|
103
158
|
"""Return a mock function if one is configured."""
|
|
104
159
|
if name in self._mocks:
|
|
105
160
|
mock_fn = AsyncMock()
|
|
@@ -109,25 +164,49 @@ class MockBuilder(Builder):
|
|
|
109
164
|
|
|
110
165
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
111
166
|
"""Mock implementation."""
|
|
112
|
-
|
|
167
|
+
return FunctionBaseConfig()
|
|
168
|
+
|
|
169
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
170
|
+
"""Mock implementation - not used in tool testing."""
|
|
171
|
+
raise NotImplementedError("Mock implementation does not support add_function_group")
|
|
172
|
+
|
|
173
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
174
|
+
"""Return a mock function group if one is configured."""
|
|
175
|
+
if name in self._mocks:
|
|
176
|
+
mock_fn_group = MagicMock(spec=FunctionGroup)
|
|
177
|
+
mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
|
|
178
|
+
return mock_fn_group
|
|
179
|
+
raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
|
|
180
|
+
|
|
181
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
182
|
+
"""Mock implementation."""
|
|
183
|
+
return FunctionGroupBaseConfig()
|
|
113
184
|
|
|
114
185
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
115
186
|
"""Mock implementation."""
|
|
116
|
-
|
|
187
|
+
mock_fn = AsyncMock()
|
|
188
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
|
189
|
+
return mock_fn
|
|
117
190
|
|
|
118
191
|
def get_workflow(self) -> Function:
|
|
119
192
|
"""Mock implementation."""
|
|
120
|
-
|
|
193
|
+
mock_fn = AsyncMock()
|
|
194
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
|
195
|
+
return mock_fn
|
|
121
196
|
|
|
122
197
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
123
198
|
"""Mock implementation."""
|
|
124
|
-
|
|
199
|
+
return FunctionBaseConfig()
|
|
125
200
|
|
|
126
|
-
def
|
|
201
|
+
async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
|
|
202
|
+
"""Mock implementation."""
|
|
203
|
+
return []
|
|
204
|
+
|
|
205
|
+
async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
|
|
127
206
|
"""Mock implementation."""
|
|
128
207
|
pass
|
|
129
208
|
|
|
130
|
-
async def add_llm(self, name: str, config):
|
|
209
|
+
async def add_llm(self, name: str, config) -> None:
|
|
131
210
|
"""Mock implementation."""
|
|
132
211
|
pass
|
|
133
212
|
|
|
@@ -141,11 +220,11 @@ class MockBuilder(Builder):
|
|
|
141
220
|
return mock_llm
|
|
142
221
|
raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
|
|
143
222
|
|
|
144
|
-
def get_llm_config(self, llm_name: str):
|
|
223
|
+
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
|
145
224
|
"""Mock implementation."""
|
|
146
|
-
|
|
225
|
+
return LLMBaseConfig()
|
|
147
226
|
|
|
148
|
-
async def add_embedder(self, name: str, config):
|
|
227
|
+
async def add_embedder(self, name: str, config) -> None:
|
|
149
228
|
"""Mock implementation."""
|
|
150
229
|
pass
|
|
151
230
|
|
|
@@ -159,15 +238,14 @@ class MockBuilder(Builder):
|
|
|
159
238
|
return mock_embedder
|
|
160
239
|
raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
|
|
161
240
|
|
|
162
|
-
def get_embedder_config(self, embedder_name: str):
|
|
241
|
+
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
|
163
242
|
"""Mock implementation."""
|
|
164
|
-
|
|
243
|
+
return EmbedderBaseConfig()
|
|
165
244
|
|
|
166
|
-
async def add_memory_client(self, name: str, config):
|
|
167
|
-
|
|
168
|
-
pass
|
|
245
|
+
async def add_memory_client(self, name: str, config) -> MemoryEditor:
|
|
246
|
+
return MagicMock(spec=MemoryEditor)
|
|
169
247
|
|
|
170
|
-
def get_memory_client(self, memory_name: str):
|
|
248
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
171
249
|
"""Return a mock memory client if one is configured."""
|
|
172
250
|
key = f"memory_{memory_name}"
|
|
173
251
|
if key in self._mocks:
|
|
@@ -177,11 +255,11 @@ class MockBuilder(Builder):
|
|
|
177
255
|
return mock_memory
|
|
178
256
|
raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
|
|
179
257
|
|
|
180
|
-
def get_memory_client_config(self, memory_name: str):
|
|
258
|
+
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
|
181
259
|
"""Mock implementation."""
|
|
182
|
-
|
|
260
|
+
return MemoryBaseConfig()
|
|
183
261
|
|
|
184
|
-
async def add_retriever(self, name: str, config):
|
|
262
|
+
async def add_retriever(self, name: str, config) -> None:
|
|
185
263
|
"""Mock implementation."""
|
|
186
264
|
pass
|
|
187
265
|
|
|
@@ -194,13 +272,13 @@ class MockBuilder(Builder):
|
|
|
194
272
|
return mock_retriever
|
|
195
273
|
raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
|
|
196
274
|
|
|
197
|
-
async def get_retriever_config(self, retriever_name: str):
|
|
275
|
+
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
|
198
276
|
"""Mock implementation."""
|
|
199
|
-
|
|
277
|
+
return RetrieverBaseConfig()
|
|
200
278
|
|
|
201
|
-
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
|
279
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
202
280
|
"""Mock implementation for object store."""
|
|
203
|
-
|
|
281
|
+
return MagicMock(spec=ObjectStore)
|
|
204
282
|
|
|
205
283
|
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
206
284
|
"""Return a mock object store client if one is configured."""
|
|
@@ -216,7 +294,7 @@ class MockBuilder(Builder):
|
|
|
216
294
|
|
|
217
295
|
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
|
218
296
|
"""Mock implementation for object store config."""
|
|
219
|
-
|
|
297
|
+
return ObjectStoreBaseConfig()
|
|
220
298
|
|
|
221
299
|
def get_user_manager(self):
|
|
222
300
|
"""Mock implementation."""
|
|
@@ -224,9 +302,81 @@ class MockBuilder(Builder):
|
|
|
224
302
|
mock_user.get_id = MagicMock(return_value="test_user")
|
|
225
303
|
return mock_user
|
|
226
304
|
|
|
227
|
-
def get_function_dependencies(self, fn_name: str):
|
|
305
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
228
306
|
"""Mock implementation."""
|
|
229
|
-
|
|
307
|
+
return FunctionDependencies()
|
|
308
|
+
|
|
309
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
310
|
+
"""Mock implementation."""
|
|
311
|
+
return FunctionDependencies()
|
|
312
|
+
|
|
313
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> FunctionMiddleware:
|
|
314
|
+
"""Mock implementation."""
|
|
315
|
+
return FunctionMiddleware()
|
|
316
|
+
|
|
317
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> FunctionMiddlewareBaseConfig:
|
|
318
|
+
"""Mock implementation."""
|
|
319
|
+
return FunctionMiddlewareBaseConfig()
|
|
320
|
+
|
|
321
|
+
async def add_middleware(self, name: str | MiddlewareRef,
|
|
322
|
+
config: FunctionMiddlewareBaseConfig) -> FunctionMiddleware:
|
|
323
|
+
"""Mock implementation."""
|
|
324
|
+
return FunctionMiddleware()
|
|
325
|
+
|
|
326
|
+
async def add_trainer(self, name: str, config: TrainerConfig) -> Trainer:
|
|
327
|
+
"""Mock implementation."""
|
|
328
|
+
return MagicMock(spec=Trainer)
|
|
329
|
+
|
|
330
|
+
async def get_trainer(self,
|
|
331
|
+
trainer_name: str,
|
|
332
|
+
trajectory_builder: TrajectoryBuilder,
|
|
333
|
+
trainer_adapter: TrainerAdapter) -> Trainer:
|
|
334
|
+
"""Return a mock trainer if one is configured."""
|
|
335
|
+
key = f"trainer_{trainer_name}"
|
|
336
|
+
if key in self._mocks:
|
|
337
|
+
mock_trainer = MagicMock()
|
|
338
|
+
mock_trainer.train = AsyncMock(return_value=self._mocks[key])
|
|
339
|
+
return mock_trainer
|
|
340
|
+
raise ValueError(f"Trainer '{trainer_name}' not mocked. Use mock_trainer() to add it.")
|
|
341
|
+
|
|
342
|
+
async def get_trainer_config(self, trainer_name: str) -> TrainerConfig:
|
|
343
|
+
"""Mock implementation."""
|
|
344
|
+
return TrainerConfig()
|
|
345
|
+
|
|
346
|
+
async def add_trainer_adapter(self, name: str, config: TrainerAdapterConfig) -> TrainerAdapter:
|
|
347
|
+
"""Mock implementation."""
|
|
348
|
+
return MagicMock(spec=TrainerAdapter)
|
|
349
|
+
|
|
350
|
+
async def get_trainer_adapter(self, trainer_adapter_name: str) -> TrainerAdapter:
|
|
351
|
+
"""Return a mock trainer adapter if one is configured."""
|
|
352
|
+
key = f"trainer_adapter_{trainer_adapter_name}"
|
|
353
|
+
if key in self._mocks:
|
|
354
|
+
mock_adapter = MagicMock()
|
|
355
|
+
mock_adapter.adapt = AsyncMock(return_value=self._mocks[key])
|
|
356
|
+
return mock_adapter
|
|
357
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not mocked. Use mock_trainer_adapter() to add it.")
|
|
358
|
+
|
|
359
|
+
async def get_trainer_adapter_config(self, trainer_adapter_name: str) -> TrainerAdapterConfig:
|
|
360
|
+
"""Mock implementation."""
|
|
361
|
+
return TrainerAdapterConfig()
|
|
362
|
+
|
|
363
|
+
async def add_trajectory_builder(self, name: str, config: TrajectoryBuilderConfig) -> TrajectoryBuilder:
|
|
364
|
+
"""Mock implementation."""
|
|
365
|
+
return MagicMock(spec=TrajectoryBuilder)
|
|
366
|
+
|
|
367
|
+
async def get_trajectory_builder(self, trajectory_builder_name: str) -> TrajectoryBuilder:
|
|
368
|
+
"""Return a mock trajectory builder if one is configured."""
|
|
369
|
+
key = f"trajectory_builder_{trajectory_builder_name}"
|
|
370
|
+
if key in self._mocks:
|
|
371
|
+
mock_builder = MagicMock()
|
|
372
|
+
mock_builder.build = AsyncMock(return_value=self._mocks[key])
|
|
373
|
+
return mock_builder
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Trajectory builder '{trajectory_builder_name}' not mocked. Use mock_trajectory_builder() to add it.")
|
|
376
|
+
|
|
377
|
+
async def get_trajectory_builder_config(self, trajectory_builder_name: str) -> TrajectoryBuilderConfig:
|
|
378
|
+
"""Mock implementation."""
|
|
379
|
+
return TrajectoryBuilderConfig()
|
|
230
380
|
|
|
231
381
|
|
|
232
382
|
class ToolTestRunner:
|
|
@@ -322,15 +472,19 @@ class ToolTestRunner:
|
|
|
322
472
|
|
|
323
473
|
# Execute the tool
|
|
324
474
|
if input_data is not None:
|
|
325
|
-
if
|
|
475
|
+
if isinstance(tool_function, Function):
|
|
476
|
+
result = await tool_function.ainvoke(input_data)
|
|
477
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
|
326
478
|
result = await tool_function(input_data)
|
|
327
479
|
else:
|
|
328
480
|
result = tool_function(input_data)
|
|
481
|
+
elif isinstance(tool_function, Function):
|
|
482
|
+
# Function objects require input, so pass None if no input_data
|
|
483
|
+
result = await tool_function.ainvoke(None)
|
|
484
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
|
485
|
+
result = await tool_function()
|
|
329
486
|
else:
|
|
330
|
-
|
|
331
|
-
result = await tool_function()
|
|
332
|
-
else:
|
|
333
|
-
result = tool_function()
|
|
487
|
+
result = tool_function()
|
|
334
488
|
|
|
335
489
|
# Assert expected output if provided
|
|
336
490
|
if expected_output is not None:
|
|
@@ -402,8 +556,8 @@ class ToolTestRunner:
|
|
|
402
556
|
elif isinstance(tool_result, FunctionInfo):
|
|
403
557
|
if tool_result.single_fn:
|
|
404
558
|
tool_function = tool_result.single_fn
|
|
405
|
-
elif tool_result.
|
|
406
|
-
tool_function = tool_result.
|
|
559
|
+
elif tool_result.stream_fn:
|
|
560
|
+
tool_function = tool_result.stream_fn
|
|
407
561
|
else:
|
|
408
562
|
raise ValueError("Tool function not found in FunctionInfo")
|
|
409
563
|
elif callable(tool_result):
|
|
@@ -413,15 +567,17 @@ class ToolTestRunner:
|
|
|
413
567
|
|
|
414
568
|
# Execute the tool
|
|
415
569
|
if input_data is not None:
|
|
416
|
-
if
|
|
417
|
-
result = await tool_function(input_data)
|
|
570
|
+
if isinstance(tool_function, Function):
|
|
571
|
+
result = await tool_function.ainvoke(input_data)
|
|
418
572
|
else:
|
|
419
|
-
|
|
573
|
+
maybe_result = tool_function(input_data)
|
|
574
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
|
575
|
+
elif isinstance(tool_function, Function):
|
|
576
|
+
# Function objects require input, so pass None if no input_data
|
|
577
|
+
result = await tool_function.ainvoke(None)
|
|
420
578
|
else:
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
else:
|
|
424
|
-
result = tool_function()
|
|
579
|
+
maybe_result = tool_function()
|
|
580
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
|
425
581
|
|
|
426
582
|
# Assert expected output if provided
|
|
427
583
|
if expected_output is not None:
|
nat/test/utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import importlib.resources
|
|
17
|
+
import inspect
|
|
18
|
+
import json
|
|
19
|
+
import subprocess
|
|
20
|
+
import typing
|
|
21
|
+
from contextlib import asynccontextmanager
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
if typing.TYPE_CHECKING:
|
|
25
|
+
from collections.abc import AsyncIterator
|
|
26
|
+
|
|
27
|
+
from httpx import AsyncClient
|
|
28
|
+
|
|
29
|
+
from nat.data_models.config import Config
|
|
30
|
+
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
|
|
31
|
+
from nat.utils.type_utils import StrPath
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def locate_repo_root() -> Path:
|
|
35
|
+
result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
|
|
36
|
+
assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
|
|
37
|
+
return Path(result.stdout.strip())
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def locate_example_src_dir(example_config_class: type) -> Path:
|
|
41
|
+
"""
|
|
42
|
+
Locate the example src directory for an example's config class.
|
|
43
|
+
"""
|
|
44
|
+
package_name = inspect.getmodule(example_config_class).__package__
|
|
45
|
+
return importlib.resources.files(package_name)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def locate_example_dir(example_config_class: type) -> Path:
|
|
49
|
+
"""
|
|
50
|
+
Locate the example directory for an example's config class.
|
|
51
|
+
"""
|
|
52
|
+
src_dir = locate_example_src_dir(example_config_class)
|
|
53
|
+
example_dir = src_dir.parent.parent
|
|
54
|
+
return example_dir
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def locate_example_config(example_config_class: type,
|
|
58
|
+
config_file: str = "config.yml",
|
|
59
|
+
assert_exists: bool = True) -> Path:
|
|
60
|
+
"""
|
|
61
|
+
Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
|
|
62
|
+
"""
|
|
63
|
+
example_dir = locate_example_src_dir(example_config_class)
|
|
64
|
+
config_path = example_dir.joinpath("configs", config_file).absolute()
|
|
65
|
+
if assert_exists:
|
|
66
|
+
assert config_path.exists(), f"Config file {config_path} does not exist"
|
|
67
|
+
|
|
68
|
+
return config_path
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def run_workflow(*,
|
|
72
|
+
config: "Config | None" = None,
|
|
73
|
+
config_file: "StrPath | None" = None,
|
|
74
|
+
question: str,
|
|
75
|
+
expected_answer: str,
|
|
76
|
+
assert_expected_answer: bool = True,
|
|
77
|
+
**kwargs) -> str:
|
|
78
|
+
"""
|
|
79
|
+
Test specific wrapper for `nat.utils.run_workflow` to run a workflow with a question and validate the expected
|
|
80
|
+
answer. This variant always sets the result type to `str`.
|
|
81
|
+
"""
|
|
82
|
+
from nat.utils import run_workflow as nat_run_workflow
|
|
83
|
+
|
|
84
|
+
result = await nat_run_workflow(config=config, config_file=config_file, prompt=question, to_type=str, **kwargs)
|
|
85
|
+
|
|
86
|
+
if assert_expected_answer:
|
|
87
|
+
assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
|
|
88
|
+
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@asynccontextmanager
|
|
93
|
+
async def build_nat_client(
|
|
94
|
+
config: "Config",
|
|
95
|
+
worker_class: "type[FastApiFrontEndPluginWorker] | None" = None) -> "AsyncIterator[AsyncClient]":
|
|
96
|
+
"""
|
|
97
|
+
Build a NAT client for testing purposes.
|
|
98
|
+
|
|
99
|
+
Creates a test client with an ASGI transport for the specified configuration.
|
|
100
|
+
The client is backed by a FastAPI application built from the provided worker class.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
config: The NAT configuration to use for building the client.
|
|
104
|
+
worker_class: Optional worker class to use. Defaults to FastApiFrontEndPluginWorker.
|
|
105
|
+
|
|
106
|
+
Yields:
|
|
107
|
+
An AsyncClient instance configured for testing.
|
|
108
|
+
"""
|
|
109
|
+
from asgi_lifespan import LifespanManager
|
|
110
|
+
from httpx import ASGITransport
|
|
111
|
+
from httpx import AsyncClient
|
|
112
|
+
|
|
113
|
+
from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
|
|
114
|
+
|
|
115
|
+
if worker_class is None:
|
|
116
|
+
worker_class = FastApiFrontEndPluginWorker
|
|
117
|
+
|
|
118
|
+
worker = worker_class(config)
|
|
119
|
+
app = worker.build_app()
|
|
120
|
+
|
|
121
|
+
async with LifespanManager(app):
|
|
122
|
+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
|
123
|
+
yield client
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def validate_workflow_output(workflow_output_file: Path) -> None:
|
|
127
|
+
"""
|
|
128
|
+
Validate the contents of the workflow output file.
|
|
129
|
+
WIP: output format should be published as a schema and this validation should be done against that schema.
|
|
130
|
+
"""
|
|
131
|
+
# Ensure the workflow_output.json file was created
|
|
132
|
+
assert workflow_output_file.exists(), "The workflow_output.json file was not created"
|
|
133
|
+
|
|
134
|
+
# Read and validate the workflow_output.json file
|
|
135
|
+
try:
|
|
136
|
+
with open(workflow_output_file, encoding="utf-8") as f:
|
|
137
|
+
result_json = json.load(f)
|
|
138
|
+
except json.JSONDecodeError as err:
|
|
139
|
+
raise RuntimeError("Failed to parse workflow_output.json as valid JSON") from err
|
|
140
|
+
|
|
141
|
+
assert isinstance(result_json, list), "The workflow_output.json file is not a list"
|
|
142
|
+
assert len(result_json) > 0, "The workflow_output.json file is empty"
|
|
143
|
+
assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries"
|
|
144
|
+
|
|
145
|
+
# Ensure required keys exist
|
|
146
|
+
required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"]
|
|
147
|
+
for key in required_keys:
|
|
148
|
+
assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"
|
|
@@ -1,14 +1,25 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nvidia-nat-test
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0a20251212
|
|
4
4
|
Summary: Testing utilities for NeMo Agent toolkit
|
|
5
|
+
Author: NVIDIA Corporation
|
|
6
|
+
Maintainer: NVIDIA Corporation
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
|
|
9
|
+
Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
|
|
5
10
|
Keywords: ai,rag,agents
|
|
6
11
|
Classifier: Programming Language :: Python
|
|
7
|
-
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Python: <3.14,>=3.11
|
|
8
16
|
Description-Content-Type: text/markdown
|
|
9
|
-
|
|
17
|
+
License-File: LICENSE-3rd-party.txt
|
|
18
|
+
License-File: LICENSE.md
|
|
19
|
+
Requires-Dist: nvidia-nat==v1.4.0a20251212
|
|
10
20
|
Requires-Dist: langchain-community~=0.3
|
|
11
21
|
Requires-Dist: pytest~=8.3
|
|
22
|
+
Dynamic: license-file
|
|
12
23
|
|
|
13
24
|
<!--
|
|
14
25
|
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
|
|
2
|
+
nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
|
|
3
|
+
nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
|
|
4
|
+
nat/test/functions.py,sha256=ZxXVzfaLBGOpR5qtmMrKU7q-M9-vVGGj3Xi5mrw4vHY,3557
|
|
5
|
+
nat/test/llm.py,sha256=dbFoWFrSAlUoKm6QGfS4VJdrhgxwkXzm1oaFd6K7jnM,9926
|
|
6
|
+
nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
|
|
7
|
+
nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
|
|
8
|
+
nat/test/plugin.py,sha256=HF25W2YPTiXaoIJggnZTstiTMaspQckvL_thQSseDEc,32434
|
|
9
|
+
nat/test/register.py,sha256=o1BEA5fyxyFyCxXhQ6ArmtuNpgRyTEfvw6HdBgECPLI,897
|
|
10
|
+
nat/test/tool_test_runner.py,sha256=O4FyZMlf-oc1XYbuAFxLzn9_zWi8TzWigFyg-Px7xBc,26126
|
|
11
|
+
nat/test/utils.py,sha256=GyhxIZ1CcUPcc8RMRyCzpHBEwVifeqiGxT3c9Pp0KAU,5774
|
|
12
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
|
|
13
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
14
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/METADATA,sha256=lqR6z09voBpFqNfGMydq_0tDeACzfQVCkSf5muBIwO4,1925
|
|
15
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
|
|
17
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
18
|
+
nvidia_nat_test-1.4.0a20251212.dist-info/RECORD,,
|