nvidia-nat-test 1.4.0a20260117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/test/__init__.py +23 -0
- nat/test/embedder.py +44 -0
- nat/test/functions.py +99 -0
- nat/test/llm.py +244 -0
- nat/test/memory.py +41 -0
- nat/test/object_store_tests.py +117 -0
- nat/test/plugin.py +890 -0
- nat/test/register.py +25 -0
- nat/test/tool_test_runner.py +612 -0
- nat/test/utils.py +215 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/METADATA +46 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/RECORD +18 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/WHEEL +5 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/entry_points.txt +5 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/top_level.txt +1 -0
nat/test/register.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
# Import any providers which need to be automatically registered here
|
|
20
|
+
|
|
21
|
+
from . import embedder
|
|
22
|
+
from . import functions
|
|
23
|
+
from . import memory
|
|
24
|
+
from . import llm
|
|
25
|
+
from . import utils
|
|
@@ -0,0 +1,612 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import inspect
|
|
18
|
+
import logging
|
|
19
|
+
import typing
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
from contextlib import asynccontextmanager
|
|
22
|
+
from unittest.mock import AsyncMock
|
|
23
|
+
from unittest.mock import MagicMock
|
|
24
|
+
|
|
25
|
+
from nat.authentication.interfaces import AuthProviderBase
|
|
26
|
+
from nat.builder.builder import Builder
|
|
27
|
+
from nat.builder.function import Function
|
|
28
|
+
from nat.builder.function import FunctionGroup
|
|
29
|
+
from nat.builder.function_info import FunctionInfo
|
|
30
|
+
from nat.builder.sync_builder import SyncBuilder
|
|
31
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
32
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
|
33
|
+
from nat.data_models.component_ref import MiddlewareRef
|
|
34
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
35
|
+
from nat.data_models.finetuning import TrainerAdapterConfig
|
|
36
|
+
from nat.data_models.finetuning import TrainerConfig
|
|
37
|
+
from nat.data_models.finetuning import TrajectoryBuilderConfig
|
|
38
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
39
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
40
|
+
from nat.data_models.function_dependencies import FunctionDependencies
|
|
41
|
+
from nat.data_models.llm import LLMBaseConfig
|
|
42
|
+
from nat.data_models.memory import MemoryBaseConfig
|
|
43
|
+
from nat.data_models.middleware import FunctionMiddlewareBaseConfig
|
|
44
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
45
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
|
46
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
|
47
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
|
48
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
|
49
|
+
from nat.finetuning.interfaces.finetuning_runner import Trainer
|
|
50
|
+
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
|
|
51
|
+
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
|
|
52
|
+
from nat.memory.interfaces import MemoryEditor
|
|
53
|
+
from nat.middleware import FunctionMiddleware
|
|
54
|
+
from nat.object_store.interfaces import ObjectStore
|
|
55
|
+
from nat.runtime.loader import PluginTypes
|
|
56
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
57
|
+
from nat.utils.type_utils import override
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MockBuilder(Builder):
|
|
63
|
+
"""
|
|
64
|
+
A lightweight mock builder for tool testing that provides minimal dependencies.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self):
|
|
68
|
+
self._functions = {}
|
|
69
|
+
self._mocks = {}
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
@property
|
|
73
|
+
def sync_builder(self) -> SyncBuilder:
|
|
74
|
+
return SyncBuilder(self)
|
|
75
|
+
|
|
76
|
+
def mock_function(self, name: str, mock_response: typing.Any):
|
|
77
|
+
"""Add a mock function that returns a fixed response."""
|
|
78
|
+
self._mocks[name] = mock_response
|
|
79
|
+
|
|
80
|
+
def mock_function_group(self, name: str, mock_response: typing.Any):
|
|
81
|
+
"""Add a mock function group that returns a fixed response."""
|
|
82
|
+
self._mocks[name] = mock_response
|
|
83
|
+
|
|
84
|
+
def mock_llm(self, name: str, mock_response: typing.Any):
|
|
85
|
+
"""Add a mock LLM that returns a fixed response."""
|
|
86
|
+
self._mocks[f"llm_{name}"] = mock_response
|
|
87
|
+
|
|
88
|
+
def mock_embedder(self, name: str, mock_response: typing.Any):
|
|
89
|
+
"""Add a mock embedder that returns a fixed response."""
|
|
90
|
+
self._mocks[f"embedder_{name}"] = mock_response
|
|
91
|
+
|
|
92
|
+
def mock_memory_client(self, name: str, mock_response: typing.Any):
|
|
93
|
+
"""Add a mock memory client that returns a fixed response."""
|
|
94
|
+
self._mocks[f"memory_{name}"] = mock_response
|
|
95
|
+
|
|
96
|
+
def mock_retriever(self, name: str, mock_response: typing.Any):
|
|
97
|
+
"""Add a mock retriever that returns a fixed response."""
|
|
98
|
+
self._mocks[f"retriever_{name}"] = mock_response
|
|
99
|
+
|
|
100
|
+
def mock_object_store(self, name: str, mock_response: typing.Any):
|
|
101
|
+
"""Add a mock object store that returns a fixed response."""
|
|
102
|
+
self._mocks[f"object_store_{name}"] = mock_response
|
|
103
|
+
|
|
104
|
+
def mock_ttc_strategy(self, name: str, mock_response: typing.Any):
|
|
105
|
+
"""Add a mock TTC strategy that returns a fixed response."""
|
|
106
|
+
self._mocks[f"ttc_strategy_{name}"] = mock_response
|
|
107
|
+
|
|
108
|
+
def mock_auth_provider(self, name: str, mock_response: typing.Any):
|
|
109
|
+
"""Add a mock auth provider that returns a fixed response."""
|
|
110
|
+
self._mocks[f"auth_provider_{name}"] = mock_response
|
|
111
|
+
|
|
112
|
+
def mock_trainer(self, name: str, mock_response: typing.Any):
|
|
113
|
+
"""Add a mock trainer that returns a fixed response."""
|
|
114
|
+
self._mocks[f"trainer_{name}"] = mock_response
|
|
115
|
+
|
|
116
|
+
def mock_trainer_adapter(self, name: str, mock_response: typing.Any):
|
|
117
|
+
"""Add a mock trainer adapter that returns a fixed response."""
|
|
118
|
+
self._mocks[f"trainer_adapter_{name}"] = mock_response
|
|
119
|
+
|
|
120
|
+
def mock_trajectory_builder(self, name: str, mock_response: typing.Any):
|
|
121
|
+
"""Add a mock trajectory builder that returns a fixed response."""
|
|
122
|
+
self._mocks[f"trajectory_builder_{name}"] = mock_response
|
|
123
|
+
|
|
124
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
|
125
|
+
"""Mock implementation (no‑op)."""
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
|
|
129
|
+
stage_type: StageTypeEnum) -> typing.Any:
|
|
130
|
+
"""Return a mock TTC strategy if one is configured."""
|
|
131
|
+
key = f"ttc_strategy_{strategy_name}"
|
|
132
|
+
if key in self._mocks:
|
|
133
|
+
mock_strategy = MagicMock()
|
|
134
|
+
# Provide common callable patterns used in tests
|
|
135
|
+
mock_strategy.invoke = MagicMock(return_value=self._mocks[key])
|
|
136
|
+
mock_strategy.ainvoke = AsyncMock(return_value=self._mocks[key])
|
|
137
|
+
return mock_strategy
|
|
138
|
+
raise ValueError(f"TTC strategy '{strategy_name}' not mocked. Use mock_ttc_strategy() to add it.")
|
|
139
|
+
|
|
140
|
+
async def get_ttc_strategy_config(self,
|
|
141
|
+
strategy_name: str,
|
|
142
|
+
pipeline_type: PipelineTypeEnum,
|
|
143
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
|
144
|
+
"""Mock implementation."""
|
|
145
|
+
return TTCStrategyBaseConfig()
|
|
146
|
+
|
|
147
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
|
|
148
|
+
"""Mock implementation (no‑op)."""
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
152
|
+
"""Return a mock auth provider if one is configured."""
|
|
153
|
+
key = f"auth_provider_{auth_provider_name}"
|
|
154
|
+
if key in self._mocks:
|
|
155
|
+
mock_auth = MagicMock()
|
|
156
|
+
mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
|
|
157
|
+
return mock_auth
|
|
158
|
+
raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
|
|
159
|
+
|
|
160
|
+
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
161
|
+
"""Mock implementation - not used in tool testing."""
|
|
162
|
+
raise NotImplementedError("Mock implementation does not support add_function")
|
|
163
|
+
|
|
164
|
+
async def get_function(self, name: str) -> Function:
|
|
165
|
+
"""Return a mock function if one is configured."""
|
|
166
|
+
if name in self._mocks:
|
|
167
|
+
mock_fn = AsyncMock()
|
|
168
|
+
mock_fn.ainvoke = AsyncMock(return_value=self._mocks[name])
|
|
169
|
+
return mock_fn
|
|
170
|
+
raise ValueError(f"Function '{name}' not mocked. Use mock_function() to add it.")
|
|
171
|
+
|
|
172
|
+
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
173
|
+
"""Mock implementation."""
|
|
174
|
+
return FunctionBaseConfig()
|
|
175
|
+
|
|
176
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
|
177
|
+
"""Mock implementation - not used in tool testing."""
|
|
178
|
+
raise NotImplementedError("Mock implementation does not support add_function_group")
|
|
179
|
+
|
|
180
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
|
181
|
+
"""Return a mock function group if one is configured."""
|
|
182
|
+
if name in self._mocks:
|
|
183
|
+
mock_fn_group = MagicMock(spec=FunctionGroup)
|
|
184
|
+
mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
|
|
185
|
+
return mock_fn_group
|
|
186
|
+
raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
|
|
187
|
+
|
|
188
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
|
189
|
+
"""Mock implementation."""
|
|
190
|
+
return FunctionGroupBaseConfig()
|
|
191
|
+
|
|
192
|
+
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
193
|
+
"""Mock implementation."""
|
|
194
|
+
mock_fn = AsyncMock()
|
|
195
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
|
196
|
+
return mock_fn
|
|
197
|
+
|
|
198
|
+
def get_workflow(self) -> Function:
|
|
199
|
+
"""Mock implementation."""
|
|
200
|
+
mock_fn = AsyncMock()
|
|
201
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
|
202
|
+
return mock_fn
|
|
203
|
+
|
|
204
|
+
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
205
|
+
"""Mock implementation."""
|
|
206
|
+
return FunctionBaseConfig()
|
|
207
|
+
|
|
208
|
+
async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
|
|
209
|
+
"""Mock implementation."""
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
|
|
213
|
+
"""Mock implementation."""
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
async def add_llm(self, name: str, config) -> None:
|
|
217
|
+
"""Mock implementation."""
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
async def get_llm(self, llm_name: str, wrapper_type):
|
|
221
|
+
"""Return a mock LLM if one is configured."""
|
|
222
|
+
key = f"llm_{llm_name}"
|
|
223
|
+
if key in self._mocks:
|
|
224
|
+
mock_llm = MagicMock()
|
|
225
|
+
mock_llm.invoke = MagicMock(return_value=self._mocks[key])
|
|
226
|
+
mock_llm.ainvoke = AsyncMock(return_value=self._mocks[key])
|
|
227
|
+
return mock_llm
|
|
228
|
+
raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
|
|
229
|
+
|
|
230
|
+
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
|
231
|
+
"""Mock implementation."""
|
|
232
|
+
return LLMBaseConfig()
|
|
233
|
+
|
|
234
|
+
async def add_embedder(self, name: str, config) -> None:
|
|
235
|
+
"""Mock implementation."""
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
async def get_embedder(self, embedder_name: str, wrapper_type):
|
|
239
|
+
"""Return a mock embedder if one is configured."""
|
|
240
|
+
key = f"embedder_{embedder_name}"
|
|
241
|
+
if key in self._mocks:
|
|
242
|
+
mock_embedder = MagicMock()
|
|
243
|
+
mock_embedder.embed_query = MagicMock(return_value=self._mocks[key])
|
|
244
|
+
mock_embedder.embed_documents = MagicMock(return_value=self._mocks[key])
|
|
245
|
+
return mock_embedder
|
|
246
|
+
raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
|
|
247
|
+
|
|
248
|
+
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
|
249
|
+
"""Mock implementation."""
|
|
250
|
+
return EmbedderBaseConfig()
|
|
251
|
+
|
|
252
|
+
async def add_memory_client(self, name: str, config) -> MemoryEditor:
|
|
253
|
+
return MagicMock(spec=MemoryEditor)
|
|
254
|
+
|
|
255
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
|
256
|
+
"""Return a mock memory client if one is configured."""
|
|
257
|
+
key = f"memory_{memory_name}"
|
|
258
|
+
if key in self._mocks:
|
|
259
|
+
mock_memory = MagicMock()
|
|
260
|
+
mock_memory.add = AsyncMock(return_value=self._mocks[key])
|
|
261
|
+
mock_memory.search = AsyncMock(return_value=self._mocks[key])
|
|
262
|
+
return mock_memory
|
|
263
|
+
raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
|
|
264
|
+
|
|
265
|
+
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
|
266
|
+
"""Mock implementation."""
|
|
267
|
+
return MemoryBaseConfig()
|
|
268
|
+
|
|
269
|
+
async def add_retriever(self, name: str, config) -> None:
|
|
270
|
+
"""Mock implementation."""
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
async def get_retriever(self, retriever_name: str, wrapper_type=None):
|
|
274
|
+
"""Return a mock retriever if one is configured."""
|
|
275
|
+
key = f"retriever_{retriever_name}"
|
|
276
|
+
if key in self._mocks:
|
|
277
|
+
mock_retriever = MagicMock()
|
|
278
|
+
mock_retriever.retrieve = AsyncMock(return_value=self._mocks[key])
|
|
279
|
+
return mock_retriever
|
|
280
|
+
raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
|
|
281
|
+
|
|
282
|
+
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
|
283
|
+
"""Mock implementation."""
|
|
284
|
+
return RetrieverBaseConfig()
|
|
285
|
+
|
|
286
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
|
|
287
|
+
"""Mock implementation for object store."""
|
|
288
|
+
return MagicMock(spec=ObjectStore)
|
|
289
|
+
|
|
290
|
+
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
291
|
+
"""Return a mock object store client if one is configured."""
|
|
292
|
+
key = f"object_store_{object_store_name}"
|
|
293
|
+
if key in self._mocks:
|
|
294
|
+
mock_object_store = MagicMock()
|
|
295
|
+
mock_object_store.put_object = AsyncMock(return_value=self._mocks[key])
|
|
296
|
+
mock_object_store.get_object = AsyncMock(return_value=self._mocks[key])
|
|
297
|
+
mock_object_store.delete_object = AsyncMock(return_value=self._mocks[key])
|
|
298
|
+
mock_object_store.list_objects = AsyncMock(return_value=self._mocks[key])
|
|
299
|
+
return mock_object_store
|
|
300
|
+
raise ValueError(f"Object store '{object_store_name}' not mocked. Use mock_object_store() to add it.")
|
|
301
|
+
|
|
302
|
+
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
|
303
|
+
"""Mock implementation for object store config."""
|
|
304
|
+
return ObjectStoreBaseConfig()
|
|
305
|
+
|
|
306
|
+
def get_user_manager(self):
|
|
307
|
+
"""Mock implementation."""
|
|
308
|
+
mock_user = MagicMock()
|
|
309
|
+
mock_user.get_id = MagicMock(return_value="test_user")
|
|
310
|
+
return mock_user
|
|
311
|
+
|
|
312
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
313
|
+
"""Mock implementation."""
|
|
314
|
+
return FunctionDependencies()
|
|
315
|
+
|
|
316
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
|
317
|
+
"""Mock implementation."""
|
|
318
|
+
return FunctionDependencies()
|
|
319
|
+
|
|
320
|
+
async def get_middleware(self, middleware_name: str | MiddlewareRef) -> FunctionMiddleware:
|
|
321
|
+
"""Mock implementation."""
|
|
322
|
+
return FunctionMiddleware()
|
|
323
|
+
|
|
324
|
+
def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> FunctionMiddlewareBaseConfig:
|
|
325
|
+
"""Mock implementation."""
|
|
326
|
+
return FunctionMiddlewareBaseConfig()
|
|
327
|
+
|
|
328
|
+
async def add_middleware(self, name: str | MiddlewareRef,
|
|
329
|
+
config: FunctionMiddlewareBaseConfig) -> FunctionMiddleware:
|
|
330
|
+
"""Mock implementation."""
|
|
331
|
+
return FunctionMiddleware()
|
|
332
|
+
|
|
333
|
+
async def add_trainer(self, name: str, config: TrainerConfig) -> Trainer:
|
|
334
|
+
"""Mock implementation."""
|
|
335
|
+
return MagicMock(spec=Trainer)
|
|
336
|
+
|
|
337
|
+
async def get_trainer(self,
|
|
338
|
+
trainer_name: str,
|
|
339
|
+
trajectory_builder: TrajectoryBuilder,
|
|
340
|
+
trainer_adapter: TrainerAdapter) -> Trainer:
|
|
341
|
+
"""Return a mock trainer if one is configured."""
|
|
342
|
+
key = f"trainer_{trainer_name}"
|
|
343
|
+
if key in self._mocks:
|
|
344
|
+
mock_trainer = MagicMock()
|
|
345
|
+
mock_trainer.train = AsyncMock(return_value=self._mocks[key])
|
|
346
|
+
return mock_trainer
|
|
347
|
+
raise ValueError(f"Trainer '{trainer_name}' not mocked. Use mock_trainer() to add it.")
|
|
348
|
+
|
|
349
|
+
async def get_trainer_config(self, trainer_name: str) -> TrainerConfig:
|
|
350
|
+
"""Mock implementation."""
|
|
351
|
+
return TrainerConfig()
|
|
352
|
+
|
|
353
|
+
async def add_trainer_adapter(self, name: str, config: TrainerAdapterConfig) -> TrainerAdapter:
|
|
354
|
+
"""Mock implementation."""
|
|
355
|
+
return MagicMock(spec=TrainerAdapter)
|
|
356
|
+
|
|
357
|
+
async def get_trainer_adapter(self, trainer_adapter_name: str) -> TrainerAdapter:
|
|
358
|
+
"""Return a mock trainer adapter if one is configured."""
|
|
359
|
+
key = f"trainer_adapter_{trainer_adapter_name}"
|
|
360
|
+
if key in self._mocks:
|
|
361
|
+
mock_adapter = MagicMock()
|
|
362
|
+
mock_adapter.adapt = AsyncMock(return_value=self._mocks[key])
|
|
363
|
+
return mock_adapter
|
|
364
|
+
raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not mocked. Use mock_trainer_adapter() to add it.")
|
|
365
|
+
|
|
366
|
+
async def get_trainer_adapter_config(self, trainer_adapter_name: str) -> TrainerAdapterConfig:
|
|
367
|
+
"""Mock implementation."""
|
|
368
|
+
return TrainerAdapterConfig()
|
|
369
|
+
|
|
370
|
+
async def add_trajectory_builder(self, name: str, config: TrajectoryBuilderConfig) -> TrajectoryBuilder:
|
|
371
|
+
"""Mock implementation."""
|
|
372
|
+
return MagicMock(spec=TrajectoryBuilder)
|
|
373
|
+
|
|
374
|
+
async def get_trajectory_builder(self, trajectory_builder_name: str) -> TrajectoryBuilder:
|
|
375
|
+
"""Return a mock trajectory builder if one is configured."""
|
|
376
|
+
key = f"trajectory_builder_{trajectory_builder_name}"
|
|
377
|
+
if key in self._mocks:
|
|
378
|
+
mock_builder = MagicMock()
|
|
379
|
+
mock_builder.build = AsyncMock(return_value=self._mocks[key])
|
|
380
|
+
return mock_builder
|
|
381
|
+
raise ValueError(
|
|
382
|
+
f"Trajectory builder '{trajectory_builder_name}' not mocked. Use mock_trajectory_builder() to add it.")
|
|
383
|
+
|
|
384
|
+
async def get_trajectory_builder_config(self, trajectory_builder_name: str) -> TrajectoryBuilderConfig:
|
|
385
|
+
"""Mock implementation."""
|
|
386
|
+
return TrajectoryBuilderConfig()
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class ToolTestRunner:
|
|
390
|
+
"""
|
|
391
|
+
A test runner that enables isolated testing of NAT tools without requiring
|
|
392
|
+
full workflow setup, LLMs, or complex dependencies.
|
|
393
|
+
|
|
394
|
+
Usage:
|
|
395
|
+
runner = ToolTestRunner()
|
|
396
|
+
|
|
397
|
+
# Test a tool with minimal setup
|
|
398
|
+
result = await runner.test_tool(
|
|
399
|
+
config_type=MyToolConfig,
|
|
400
|
+
config_params={"param1": "value1"},
|
|
401
|
+
input_data="test input"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Test a tool with mocked dependencies
|
|
405
|
+
async with runner.with_mocks() as mock_builder:
|
|
406
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
407
|
+
result = await runner.test_tool(
|
|
408
|
+
config_type=MyToolConfig,
|
|
409
|
+
config_params={"llm_name": "my_llm"},
|
|
410
|
+
input_data="test input"
|
|
411
|
+
)
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def __init__(self):
|
|
415
|
+
self._ensure_plugins_loaded()
|
|
416
|
+
|
|
417
|
+
def _ensure_plugins_loaded(self):
|
|
418
|
+
"""Ensure all plugins are loaded for tool registration."""
|
|
419
|
+
discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
|
|
420
|
+
|
|
421
|
+
async def test_tool(self,
|
|
422
|
+
config_type: type[FunctionBaseConfig],
|
|
423
|
+
config_params: dict[str, typing.Any] | None = None,
|
|
424
|
+
input_data: typing.Any = None,
|
|
425
|
+
expected_output: typing.Any = None,
|
|
426
|
+
**kwargs) -> typing.Any:
|
|
427
|
+
"""
|
|
428
|
+
Test a tool in isolation with minimal setup.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
config_type: The tool configuration class
|
|
432
|
+
config_params: Parameters to pass to the config constructor
|
|
433
|
+
input_data: Input data to pass to the tool
|
|
434
|
+
expected_output: Expected output for assertion (optional)
|
|
435
|
+
**kwargs: Additional parameters
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
The tool's output
|
|
439
|
+
|
|
440
|
+
Raises:
|
|
441
|
+
AssertionError: If expected_output is provided and doesn't match
|
|
442
|
+
ValueError: If tool registration or execution fails
|
|
443
|
+
"""
|
|
444
|
+
config_params = config_params or {}
|
|
445
|
+
|
|
446
|
+
# Create tool configuration
|
|
447
|
+
config = config_type(**config_params)
|
|
448
|
+
|
|
449
|
+
# Get the registered tool function
|
|
450
|
+
registry = GlobalTypeRegistry.get()
|
|
451
|
+
try:
|
|
452
|
+
tool_registration = registry.get_function(config_type)
|
|
453
|
+
except KeyError:
|
|
454
|
+
raise ValueError(
|
|
455
|
+
f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Create a mock builder for dependencies
|
|
459
|
+
mock_builder = MockBuilder()
|
|
460
|
+
|
|
461
|
+
# Build the tool function
|
|
462
|
+
async with tool_registration.build_fn(config, mock_builder) as tool_result:
|
|
463
|
+
|
|
464
|
+
# Handle different tool result types
|
|
465
|
+
if isinstance(tool_result, Function):
|
|
466
|
+
tool_function = tool_result
|
|
467
|
+
elif isinstance(tool_result, FunctionInfo):
|
|
468
|
+
# Extract the actual function from FunctionInfo
|
|
469
|
+
if tool_result.single_fn:
|
|
470
|
+
tool_function = tool_result.single_fn
|
|
471
|
+
elif tool_result.stream_fn:
|
|
472
|
+
tool_function = tool_result.stream_fn
|
|
473
|
+
else:
|
|
474
|
+
raise ValueError("Tool function not found in FunctionInfo")
|
|
475
|
+
elif callable(tool_result):
|
|
476
|
+
tool_function = tool_result
|
|
477
|
+
else:
|
|
478
|
+
raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
|
|
479
|
+
|
|
480
|
+
# Execute the tool
|
|
481
|
+
if input_data is not None:
|
|
482
|
+
if isinstance(tool_function, Function):
|
|
483
|
+
result = await tool_function.ainvoke(input_data)
|
|
484
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
|
485
|
+
result = await tool_function(input_data)
|
|
486
|
+
else:
|
|
487
|
+
result = tool_function(input_data)
|
|
488
|
+
elif isinstance(tool_function, Function):
|
|
489
|
+
# Function objects require input, so pass None if no input_data
|
|
490
|
+
result = await tool_function.ainvoke(None)
|
|
491
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
|
492
|
+
result = await tool_function()
|
|
493
|
+
else:
|
|
494
|
+
result = tool_function()
|
|
495
|
+
|
|
496
|
+
# Assert expected output if provided
|
|
497
|
+
if expected_output is not None:
|
|
498
|
+
assert result == expected_output, f"Expected {expected_output}, got {result}"
|
|
499
|
+
|
|
500
|
+
return result
|
|
501
|
+
|
|
502
|
+
@asynccontextmanager
|
|
503
|
+
async def with_mocks(self):
|
|
504
|
+
"""
|
|
505
|
+
Context manager that provides a mock builder for setting up dependencies.
|
|
506
|
+
|
|
507
|
+
Usage:
|
|
508
|
+
async with runner.with_mocks() as mock_builder:
|
|
509
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
510
|
+
result = await runner.test_tool_with_builder(
|
|
511
|
+
config_type=MyToolConfig,
|
|
512
|
+
builder=mock_builder,
|
|
513
|
+
input_data="test input"
|
|
514
|
+
)
|
|
515
|
+
"""
|
|
516
|
+
mock_builder = MockBuilder()
|
|
517
|
+
try:
|
|
518
|
+
yield mock_builder
|
|
519
|
+
finally:
|
|
520
|
+
pass
|
|
521
|
+
|
|
522
|
+
async def test_tool_with_builder(
|
|
523
|
+
self,
|
|
524
|
+
config_type: type[FunctionBaseConfig],
|
|
525
|
+
builder: MockBuilder,
|
|
526
|
+
config_params: dict[str, typing.Any] | None = None,
|
|
527
|
+
input_data: typing.Any = None,
|
|
528
|
+
expected_output: typing.Any = None,
|
|
529
|
+
) -> typing.Any:
|
|
530
|
+
"""
|
|
531
|
+
Test a tool with a pre-configured mock builder.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
config_type: The tool configuration class
|
|
535
|
+
builder: Pre-configured MockBuilder with mocked dependencies
|
|
536
|
+
config_params: Parameters to pass to the config constructor
|
|
537
|
+
input_data: Input data to pass to the tool
|
|
538
|
+
expected_output: Expected output for assertion (optional)
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
The tool's output
|
|
542
|
+
"""
|
|
543
|
+
config_params = config_params or {}
|
|
544
|
+
|
|
545
|
+
# Create tool configuration
|
|
546
|
+
config = config_type(**config_params)
|
|
547
|
+
|
|
548
|
+
# Get the registered tool function
|
|
549
|
+
registry = GlobalTypeRegistry.get()
|
|
550
|
+
try:
|
|
551
|
+
tool_registration = registry.get_function(config_type)
|
|
552
|
+
except KeyError:
|
|
553
|
+
raise ValueError(
|
|
554
|
+
f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Build the tool function with the provided builder
|
|
558
|
+
async with tool_registration.build_fn(config, builder) as tool_result:
|
|
559
|
+
|
|
560
|
+
# Handle different tool result types (same as above)
|
|
561
|
+
if isinstance(tool_result, Function):
|
|
562
|
+
tool_function = tool_result
|
|
563
|
+
elif isinstance(tool_result, FunctionInfo):
|
|
564
|
+
if tool_result.single_fn:
|
|
565
|
+
tool_function = tool_result.single_fn
|
|
566
|
+
elif tool_result.stream_fn:
|
|
567
|
+
tool_function = tool_result.stream_fn
|
|
568
|
+
else:
|
|
569
|
+
raise ValueError("Tool function not found in FunctionInfo")
|
|
570
|
+
elif callable(tool_result):
|
|
571
|
+
tool_function = tool_result
|
|
572
|
+
else:
|
|
573
|
+
raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
|
|
574
|
+
|
|
575
|
+
# Execute the tool
|
|
576
|
+
if input_data is not None:
|
|
577
|
+
if isinstance(tool_function, Function):
|
|
578
|
+
result = await tool_function.ainvoke(input_data)
|
|
579
|
+
else:
|
|
580
|
+
maybe_result = tool_function(input_data)
|
|
581
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
|
582
|
+
elif isinstance(tool_function, Function):
|
|
583
|
+
# Function objects require input, so pass None if no input_data
|
|
584
|
+
result = await tool_function.ainvoke(None)
|
|
585
|
+
else:
|
|
586
|
+
maybe_result = tool_function()
|
|
587
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
|
588
|
+
|
|
589
|
+
# Assert expected output if provided
|
|
590
|
+
if expected_output is not None:
|
|
591
|
+
assert result == expected_output, f"Expected {expected_output}, got {result}"
|
|
592
|
+
|
|
593
|
+
return result
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
@asynccontextmanager
|
|
597
|
+
async def with_mocked_dependencies():
|
|
598
|
+
"""
|
|
599
|
+
Convenience context manager for testing tools with mocked dependencies.
|
|
600
|
+
|
|
601
|
+
Usage:
|
|
602
|
+
async with with_mocked_dependencies() as (runner, mock_builder):
|
|
603
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
604
|
+
result = await runner.test_tool_with_builder(
|
|
605
|
+
config_type=MyToolConfig,
|
|
606
|
+
builder=mock_builder,
|
|
607
|
+
input_data="test input"
|
|
608
|
+
)
|
|
609
|
+
"""
|
|
610
|
+
runner = ToolTestRunner()
|
|
611
|
+
async with runner.with_mocks() as mock_builder:
|
|
612
|
+
yield runner, mock_builder
|