nvidia-nat-test 1.2.0a20250813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-test might be problematic. Click here for more details.
- nat/meta/pypi.md +23 -0
- nat/test/__init__.py +23 -0
- nat/test/embedder.py +44 -0
- nat/test/functions.py +91 -0
- nat/test/memory.py +41 -0
- nat/test/object_store_tests.py +117 -0
- nat/test/plugin.py +97 -0
- nat/test/register.py +24 -0
- nat/test/tool_test_runner.py +449 -0
- nvidia_nat_test-1.2.0a20250813.dist-info/METADATA +35 -0
- nvidia_nat_test-1.2.0a20250813.dist-info/RECORD +14 -0
- nvidia_nat_test-1.2.0a20250813.dist-info/WHEEL +5 -0
- nvidia_nat_test-1.2.0a20250813.dist-info/entry_points.txt +5 -0
- nvidia_nat_test-1.2.0a20250813.dist-info/top_level.txt +1 -0
nat/meta/pypi.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
<!--
|
|
2
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
you may not use this file except in compliance with the License.
|
|
7
|
+
You may obtain a copy of the License at
|
|
8
|
+
|
|
9
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
|
|
11
|
+
Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
See the License for the specific language governing permissions and
|
|
15
|
+
limitations under the License.
|
|
16
|
+
-->
|
|
17
|
+
|
|
18
|
+

|
|
19
|
+
|
|
20
|
+
# NVIDIA NeMo Agent Toolkit Subpackage
|
|
21
|
+
This is a subpackage for NeMo Agent toolkit test utilities.
|
|
22
|
+
|
|
23
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
nat/test/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# Tool testing utilities
|
|
17
|
+
from .tool_test_runner import ToolTestRunner
|
|
18
|
+
from .tool_test_runner import with_mocked_dependencies
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"ToolTestRunner",
|
|
22
|
+
"with_mocked_dependencies",
|
|
23
|
+
]
|
nat/test/embedder.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from pydantic import ConfigDict
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import Builder
|
|
19
|
+
from nat.builder.embedder import EmbedderProviderInfo
|
|
20
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
21
|
+
from nat.cli.register_workflow import register_embedder_client
|
|
22
|
+
from nat.cli.register_workflow import register_embedder_provider
|
|
23
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmbedderTestConfig(EmbedderBaseConfig, name="test_embedder"):
|
|
27
|
+
model_config = ConfigDict(protected_namespaces=())
|
|
28
|
+
|
|
29
|
+
model_name: str = "nvidia/nv-embedqa-e5-v5"
|
|
30
|
+
embedding_size: int = 768
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@register_embedder_provider(config_type=EmbedderTestConfig)
|
|
34
|
+
async def embedder_test_provider(config: EmbedderTestConfig, builder: Builder):
|
|
35
|
+
|
|
36
|
+
yield EmbedderProviderInfo(config=config, description="Test embedder provider")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@register_embedder_client(config_type=EmbedderTestConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
40
|
+
async def embedder_langchain_test_client(config: EmbedderTestConfig, builder: Builder):
|
|
41
|
+
|
|
42
|
+
from langchain_community.embeddings import DeterministicFakeEmbedding
|
|
43
|
+
|
|
44
|
+
yield DeterministicFakeEmbedding(size=config.embedding_size)
|
nat/test/functions.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections.abc import AsyncGenerator
|
|
17
|
+
|
|
18
|
+
from nat.builder.builder import Builder
|
|
19
|
+
from nat.builder.function_info import FunctionInfo
|
|
20
|
+
from nat.cli.register_workflow import register_function
|
|
21
|
+
from nat.data_models.api_server import ChatRequest
|
|
22
|
+
from nat.data_models.api_server import ChatResponse
|
|
23
|
+
from nat.data_models.api_server import ChatResponseChunk
|
|
24
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EchoFunctionConfig(FunctionBaseConfig, name="test_echo"):
|
|
28
|
+
use_openai_api: bool = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@register_function(config_type=EchoFunctionConfig)
|
|
32
|
+
async def echo_function(config: EchoFunctionConfig, builder: Builder):
|
|
33
|
+
|
|
34
|
+
async def inner(message: str) -> str:
|
|
35
|
+
return message
|
|
36
|
+
|
|
37
|
+
async def inner_oai(message: ChatRequest) -> ChatResponse:
|
|
38
|
+
return ChatResponse.from_string(message.messages[0].content)
|
|
39
|
+
|
|
40
|
+
if (config.use_openai_api):
|
|
41
|
+
yield inner_oai
|
|
42
|
+
else:
|
|
43
|
+
yield inner
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class StreamingEchoFunctionConfig(FunctionBaseConfig, name="test_streaming_echo"):
|
|
47
|
+
use_openai_api: bool = False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@register_function(config_type=StreamingEchoFunctionConfig)
|
|
51
|
+
async def streaming_function(config: StreamingEchoFunctionConfig, builder: Builder):
|
|
52
|
+
|
|
53
|
+
def oai_to_list(message: ChatRequest) -> list[str]:
|
|
54
|
+
return [m.content for m in message.messages]
|
|
55
|
+
|
|
56
|
+
async def inner(message: list[str]) -> AsyncGenerator[str]:
|
|
57
|
+
for value in message:
|
|
58
|
+
yield value
|
|
59
|
+
|
|
60
|
+
async def inner_oai(message: ChatRequest) -> AsyncGenerator[ChatResponseChunk]:
|
|
61
|
+
for value in oai_to_list(message):
|
|
62
|
+
yield ChatResponseChunk.from_string(value)
|
|
63
|
+
|
|
64
|
+
yield FunctionInfo.from_fn(inner_oai if config.use_openai_api else inner, converters=[oai_to_list])
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class ConstantFunctionConfig(FunctionBaseConfig, name="test_constant"):
|
|
68
|
+
response: str
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@register_function(config_type=ConstantFunctionConfig)
|
|
72
|
+
async def constant_function(config: ConstantFunctionConfig, builder: Builder):
|
|
73
|
+
|
|
74
|
+
async def inner() -> str:
|
|
75
|
+
return config.response
|
|
76
|
+
|
|
77
|
+
yield inner
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class StreamingConstantFunctionConfig(FunctionBaseConfig, name="test_streaming_constant"):
|
|
81
|
+
responses: list[str]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@register_function(config_type=StreamingConstantFunctionConfig)
|
|
85
|
+
async def streaming_constant_function(config: StreamingConstantFunctionConfig, builder: Builder):
|
|
86
|
+
|
|
87
|
+
async def inner() -> AsyncGenerator[str]:
|
|
88
|
+
for value in config.responses:
|
|
89
|
+
yield value
|
|
90
|
+
|
|
91
|
+
yield inner
|
nat/test/memory.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from nat.builder.builder import Builder
|
|
17
|
+
from nat.cli.register_workflow import register_memory
|
|
18
|
+
from nat.data_models.memory import MemoryBaseConfig
|
|
19
|
+
from nat.memory.interfaces import MemoryEditor
|
|
20
|
+
from nat.memory.models import MemoryItem
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DummyMemoryConfig(MemoryBaseConfig, name="test_dummy"):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@register_memory(config_type=DummyMemoryConfig)
|
|
28
|
+
async def echo_function(config: DummyMemoryConfig, builder: Builder):
|
|
29
|
+
|
|
30
|
+
class DummyMemoryEditor(MemoryEditor):
|
|
31
|
+
|
|
32
|
+
async def add_items(self, items: list[MemoryItem]) -> None:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem]:
|
|
36
|
+
return []
|
|
37
|
+
|
|
38
|
+
async def remove_items(self, **kwargs) -> None:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
yield DummyMemoryEditor()
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import uuid
|
|
17
|
+
from abc import abstractmethod
|
|
18
|
+
from contextlib import asynccontextmanager
|
|
19
|
+
|
|
20
|
+
import pytest
|
|
21
|
+
import pytest_asyncio
|
|
22
|
+
|
|
23
|
+
from nat.data_models.object_store import KeyAlreadyExistsError
|
|
24
|
+
from nat.data_models.object_store import NoSuchKeyError
|
|
25
|
+
from nat.object_store.interfaces import ObjectStore
|
|
26
|
+
from nat.object_store.models import ObjectStoreItem
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.mark.asyncio(loop_scope="class")
|
|
30
|
+
class ObjectStoreTests:
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
@asynccontextmanager
|
|
34
|
+
async def _get_store(self):
|
|
35
|
+
yield
|
|
36
|
+
|
|
37
|
+
@pytest_asyncio.fixture(loop_scope="class", scope="class")
|
|
38
|
+
async def store(self):
|
|
39
|
+
|
|
40
|
+
async with self._get_store() as store:
|
|
41
|
+
yield store
|
|
42
|
+
|
|
43
|
+
async def test_create_object_store(self, store: ObjectStore):
|
|
44
|
+
assert isinstance(store, ObjectStore)
|
|
45
|
+
|
|
46
|
+
async def test_put_object(self, store: ObjectStore):
|
|
47
|
+
|
|
48
|
+
# Use a random key to avoid conflicts with other tests
|
|
49
|
+
key = f"test_key_{uuid.uuid4()}"
|
|
50
|
+
|
|
51
|
+
initial_item = ObjectStoreItem(data=b"test_value")
|
|
52
|
+
await store.put_object(key, initial_item)
|
|
53
|
+
|
|
54
|
+
# Try to put the same object again
|
|
55
|
+
with pytest.raises(KeyAlreadyExistsError):
|
|
56
|
+
await store.put_object(key, initial_item)
|
|
57
|
+
|
|
58
|
+
async def test_upsert_object(self, store: ObjectStore):
|
|
59
|
+
key = f"test_key_{uuid.uuid4()}"
|
|
60
|
+
|
|
61
|
+
initial_item = ObjectStoreItem(data=b"test_value", content_type="text/plain", metadata={"key": "value"})
|
|
62
|
+
|
|
63
|
+
await store.upsert_object(key, initial_item)
|
|
64
|
+
|
|
65
|
+
# Check that the object exists
|
|
66
|
+
retrieved_item = await store.get_object(key)
|
|
67
|
+
assert retrieved_item.data == initial_item.data
|
|
68
|
+
assert retrieved_item.content_type == initial_item.content_type
|
|
69
|
+
assert retrieved_item.metadata == initial_item.metadata
|
|
70
|
+
|
|
71
|
+
# Upsert the object with a new value
|
|
72
|
+
new_item = ObjectStoreItem(data=b"new_value", content_type="application/json", metadata={"key": "new_value"})
|
|
73
|
+
await store.upsert_object(key, new_item)
|
|
74
|
+
|
|
75
|
+
# Check that the object was updated
|
|
76
|
+
retrieved_item = await store.get_object(key)
|
|
77
|
+
assert retrieved_item.data == new_item.data
|
|
78
|
+
assert retrieved_item.content_type == new_item.content_type
|
|
79
|
+
assert retrieved_item.metadata == new_item.metadata
|
|
80
|
+
|
|
81
|
+
async def test_get_object(self, store: ObjectStore):
|
|
82
|
+
|
|
83
|
+
key = f"test_key_{uuid.uuid4()}"
|
|
84
|
+
|
|
85
|
+
initial_item = ObjectStoreItem(data=b"test_value", content_type="text/plain", metadata={"key": "value"})
|
|
86
|
+
await store.put_object(key, initial_item)
|
|
87
|
+
|
|
88
|
+
retrieved_item = await store.get_object(key)
|
|
89
|
+
assert retrieved_item.data == initial_item.data
|
|
90
|
+
assert retrieved_item.content_type == initial_item.content_type
|
|
91
|
+
assert retrieved_item.metadata == initial_item.metadata
|
|
92
|
+
|
|
93
|
+
# Try to get an object that doesn't exist
|
|
94
|
+
with pytest.raises(NoSuchKeyError):
|
|
95
|
+
await store.get_object(f"test_key_{uuid.uuid4()}")
|
|
96
|
+
|
|
97
|
+
async def test_delete_object(self, store: ObjectStore):
|
|
98
|
+
|
|
99
|
+
key = f"test_key_{uuid.uuid4()}"
|
|
100
|
+
|
|
101
|
+
initial_item = ObjectStoreItem(data=b"test_value")
|
|
102
|
+
await store.put_object(key, initial_item)
|
|
103
|
+
|
|
104
|
+
# Check that the object exists
|
|
105
|
+
retrieved_item = await store.get_object(key)
|
|
106
|
+
assert retrieved_item.data == initial_item.data
|
|
107
|
+
|
|
108
|
+
# Delete the object
|
|
109
|
+
await store.delete_object(key)
|
|
110
|
+
|
|
111
|
+
# Try to get the object again
|
|
112
|
+
with pytest.raises(NoSuchKeyError):
|
|
113
|
+
await store.get_object(key)
|
|
114
|
+
|
|
115
|
+
# Try to delete the object again
|
|
116
|
+
with pytest.raises(NoSuchKeyError):
|
|
117
|
+
await store.delete_object(key)
|
nat/test/plugin.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import pytest
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def pytest_addoption(parser: pytest.Parser):
|
|
20
|
+
"""
|
|
21
|
+
Adds command line options for running specfic tests that are disabled by default
|
|
22
|
+
"""
|
|
23
|
+
parser.addoption(
|
|
24
|
+
"--run_e2e",
|
|
25
|
+
action="store_true",
|
|
26
|
+
dest="run_e2e",
|
|
27
|
+
help="Run end to end tests that would otherwise be skipped",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
parser.addoption(
|
|
31
|
+
"--run_integration",
|
|
32
|
+
action="store_true",
|
|
33
|
+
dest="run_integration",
|
|
34
|
+
help=("Run integrations tests that would otherwise be skipped. "
|
|
35
|
+
"This will call out to external services instead of using mocks"),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
parser.addoption(
|
|
39
|
+
"--run_slow",
|
|
40
|
+
action="store_true",
|
|
41
|
+
dest="run_slow",
|
|
42
|
+
help="Run end to end tests that would otherwise be skipped",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def pytest_runtest_setup(item):
|
|
47
|
+
if (not item.config.getoption("--run_e2e")):
|
|
48
|
+
if (item.get_closest_marker("e2e") is not None):
|
|
49
|
+
pytest.skip("Skipping end to end tests by default. Use --run_e2e to enable")
|
|
50
|
+
|
|
51
|
+
if (not item.config.getoption("--run_integration")):
|
|
52
|
+
if (item.get_closest_marker("integration") is not None):
|
|
53
|
+
pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
|
|
54
|
+
|
|
55
|
+
if (not item.config.getoption("--run_slow")):
|
|
56
|
+
if (item.get_closest_marker("slow") is not None):
|
|
57
|
+
pytest.skip("Skipping slow tests by default. Use --run_slow to enable")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture(name="register_components", scope="session", autouse=True)
|
|
61
|
+
def register_components_fixture():
|
|
62
|
+
from nat.runtime.loader import PluginTypes
|
|
63
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
64
|
+
|
|
65
|
+
# Ensure that all components which need to be registered as part of an import are done so. This is necessary
|
|
66
|
+
# because imports will not be reloaded between tests, so we need to ensure that all components are registered
|
|
67
|
+
# before any tests are run.
|
|
68
|
+
discover_and_register_plugins(PluginTypes.ALL)
|
|
69
|
+
|
|
70
|
+
# Also import the aiq.test.register module to register test-only components
|
|
71
|
+
import nat.test.register # pylint: disable=unused-import # noqa: F401
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest.fixture(name="module_registry", scope="module", autouse=True)
|
|
75
|
+
def module_registry_fixture():
|
|
76
|
+
"""
|
|
77
|
+
Resets and returns the global type registry for testing
|
|
78
|
+
|
|
79
|
+
This gets automatically used at the module level to ensure no state is leaked between modules
|
|
80
|
+
"""
|
|
81
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
82
|
+
|
|
83
|
+
with GlobalTypeRegistry.push() as registry:
|
|
84
|
+
yield registry
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@pytest.fixture(name="registry", scope="function", autouse=True)
|
|
88
|
+
def function_registry_fixture():
|
|
89
|
+
"""
|
|
90
|
+
Resets and returns the global type registry for testing
|
|
91
|
+
|
|
92
|
+
This gets automatically used at the function level to ensure no state is leaked between functions
|
|
93
|
+
"""
|
|
94
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
95
|
+
|
|
96
|
+
with GlobalTypeRegistry.push() as registry:
|
|
97
|
+
yield registry
|
nat/test/register.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# pylint: disable=unused-import
|
|
17
|
+
# flake8: noqa
|
|
18
|
+
# isort:skip_file
|
|
19
|
+
|
|
20
|
+
# Import any providers which need to be automatically registered here
|
|
21
|
+
|
|
22
|
+
from . import embedder
|
|
23
|
+
from . import functions
|
|
24
|
+
from . import memory
|
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import logging
|
|
18
|
+
import typing
|
|
19
|
+
from contextlib import asynccontextmanager
|
|
20
|
+
from unittest.mock import AsyncMock
|
|
21
|
+
from unittest.mock import MagicMock
|
|
22
|
+
|
|
23
|
+
from nat.builder.builder import Builder
|
|
24
|
+
from nat.builder.function import Function
|
|
25
|
+
from nat.builder.function_info import FunctionInfo
|
|
26
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
27
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
28
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
29
|
+
from nat.object_store.interfaces import ObjectStore
|
|
30
|
+
from nat.runtime.loader import PluginTypes
|
|
31
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MockBuilder(Builder):
|
|
37
|
+
"""
|
|
38
|
+
A lightweight mock builder for tool testing that provides minimal dependencies.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
self._functions = {}
|
|
43
|
+
self._mocks = {}
|
|
44
|
+
|
|
45
|
+
def mock_function(self, name: str, mock_response: typing.Any):
|
|
46
|
+
"""Add a mock function that returns a fixed response."""
|
|
47
|
+
self._mocks[name] = mock_response
|
|
48
|
+
|
|
49
|
+
def mock_llm(self, name: str, mock_response: typing.Any):
|
|
50
|
+
"""Add a mock LLM that returns a fixed response."""
|
|
51
|
+
self._mocks[f"llm_{name}"] = mock_response
|
|
52
|
+
|
|
53
|
+
def mock_embedder(self, name: str, mock_response: typing.Any):
|
|
54
|
+
"""Add a mock embedder that returns a fixed response."""
|
|
55
|
+
self._mocks[f"embedder_{name}"] = mock_response
|
|
56
|
+
|
|
57
|
+
def mock_memory_client(self, name: str, mock_response: typing.Any):
|
|
58
|
+
"""Add a mock memory client that returns a fixed response."""
|
|
59
|
+
self._mocks[f"memory_{name}"] = mock_response
|
|
60
|
+
|
|
61
|
+
def mock_retriever(self, name: str, mock_response: typing.Any):
|
|
62
|
+
"""Add a mock retriever that returns a fixed response."""
|
|
63
|
+
self._mocks[f"retriever_{name}"] = mock_response
|
|
64
|
+
|
|
65
|
+
def mock_object_store(self, name: str, mock_response: typing.Any):
|
|
66
|
+
"""Add a mock object store that returns a fixed response."""
|
|
67
|
+
self._mocks[f"object_store_{name}"] = mock_response
|
|
68
|
+
|
|
69
|
+
def mock_ttc_strategy(self, name: str, mock_response: typing.Any):
|
|
70
|
+
"""Add a mock TTC strategy that returns a fixed response."""
|
|
71
|
+
self._mocks[f"ttc_strategy_{name}"] = mock_response
|
|
72
|
+
|
|
73
|
+
async def add_ttc_strategy(self, name: str, config):
|
|
74
|
+
"""Mock implementation (no‑op)."""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
async def get_ttc_strategy(self,
|
|
78
|
+
strategy_name: str,
|
|
79
|
+
pipeline_type: typing.Any = None,
|
|
80
|
+
stage_type: typing.Any = None):
|
|
81
|
+
"""Return a mock TTC strategy if one is configured."""
|
|
82
|
+
key = f"ttc_strategy_{strategy_name}"
|
|
83
|
+
if key in self._mocks:
|
|
84
|
+
mock_strategy = MagicMock()
|
|
85
|
+
# Provide common callable patterns used in tests
|
|
86
|
+
mock_strategy.invoke = MagicMock(return_value=self._mocks[key])
|
|
87
|
+
mock_strategy.ainvoke = AsyncMock(return_value=self._mocks[key])
|
|
88
|
+
return mock_strategy
|
|
89
|
+
raise ValueError(f"TTC strategy '{strategy_name}' not mocked. Use mock_ttc_strategy() to add it.")
|
|
90
|
+
|
|
91
|
+
async def get_ttc_strategy_config(self,
|
|
92
|
+
strategy_name: str,
|
|
93
|
+
pipeline_type: typing.Any = None,
|
|
94
|
+
stage_type: typing.Any = None):
|
|
95
|
+
"""Mock implementation."""
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
|
99
|
+
"""Mock implementation - not used in tool testing."""
|
|
100
|
+
raise NotImplementedError("Mock implementation does not support add_function")
|
|
101
|
+
|
|
102
|
+
def get_function(self, name: str) -> Function:
|
|
103
|
+
"""Return a mock function if one is configured."""
|
|
104
|
+
if name in self._mocks:
|
|
105
|
+
mock_fn = AsyncMock()
|
|
106
|
+
mock_fn.ainvoke = AsyncMock(return_value=self._mocks[name])
|
|
107
|
+
return mock_fn
|
|
108
|
+
raise ValueError(f"Function '{name}' not mocked. Use mock_function() to add it.")
|
|
109
|
+
|
|
110
|
+
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
|
111
|
+
"""Mock implementation."""
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
|
115
|
+
"""Mock implementation."""
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def get_workflow(self) -> Function:
|
|
119
|
+
"""Mock implementation."""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def get_workflow_config(self) -> FunctionBaseConfig:
|
|
123
|
+
"""Mock implementation."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def get_tool(self, fn_name: str, wrapper_type):
|
|
127
|
+
"""Mock implementation."""
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
async def add_llm(self, name: str, config):
|
|
131
|
+
"""Mock implementation."""
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
async def get_llm(self, llm_name: str, wrapper_type):
|
|
135
|
+
"""Return a mock LLM if one is configured."""
|
|
136
|
+
key = f"llm_{llm_name}"
|
|
137
|
+
if key in self._mocks:
|
|
138
|
+
mock_llm = MagicMock()
|
|
139
|
+
mock_llm.invoke = MagicMock(return_value=self._mocks[key])
|
|
140
|
+
mock_llm.ainvoke = AsyncMock(return_value=self._mocks[key])
|
|
141
|
+
return mock_llm
|
|
142
|
+
raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
|
|
143
|
+
|
|
144
|
+
def get_llm_config(self, llm_name: str):
|
|
145
|
+
"""Mock implementation."""
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
async def add_embedder(self, name: str, config):
|
|
149
|
+
"""Mock implementation."""
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
async def get_embedder(self, embedder_name: str, wrapper_type):
|
|
153
|
+
"""Return a mock embedder if one is configured."""
|
|
154
|
+
key = f"embedder_{embedder_name}"
|
|
155
|
+
if key in self._mocks:
|
|
156
|
+
mock_embedder = MagicMock()
|
|
157
|
+
mock_embedder.embed_query = MagicMock(return_value=self._mocks[key])
|
|
158
|
+
mock_embedder.embed_documents = MagicMock(return_value=self._mocks[key])
|
|
159
|
+
return mock_embedder
|
|
160
|
+
raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
|
|
161
|
+
|
|
162
|
+
def get_embedder_config(self, embedder_name: str):
|
|
163
|
+
"""Mock implementation."""
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
async def add_memory_client(self, name: str, config):
|
|
167
|
+
"""Mock implementation."""
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
def get_memory_client(self, memory_name: str):
|
|
171
|
+
"""Return a mock memory client if one is configured."""
|
|
172
|
+
key = f"memory_{memory_name}"
|
|
173
|
+
if key in self._mocks:
|
|
174
|
+
mock_memory = MagicMock()
|
|
175
|
+
mock_memory.add = AsyncMock(return_value=self._mocks[key])
|
|
176
|
+
mock_memory.search = AsyncMock(return_value=self._mocks[key])
|
|
177
|
+
return mock_memory
|
|
178
|
+
raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
|
|
179
|
+
|
|
180
|
+
def get_memory_client_config(self, memory_name: str):
|
|
181
|
+
"""Mock implementation."""
|
|
182
|
+
pass
|
|
183
|
+
|
|
184
|
+
async def add_retriever(self, name: str, config):
|
|
185
|
+
"""Mock implementation."""
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
async def get_retriever(self, retriever_name: str, wrapper_type=None):
|
|
189
|
+
"""Return a mock retriever if one is configured."""
|
|
190
|
+
key = f"retriever_{retriever_name}"
|
|
191
|
+
if key in self._mocks:
|
|
192
|
+
mock_retriever = MagicMock()
|
|
193
|
+
mock_retriever.retrieve = AsyncMock(return_value=self._mocks[key])
|
|
194
|
+
return mock_retriever
|
|
195
|
+
raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
|
|
196
|
+
|
|
197
|
+
async def get_retriever_config(self, retriever_name: str):
|
|
198
|
+
"""Mock implementation."""
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
|
202
|
+
"""Mock implementation for object store."""
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
|
206
|
+
"""Return a mock object store client if one is configured."""
|
|
207
|
+
key = f"object_store_{object_store_name}"
|
|
208
|
+
if key in self._mocks:
|
|
209
|
+
mock_object_store = MagicMock()
|
|
210
|
+
mock_object_store.put_object = AsyncMock(return_value=self._mocks[key])
|
|
211
|
+
mock_object_store.get_object = AsyncMock(return_value=self._mocks[key])
|
|
212
|
+
mock_object_store.delete_object = AsyncMock(return_value=self._mocks[key])
|
|
213
|
+
mock_object_store.list_objects = AsyncMock(return_value=self._mocks[key])
|
|
214
|
+
return mock_object_store
|
|
215
|
+
raise ValueError(f"Object store '{object_store_name}' not mocked. Use mock_object_store() to add it.")
|
|
216
|
+
|
|
217
|
+
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
|
218
|
+
"""Mock implementation for object store config."""
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
def get_user_manager(self):
|
|
222
|
+
"""Mock implementation."""
|
|
223
|
+
mock_user = MagicMock()
|
|
224
|
+
mock_user.get_id = MagicMock(return_value="test_user")
|
|
225
|
+
return mock_user
|
|
226
|
+
|
|
227
|
+
def get_function_dependencies(self, fn_name: str):
|
|
228
|
+
"""Mock implementation."""
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class ToolTestRunner:
|
|
233
|
+
"""
|
|
234
|
+
A test runner that enables isolated testing of AIQ tools without requiring
|
|
235
|
+
full workflow setup, LLMs, or complex dependencies.
|
|
236
|
+
|
|
237
|
+
Usage:
|
|
238
|
+
runner = ToolTestRunner()
|
|
239
|
+
|
|
240
|
+
# Test a tool with minimal setup
|
|
241
|
+
result = await runner.test_tool(
|
|
242
|
+
config_type=MyToolConfig,
|
|
243
|
+
config_params={"param1": "value1"},
|
|
244
|
+
input_data="test input"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Test a tool with mocked dependencies
|
|
248
|
+
async with runner.with_mocks() as mock_builder:
|
|
249
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
250
|
+
result = await runner.test_tool(
|
|
251
|
+
config_type=MyToolConfig,
|
|
252
|
+
config_params={"llm_name": "my_llm"},
|
|
253
|
+
input_data="test input"
|
|
254
|
+
)
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(self):
|
|
258
|
+
self._ensure_plugins_loaded()
|
|
259
|
+
|
|
260
|
+
def _ensure_plugins_loaded(self):
|
|
261
|
+
"""Ensure all plugins are loaded for tool registration."""
|
|
262
|
+
discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
|
|
263
|
+
|
|
264
|
+
async def test_tool(self,
|
|
265
|
+
config_type: type[FunctionBaseConfig],
|
|
266
|
+
config_params: dict[str, typing.Any] | None = None,
|
|
267
|
+
input_data: typing.Any = None,
|
|
268
|
+
expected_output: typing.Any = None,
|
|
269
|
+
**kwargs) -> typing.Any:
|
|
270
|
+
"""
|
|
271
|
+
Test a tool in isolation with minimal setup.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
config_type: The tool configuration class
|
|
275
|
+
config_params: Parameters to pass to the config constructor
|
|
276
|
+
input_data: Input data to pass to the tool
|
|
277
|
+
expected_output: Expected output for assertion (optional)
|
|
278
|
+
**kwargs: Additional parameters
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
The tool's output
|
|
282
|
+
|
|
283
|
+
Raises:
|
|
284
|
+
AssertionError: If expected_output is provided and doesn't match
|
|
285
|
+
ValueError: If tool registration or execution fails
|
|
286
|
+
"""
|
|
287
|
+
config_params = config_params or {}
|
|
288
|
+
|
|
289
|
+
# Create tool configuration
|
|
290
|
+
config = config_type(**config_params)
|
|
291
|
+
|
|
292
|
+
# Get the registered tool function
|
|
293
|
+
registry = GlobalTypeRegistry.get()
|
|
294
|
+
try:
|
|
295
|
+
tool_registration = registry.get_function(config_type)
|
|
296
|
+
except KeyError:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Create a mock builder for dependencies
|
|
302
|
+
mock_builder = MockBuilder()
|
|
303
|
+
|
|
304
|
+
# Build the tool function
|
|
305
|
+
async with tool_registration.build_fn(config, mock_builder) as tool_result:
|
|
306
|
+
|
|
307
|
+
# Handle different tool result types
|
|
308
|
+
if isinstance(tool_result, Function):
|
|
309
|
+
tool_function = tool_result
|
|
310
|
+
elif isinstance(tool_result, FunctionInfo):
|
|
311
|
+
# Extract the actual function from FunctionInfo
|
|
312
|
+
if tool_result.single_fn:
|
|
313
|
+
tool_function = tool_result.single_fn
|
|
314
|
+
elif tool_result.stream_fn:
|
|
315
|
+
tool_function = tool_result.stream_fn
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError("Tool function not found in FunctionInfo")
|
|
318
|
+
elif callable(tool_result):
|
|
319
|
+
tool_function = tool_result
|
|
320
|
+
else:
|
|
321
|
+
raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
|
|
322
|
+
|
|
323
|
+
# Execute the tool
|
|
324
|
+
if input_data is not None:
|
|
325
|
+
if asyncio.iscoroutinefunction(tool_function):
|
|
326
|
+
result = await tool_function(input_data)
|
|
327
|
+
else:
|
|
328
|
+
result = tool_function(input_data)
|
|
329
|
+
else:
|
|
330
|
+
if asyncio.iscoroutinefunction(tool_function):
|
|
331
|
+
result = await tool_function()
|
|
332
|
+
else:
|
|
333
|
+
result = tool_function()
|
|
334
|
+
|
|
335
|
+
# Assert expected output if provided
|
|
336
|
+
if expected_output is not None:
|
|
337
|
+
assert result == expected_output, f"Expected {expected_output}, got {result}"
|
|
338
|
+
|
|
339
|
+
return result
|
|
340
|
+
|
|
341
|
+
@asynccontextmanager
|
|
342
|
+
async def with_mocks(self):
|
|
343
|
+
"""
|
|
344
|
+
Context manager that provides a mock builder for setting up dependencies.
|
|
345
|
+
|
|
346
|
+
Usage:
|
|
347
|
+
async with runner.with_mocks() as mock_builder:
|
|
348
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
349
|
+
result = await runner.test_tool_with_builder(
|
|
350
|
+
config_type=MyToolConfig,
|
|
351
|
+
builder=mock_builder,
|
|
352
|
+
input_data="test input"
|
|
353
|
+
)
|
|
354
|
+
"""
|
|
355
|
+
mock_builder = MockBuilder()
|
|
356
|
+
try:
|
|
357
|
+
yield mock_builder
|
|
358
|
+
finally:
|
|
359
|
+
pass
|
|
360
|
+
|
|
361
|
+
async def test_tool_with_builder(
|
|
362
|
+
self,
|
|
363
|
+
config_type: type[FunctionBaseConfig],
|
|
364
|
+
builder: MockBuilder,
|
|
365
|
+
config_params: dict[str, typing.Any] | None = None,
|
|
366
|
+
input_data: typing.Any = None,
|
|
367
|
+
expected_output: typing.Any = None,
|
|
368
|
+
) -> typing.Any:
|
|
369
|
+
"""
|
|
370
|
+
Test a tool with a pre-configured mock builder.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
config_type: The tool configuration class
|
|
374
|
+
builder: Pre-configured MockBuilder with mocked dependencies
|
|
375
|
+
config_params: Parameters to pass to the config constructor
|
|
376
|
+
input_data: Input data to pass to the tool
|
|
377
|
+
expected_output: Expected output for assertion (optional)
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
The tool's output
|
|
381
|
+
"""
|
|
382
|
+
config_params = config_params or {}
|
|
383
|
+
|
|
384
|
+
# Create tool configuration
|
|
385
|
+
config = config_type(**config_params)
|
|
386
|
+
|
|
387
|
+
# Get the registered tool function
|
|
388
|
+
registry = GlobalTypeRegistry.get()
|
|
389
|
+
try:
|
|
390
|
+
tool_registration = registry.get_function(config_type)
|
|
391
|
+
except KeyError:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Build the tool function with the provided builder
|
|
397
|
+
async with tool_registration.build_fn(config, builder) as tool_result:
|
|
398
|
+
|
|
399
|
+
# Handle different tool result types (same as above)
|
|
400
|
+
if isinstance(tool_result, Function):
|
|
401
|
+
tool_function = tool_result
|
|
402
|
+
elif isinstance(tool_result, FunctionInfo):
|
|
403
|
+
if tool_result.single_fn:
|
|
404
|
+
tool_function = tool_result.single_fn
|
|
405
|
+
elif tool_result.streaming_fn:
|
|
406
|
+
tool_function = tool_result.streaming_fn
|
|
407
|
+
else:
|
|
408
|
+
raise ValueError("Tool function not found in FunctionInfo")
|
|
409
|
+
elif callable(tool_result):
|
|
410
|
+
tool_function = tool_result
|
|
411
|
+
else:
|
|
412
|
+
raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
|
|
413
|
+
|
|
414
|
+
# Execute the tool
|
|
415
|
+
if input_data is not None:
|
|
416
|
+
if asyncio.iscoroutinefunction(tool_function):
|
|
417
|
+
result = await tool_function(input_data)
|
|
418
|
+
else:
|
|
419
|
+
result = tool_function(input_data)
|
|
420
|
+
else:
|
|
421
|
+
if asyncio.iscoroutinefunction(tool_function):
|
|
422
|
+
result = await tool_function()
|
|
423
|
+
else:
|
|
424
|
+
result = tool_function()
|
|
425
|
+
|
|
426
|
+
# Assert expected output if provided
|
|
427
|
+
if expected_output is not None:
|
|
428
|
+
assert result == expected_output, f"Expected {expected_output}, got {result}"
|
|
429
|
+
|
|
430
|
+
return result
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@asynccontextmanager
|
|
434
|
+
async def with_mocked_dependencies():
|
|
435
|
+
"""
|
|
436
|
+
Convenience context manager for testing tools with mocked dependencies.
|
|
437
|
+
|
|
438
|
+
Usage:
|
|
439
|
+
async with with_mocked_dependencies() as (runner, mock_builder):
|
|
440
|
+
mock_builder.mock_llm("my_llm", "mocked response")
|
|
441
|
+
result = await runner.test_tool_with_builder(
|
|
442
|
+
config_type=MyToolConfig,
|
|
443
|
+
builder=mock_builder,
|
|
444
|
+
input_data="test input"
|
|
445
|
+
)
|
|
446
|
+
"""
|
|
447
|
+
runner = ToolTestRunner()
|
|
448
|
+
async with runner.with_mocks() as mock_builder:
|
|
449
|
+
yield runner, mock_builder
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nvidia-nat-test
|
|
3
|
+
Version: 1.2.0a20250813
|
|
4
|
+
Summary: Testing utilities for NeMo Agent toolkit
|
|
5
|
+
Keywords: ai,rag,agents
|
|
6
|
+
Classifier: Programming Language :: Python
|
|
7
|
+
Requires-Python: <3.13,>=3.11
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: nvidia-nat==v1.2.0a20250813
|
|
10
|
+
Requires-Dist: langchain-community~=0.3
|
|
11
|
+
Requires-Dist: pytest~=8.3
|
|
12
|
+
|
|
13
|
+
<!--
|
|
14
|
+
SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
15
|
+
SPDX-License-Identifier: Apache-2.0
|
|
16
|
+
|
|
17
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
18
|
+
you may not use this file except in compliance with the License.
|
|
19
|
+
You may obtain a copy of the License at
|
|
20
|
+
|
|
21
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
22
|
+
|
|
23
|
+
Unless required by applicable law or agreed to in writing, software
|
|
24
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
25
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
26
|
+
See the License for the specific language governing permissions and
|
|
27
|
+
limitations under the License.
|
|
28
|
+
-->
|
|
29
|
+
|
|
30
|
+

|
|
31
|
+
|
|
32
|
+
# NVIDIA NeMo Agent Toolkit Subpackage
|
|
33
|
+
This is a subpackage for NeMo Agent toolkit test utilities.
|
|
34
|
+
|
|
35
|
+
For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
nat/meta/pypi.md,sha256=XfzdATwjeiVEfSdXmjS5mhVbrHIwHNyK2CR4KtB9Qaw,1111
|
|
2
|
+
nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
|
|
3
|
+
nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
|
|
4
|
+
nat/test/functions.py,sha256=0ScrdsjcxCsPRLnyb5gfwukmvZxFi_ptCswLSIG0DVY,3095
|
|
5
|
+
nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
|
|
6
|
+
nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
|
|
7
|
+
nat/test/plugin.py,sha256=HUg_BidPzyjikSkkOZyPzQC9ALtbG21smmfRumUjWpc,3610
|
|
8
|
+
nat/test/register.py,sha256=jU1pW5wf20ZmCOTgkaQshKZfvYh8_-sMJ4P3xXilfTY,891
|
|
9
|
+
nat/test/tool_test_runner.py,sha256=UtoOd5irR9BQt5Nxloo_alzWzBAXJAhHHpSKhVWGDs0,17469
|
|
10
|
+
nvidia_nat_test-1.2.0a20250813.dist-info/METADATA,sha256=J2EDXFhF23pnzbC3ynLIHiyNFbTt3mb75mMTQOdvSUk,1477
|
|
11
|
+
nvidia_nat_test-1.2.0a20250813.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
12
|
+
nvidia_nat_test-1.2.0a20250813.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
|
|
13
|
+
nvidia_nat_test-1.2.0a20250813.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
|
14
|
+
nvidia_nat_test-1.2.0a20250813.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
nat
|