nvidia-nat-test 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-test might be problematic. Click here for more details.

nat/meta/pypi.md ADDED
@@ -0,0 +1,23 @@
1
+ <!--
2
+ SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ SPDX-License-Identifier: Apache-2.0
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ -->
17
+
18
+ ![NVIDIA NeMo Agent Toolkit](https://media.githubusercontent.com/media/NVIDIA/NeMo-Agent-Toolkit/refs/heads/main/docs/source/_static/banner.png "NeMo Agent toolkit banner image")
19
+
20
+ # NVIDIA NeMo Agent Toolkit Subpackage
21
+ This is a subpackage for NeMo Agent toolkit test utilities.
22
+
23
+ For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
nat/test/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Tool testing utilities
17
+ from .tool_test_runner import ToolTestRunner
18
+ from .tool_test_runner import with_mocked_dependencies
19
+
20
+ __all__ = [
21
+ "ToolTestRunner",
22
+ "with_mocked_dependencies",
23
+ ]
nat/test/embedder.py ADDED
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from pydantic import ConfigDict
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.builder.embedder import EmbedderProviderInfo
20
+ from nat.builder.framework_enum import LLMFrameworkEnum
21
+ from nat.cli.register_workflow import register_embedder_client
22
+ from nat.cli.register_workflow import register_embedder_provider
23
+ from nat.data_models.embedder import EmbedderBaseConfig
24
+
25
+
26
+ class EmbedderTestConfig(EmbedderBaseConfig, name="test_embedder"):
27
+ model_config = ConfigDict(protected_namespaces=())
28
+
29
+ model_name: str = "nvidia/nv-embedqa-e5-v5"
30
+ embedding_size: int = 768
31
+
32
+
33
+ @register_embedder_provider(config_type=EmbedderTestConfig)
34
+ async def embedder_test_provider(config: EmbedderTestConfig, builder: Builder):
35
+
36
+ yield EmbedderProviderInfo(config=config, description="Test embedder provider")
37
+
38
+
39
+ @register_embedder_client(config_type=EmbedderTestConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
40
+ async def embedder_langchain_test_client(config: EmbedderTestConfig, builder: Builder):
41
+
42
+ from langchain_community.embeddings import DeterministicFakeEmbedding
43
+
44
+ yield DeterministicFakeEmbedding(size=config.embedding_size)
nat/test/functions.py ADDED
@@ -0,0 +1,91 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections.abc import AsyncGenerator
17
+
18
+ from nat.builder.builder import Builder
19
+ from nat.builder.function_info import FunctionInfo
20
+ from nat.cli.register_workflow import register_function
21
+ from nat.data_models.api_server import ChatRequest
22
+ from nat.data_models.api_server import ChatResponse
23
+ from nat.data_models.api_server import ChatResponseChunk
24
+ from nat.data_models.function import FunctionBaseConfig
25
+
26
+
27
+ class EchoFunctionConfig(FunctionBaseConfig, name="test_echo"):
28
+ use_openai_api: bool = False
29
+
30
+
31
+ @register_function(config_type=EchoFunctionConfig)
32
+ async def echo_function(config: EchoFunctionConfig, builder: Builder):
33
+
34
+ async def inner(message: str) -> str:
35
+ return message
36
+
37
+ async def inner_oai(message: ChatRequest) -> ChatResponse:
38
+ return ChatResponse.from_string(message.messages[0].content)
39
+
40
+ if (config.use_openai_api):
41
+ yield inner_oai
42
+ else:
43
+ yield inner
44
+
45
+
46
+ class StreamingEchoFunctionConfig(FunctionBaseConfig, name="test_streaming_echo"):
47
+ use_openai_api: bool = False
48
+
49
+
50
+ @register_function(config_type=StreamingEchoFunctionConfig)
51
+ async def streaming_function(config: StreamingEchoFunctionConfig, builder: Builder):
52
+
53
+ def oai_to_list(message: ChatRequest) -> list[str]:
54
+ return [m.content for m in message.messages]
55
+
56
+ async def inner(message: list[str]) -> AsyncGenerator[str]:
57
+ for value in message:
58
+ yield value
59
+
60
+ async def inner_oai(message: ChatRequest) -> AsyncGenerator[ChatResponseChunk]:
61
+ for value in oai_to_list(message):
62
+ yield ChatResponseChunk.from_string(value)
63
+
64
+ yield FunctionInfo.from_fn(inner_oai if config.use_openai_api else inner, converters=[oai_to_list])
65
+
66
+
67
+ class ConstantFunctionConfig(FunctionBaseConfig, name="test_constant"):
68
+ response: str
69
+
70
+
71
+ @register_function(config_type=ConstantFunctionConfig)
72
+ async def constant_function(config: ConstantFunctionConfig, builder: Builder):
73
+
74
+ async def inner() -> str:
75
+ return config.response
76
+
77
+ yield inner
78
+
79
+
80
+ class StreamingConstantFunctionConfig(FunctionBaseConfig, name="test_streaming_constant"):
81
+ responses: list[str]
82
+
83
+
84
+ @register_function(config_type=StreamingConstantFunctionConfig)
85
+ async def streaming_constant_function(config: StreamingConstantFunctionConfig, builder: Builder):
86
+
87
+ async def inner() -> AsyncGenerator[str]:
88
+ for value in config.responses:
89
+ yield value
90
+
91
+ yield inner
nat/test/memory.py ADDED
@@ -0,0 +1,41 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from nat.builder.builder import Builder
17
+ from nat.cli.register_workflow import register_memory
18
+ from nat.data_models.memory import MemoryBaseConfig
19
+ from nat.memory.interfaces import MemoryEditor
20
+ from nat.memory.models import MemoryItem
21
+
22
+
23
+ class DummyMemoryConfig(MemoryBaseConfig, name="test_dummy"):
24
+ pass
25
+
26
+
27
+ @register_memory(config_type=DummyMemoryConfig)
28
+ async def echo_function(config: DummyMemoryConfig, builder: Builder):
29
+
30
+ class DummyMemoryEditor(MemoryEditor):
31
+
32
+ async def add_items(self, items: list[MemoryItem]) -> None:
33
+ pass
34
+
35
+ async def search(self, query: str, top_k: int = 5, **kwargs) -> list[MemoryItem]:
36
+ return []
37
+
38
+ async def remove_items(self, **kwargs) -> None:
39
+ pass
40
+
41
+ yield DummyMemoryEditor()
@@ -0,0 +1,117 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import uuid
17
+ from abc import abstractmethod
18
+ from contextlib import asynccontextmanager
19
+
20
+ import pytest
21
+ import pytest_asyncio
22
+
23
+ from nat.data_models.object_store import KeyAlreadyExistsError
24
+ from nat.data_models.object_store import NoSuchKeyError
25
+ from nat.object_store.interfaces import ObjectStore
26
+ from nat.object_store.models import ObjectStoreItem
27
+
28
+
29
+ @pytest.mark.asyncio(loop_scope="class")
30
+ class ObjectStoreTests:
31
+
32
+ @abstractmethod
33
+ @asynccontextmanager
34
+ async def _get_store(self):
35
+ yield
36
+
37
+ @pytest_asyncio.fixture(loop_scope="class", scope="class")
38
+ async def store(self):
39
+
40
+ async with self._get_store() as store:
41
+ yield store
42
+
43
+ async def test_create_object_store(self, store: ObjectStore):
44
+ assert isinstance(store, ObjectStore)
45
+
46
+ async def test_put_object(self, store: ObjectStore):
47
+
48
+ # Use a random key to avoid conflicts with other tests
49
+ key = f"test_key_{uuid.uuid4()}"
50
+
51
+ initial_item = ObjectStoreItem(data=b"test_value")
52
+ await store.put_object(key, initial_item)
53
+
54
+ # Try to put the same object again
55
+ with pytest.raises(KeyAlreadyExistsError):
56
+ await store.put_object(key, initial_item)
57
+
58
+ async def test_upsert_object(self, store: ObjectStore):
59
+ key = f"test_key_{uuid.uuid4()}"
60
+
61
+ initial_item = ObjectStoreItem(data=b"test_value", content_type="text/plain", metadata={"key": "value"})
62
+
63
+ await store.upsert_object(key, initial_item)
64
+
65
+ # Check that the object exists
66
+ retrieved_item = await store.get_object(key)
67
+ assert retrieved_item.data == initial_item.data
68
+ assert retrieved_item.content_type == initial_item.content_type
69
+ assert retrieved_item.metadata == initial_item.metadata
70
+
71
+ # Upsert the object with a new value
72
+ new_item = ObjectStoreItem(data=b"new_value", content_type="application/json", metadata={"key": "new_value"})
73
+ await store.upsert_object(key, new_item)
74
+
75
+ # Check that the object was updated
76
+ retrieved_item = await store.get_object(key)
77
+ assert retrieved_item.data == new_item.data
78
+ assert retrieved_item.content_type == new_item.content_type
79
+ assert retrieved_item.metadata == new_item.metadata
80
+
81
+ async def test_get_object(self, store: ObjectStore):
82
+
83
+ key = f"test_key_{uuid.uuid4()}"
84
+
85
+ initial_item = ObjectStoreItem(data=b"test_value", content_type="text/plain", metadata={"key": "value"})
86
+ await store.put_object(key, initial_item)
87
+
88
+ retrieved_item = await store.get_object(key)
89
+ assert retrieved_item.data == initial_item.data
90
+ assert retrieved_item.content_type == initial_item.content_type
91
+ assert retrieved_item.metadata == initial_item.metadata
92
+
93
+ # Try to get an object that doesn't exist
94
+ with pytest.raises(NoSuchKeyError):
95
+ await store.get_object(f"test_key_{uuid.uuid4()}")
96
+
97
+ async def test_delete_object(self, store: ObjectStore):
98
+
99
+ key = f"test_key_{uuid.uuid4()}"
100
+
101
+ initial_item = ObjectStoreItem(data=b"test_value")
102
+ await store.put_object(key, initial_item)
103
+
104
+ # Check that the object exists
105
+ retrieved_item = await store.get_object(key)
106
+ assert retrieved_item.data == initial_item.data
107
+
108
+ # Delete the object
109
+ await store.delete_object(key)
110
+
111
+ # Try to get the object again
112
+ with pytest.raises(NoSuchKeyError):
113
+ await store.get_object(key)
114
+
115
+ # Try to delete the object again
116
+ with pytest.raises(NoSuchKeyError):
117
+ await store.delete_object(key)
nat/test/plugin.py ADDED
@@ -0,0 +1,97 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pytest
17
+
18
+
19
+ def pytest_addoption(parser: pytest.Parser):
20
+ """
21
+ Adds command line options for running specfic tests that are disabled by default
22
+ """
23
+ parser.addoption(
24
+ "--run_e2e",
25
+ action="store_true",
26
+ dest="run_e2e",
27
+ help="Run end to end tests that would otherwise be skipped",
28
+ )
29
+
30
+ parser.addoption(
31
+ "--run_integration",
32
+ action="store_true",
33
+ dest="run_integration",
34
+ help=("Run integrations tests that would otherwise be skipped. "
35
+ "This will call out to external services instead of using mocks"),
36
+ )
37
+
38
+ parser.addoption(
39
+ "--run_slow",
40
+ action="store_true",
41
+ dest="run_slow",
42
+ help="Run end to end tests that would otherwise be skipped",
43
+ )
44
+
45
+
46
+ def pytest_runtest_setup(item):
47
+ if (not item.config.getoption("--run_e2e")):
48
+ if (item.get_closest_marker("e2e") is not None):
49
+ pytest.skip("Skipping end to end tests by default. Use --run_e2e to enable")
50
+
51
+ if (not item.config.getoption("--run_integration")):
52
+ if (item.get_closest_marker("integration") is not None):
53
+ pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
54
+
55
+ if (not item.config.getoption("--run_slow")):
56
+ if (item.get_closest_marker("slow") is not None):
57
+ pytest.skip("Skipping slow tests by default. Use --run_slow to enable")
58
+
59
+
60
+ @pytest.fixture(name="register_components", scope="session", autouse=True)
61
+ def register_components_fixture():
62
+ from nat.runtime.loader import PluginTypes
63
+ from nat.runtime.loader import discover_and_register_plugins
64
+
65
+ # Ensure that all components which need to be registered as part of an import are done so. This is necessary
66
+ # because imports will not be reloaded between tests, so we need to ensure that all components are registered
67
+ # before any tests are run.
68
+ discover_and_register_plugins(PluginTypes.ALL)
69
+
70
+ # Also import the nat.test.register module to register test-only components
71
+ import nat.test.register # pylint: disable=unused-import # noqa: F401
72
+
73
+
74
+ @pytest.fixture(name="module_registry", scope="module", autouse=True)
75
+ def module_registry_fixture():
76
+ """
77
+ Resets and returns the global type registry for testing
78
+
79
+ This gets automatically used at the module level to ensure no state is leaked between modules
80
+ """
81
+ from nat.cli.type_registry import GlobalTypeRegistry
82
+
83
+ with GlobalTypeRegistry.push() as registry:
84
+ yield registry
85
+
86
+
87
+ @pytest.fixture(name="registry", scope="function", autouse=True)
88
+ def function_registry_fixture():
89
+ """
90
+ Resets and returns the global type registry for testing
91
+
92
+ This gets automatically used at the function level to ensure no state is leaked between functions
93
+ """
94
+ from nat.cli.type_registry import GlobalTypeRegistry
95
+
96
+ with GlobalTypeRegistry.push() as registry:
97
+ yield registry
nat/test/register.py ADDED
@@ -0,0 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=unused-import
17
+ # flake8: noqa
18
+ # isort:skip_file
19
+
20
+ # Import any providers which need to be automatically registered here
21
+
22
+ from . import embedder
23
+ from . import functions
24
+ from . import memory
@@ -0,0 +1,449 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import logging
18
+ import typing
19
+ from contextlib import asynccontextmanager
20
+ from unittest.mock import AsyncMock
21
+ from unittest.mock import MagicMock
22
+
23
+ from nat.builder.builder import Builder
24
+ from nat.builder.function import Function
25
+ from nat.builder.function_info import FunctionInfo
26
+ from nat.cli.type_registry import GlobalTypeRegistry
27
+ from nat.data_models.function import FunctionBaseConfig
28
+ from nat.data_models.object_store import ObjectStoreBaseConfig
29
+ from nat.object_store.interfaces import ObjectStore
30
+ from nat.runtime.loader import PluginTypes
31
+ from nat.runtime.loader import discover_and_register_plugins
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class MockBuilder(Builder):
37
+ """
38
+ A lightweight mock builder for tool testing that provides minimal dependencies.
39
+ """
40
+
41
+ def __init__(self):
42
+ self._functions = {}
43
+ self._mocks = {}
44
+
45
+ def mock_function(self, name: str, mock_response: typing.Any):
46
+ """Add a mock function that returns a fixed response."""
47
+ self._mocks[name] = mock_response
48
+
49
+ def mock_llm(self, name: str, mock_response: typing.Any):
50
+ """Add a mock LLM that returns a fixed response."""
51
+ self._mocks[f"llm_{name}"] = mock_response
52
+
53
+ def mock_embedder(self, name: str, mock_response: typing.Any):
54
+ """Add a mock embedder that returns a fixed response."""
55
+ self._mocks[f"embedder_{name}"] = mock_response
56
+
57
+ def mock_memory_client(self, name: str, mock_response: typing.Any):
58
+ """Add a mock memory client that returns a fixed response."""
59
+ self._mocks[f"memory_{name}"] = mock_response
60
+
61
+ def mock_retriever(self, name: str, mock_response: typing.Any):
62
+ """Add a mock retriever that returns a fixed response."""
63
+ self._mocks[f"retriever_{name}"] = mock_response
64
+
65
+ def mock_object_store(self, name: str, mock_response: typing.Any):
66
+ """Add a mock object store that returns a fixed response."""
67
+ self._mocks[f"object_store_{name}"] = mock_response
68
+
69
+ def mock_ttc_strategy(self, name: str, mock_response: typing.Any):
70
+ """Add a mock TTC strategy that returns a fixed response."""
71
+ self._mocks[f"ttc_strategy_{name}"] = mock_response
72
+
73
+ async def add_ttc_strategy(self, name: str, config):
74
+ """Mock implementation (no‑op)."""
75
+ pass
76
+
77
+ async def get_ttc_strategy(self,
78
+ strategy_name: str,
79
+ pipeline_type: typing.Any = None,
80
+ stage_type: typing.Any = None):
81
+ """Return a mock TTC strategy if one is configured."""
82
+ key = f"ttc_strategy_{strategy_name}"
83
+ if key in self._mocks:
84
+ mock_strategy = MagicMock()
85
+ # Provide common callable patterns used in tests
86
+ mock_strategy.invoke = MagicMock(return_value=self._mocks[key])
87
+ mock_strategy.ainvoke = AsyncMock(return_value=self._mocks[key])
88
+ return mock_strategy
89
+ raise ValueError(f"TTC strategy '{strategy_name}' not mocked. Use mock_ttc_strategy() to add it.")
90
+
91
+ async def get_ttc_strategy_config(self,
92
+ strategy_name: str,
93
+ pipeline_type: typing.Any = None,
94
+ stage_type: typing.Any = None):
95
+ """Mock implementation."""
96
+ pass
97
+
98
+ async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
99
+ """Mock implementation - not used in tool testing."""
100
+ raise NotImplementedError("Mock implementation does not support add_function")
101
+
102
+ def get_function(self, name: str) -> Function:
103
+ """Return a mock function if one is configured."""
104
+ if name in self._mocks:
105
+ mock_fn = AsyncMock()
106
+ mock_fn.ainvoke = AsyncMock(return_value=self._mocks[name])
107
+ return mock_fn
108
+ raise ValueError(f"Function '{name}' not mocked. Use mock_function() to add it.")
109
+
110
+ def get_function_config(self, name: str) -> FunctionBaseConfig:
111
+ """Mock implementation."""
112
+ pass
113
+
114
+ async def set_workflow(self, config: FunctionBaseConfig) -> Function:
115
+ """Mock implementation."""
116
+ pass
117
+
118
+ def get_workflow(self) -> Function:
119
+ """Mock implementation."""
120
+ pass
121
+
122
+ def get_workflow_config(self) -> FunctionBaseConfig:
123
+ """Mock implementation."""
124
+ pass
125
+
126
+ def get_tool(self, fn_name: str, wrapper_type):
127
+ """Mock implementation."""
128
+ pass
129
+
130
+ async def add_llm(self, name: str, config):
131
+ """Mock implementation."""
132
+ pass
133
+
134
+ async def get_llm(self, llm_name: str, wrapper_type):
135
+ """Return a mock LLM if one is configured."""
136
+ key = f"llm_{llm_name}"
137
+ if key in self._mocks:
138
+ mock_llm = MagicMock()
139
+ mock_llm.invoke = MagicMock(return_value=self._mocks[key])
140
+ mock_llm.ainvoke = AsyncMock(return_value=self._mocks[key])
141
+ return mock_llm
142
+ raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
143
+
144
+ def get_llm_config(self, llm_name: str):
145
+ """Mock implementation."""
146
+ pass
147
+
148
+ async def add_embedder(self, name: str, config):
149
+ """Mock implementation."""
150
+ pass
151
+
152
+ async def get_embedder(self, embedder_name: str, wrapper_type):
153
+ """Return a mock embedder if one is configured."""
154
+ key = f"embedder_{embedder_name}"
155
+ if key in self._mocks:
156
+ mock_embedder = MagicMock()
157
+ mock_embedder.embed_query = MagicMock(return_value=self._mocks[key])
158
+ mock_embedder.embed_documents = MagicMock(return_value=self._mocks[key])
159
+ return mock_embedder
160
+ raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
161
+
162
+ def get_embedder_config(self, embedder_name: str):
163
+ """Mock implementation."""
164
+ pass
165
+
166
+ async def add_memory_client(self, name: str, config):
167
+ """Mock implementation."""
168
+ pass
169
+
170
+ def get_memory_client(self, memory_name: str):
171
+ """Return a mock memory client if one is configured."""
172
+ key = f"memory_{memory_name}"
173
+ if key in self._mocks:
174
+ mock_memory = MagicMock()
175
+ mock_memory.add = AsyncMock(return_value=self._mocks[key])
176
+ mock_memory.search = AsyncMock(return_value=self._mocks[key])
177
+ return mock_memory
178
+ raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
179
+
180
+ def get_memory_client_config(self, memory_name: str):
181
+ """Mock implementation."""
182
+ pass
183
+
184
+ async def add_retriever(self, name: str, config):
185
+ """Mock implementation."""
186
+ pass
187
+
188
+ async def get_retriever(self, retriever_name: str, wrapper_type=None):
189
+ """Return a mock retriever if one is configured."""
190
+ key = f"retriever_{retriever_name}"
191
+ if key in self._mocks:
192
+ mock_retriever = MagicMock()
193
+ mock_retriever.retrieve = AsyncMock(return_value=self._mocks[key])
194
+ return mock_retriever
195
+ raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
196
+
197
+ async def get_retriever_config(self, retriever_name: str):
198
+ """Mock implementation."""
199
+ pass
200
+
201
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
202
+ """Mock implementation for object store."""
203
+ pass
204
+
205
+ async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
206
+ """Return a mock object store client if one is configured."""
207
+ key = f"object_store_{object_store_name}"
208
+ if key in self._mocks:
209
+ mock_object_store = MagicMock()
210
+ mock_object_store.put_object = AsyncMock(return_value=self._mocks[key])
211
+ mock_object_store.get_object = AsyncMock(return_value=self._mocks[key])
212
+ mock_object_store.delete_object = AsyncMock(return_value=self._mocks[key])
213
+ mock_object_store.list_objects = AsyncMock(return_value=self._mocks[key])
214
+ return mock_object_store
215
+ raise ValueError(f"Object store '{object_store_name}' not mocked. Use mock_object_store() to add it.")
216
+
217
+ def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
218
+ """Mock implementation for object store config."""
219
+ pass
220
+
221
+ def get_user_manager(self):
222
+ """Mock implementation."""
223
+ mock_user = MagicMock()
224
+ mock_user.get_id = MagicMock(return_value="test_user")
225
+ return mock_user
226
+
227
+ def get_function_dependencies(self, fn_name: str):
228
+ """Mock implementation."""
229
+ pass
230
+
231
+
232
+ class ToolTestRunner:
233
+ """
234
+ A test runner that enables isolated testing of NAT tools without requiring
235
+ full workflow setup, LLMs, or complex dependencies.
236
+
237
+ Usage:
238
+ runner = ToolTestRunner()
239
+
240
+ # Test a tool with minimal setup
241
+ result = await runner.test_tool(
242
+ config_type=MyToolConfig,
243
+ config_params={"param1": "value1"},
244
+ input_data="test input"
245
+ )
246
+
247
+ # Test a tool with mocked dependencies
248
+ async with runner.with_mocks() as mock_builder:
249
+ mock_builder.mock_llm("my_llm", "mocked response")
250
+ result = await runner.test_tool(
251
+ config_type=MyToolConfig,
252
+ config_params={"llm_name": "my_llm"},
253
+ input_data="test input"
254
+ )
255
+ """
256
+
257
+ def __init__(self):
258
+ self._ensure_plugins_loaded()
259
+
260
+ def _ensure_plugins_loaded(self):
261
+ """Ensure all plugins are loaded for tool registration."""
262
+ discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
263
+
264
+ async def test_tool(self,
265
+ config_type: type[FunctionBaseConfig],
266
+ config_params: dict[str, typing.Any] | None = None,
267
+ input_data: typing.Any = None,
268
+ expected_output: typing.Any = None,
269
+ **kwargs) -> typing.Any:
270
+ """
271
+ Test a tool in isolation with minimal setup.
272
+
273
+ Args:
274
+ config_type: The tool configuration class
275
+ config_params: Parameters to pass to the config constructor
276
+ input_data: Input data to pass to the tool
277
+ expected_output: Expected output for assertion (optional)
278
+ **kwargs: Additional parameters
279
+
280
+ Returns:
281
+ The tool's output
282
+
283
+ Raises:
284
+ AssertionError: If expected_output is provided and doesn't match
285
+ ValueError: If tool registration or execution fails
286
+ """
287
+ config_params = config_params or {}
288
+
289
+ # Create tool configuration
290
+ config = config_type(**config_params)
291
+
292
+ # Get the registered tool function
293
+ registry = GlobalTypeRegistry.get()
294
+ try:
295
+ tool_registration = registry.get_function(config_type)
296
+ except KeyError:
297
+ raise ValueError(
298
+ f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
299
+ )
300
+
301
+ # Create a mock builder for dependencies
302
+ mock_builder = MockBuilder()
303
+
304
+ # Build the tool function
305
+ async with tool_registration.build_fn(config, mock_builder) as tool_result:
306
+
307
+ # Handle different tool result types
308
+ if isinstance(tool_result, Function):
309
+ tool_function = tool_result
310
+ elif isinstance(tool_result, FunctionInfo):
311
+ # Extract the actual function from FunctionInfo
312
+ if tool_result.single_fn:
313
+ tool_function = tool_result.single_fn
314
+ elif tool_result.stream_fn:
315
+ tool_function = tool_result.stream_fn
316
+ else:
317
+ raise ValueError("Tool function not found in FunctionInfo")
318
+ elif callable(tool_result):
319
+ tool_function = tool_result
320
+ else:
321
+ raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
322
+
323
+ # Execute the tool
324
+ if input_data is not None:
325
+ if asyncio.iscoroutinefunction(tool_function):
326
+ result = await tool_function(input_data)
327
+ else:
328
+ result = tool_function(input_data)
329
+ else:
330
+ if asyncio.iscoroutinefunction(tool_function):
331
+ result = await tool_function()
332
+ else:
333
+ result = tool_function()
334
+
335
+ # Assert expected output if provided
336
+ if expected_output is not None:
337
+ assert result == expected_output, f"Expected {expected_output}, got {result}"
338
+
339
+ return result
340
+
341
+ @asynccontextmanager
342
+ async def with_mocks(self):
343
+ """
344
+ Context manager that provides a mock builder for setting up dependencies.
345
+
346
+ Usage:
347
+ async with runner.with_mocks() as mock_builder:
348
+ mock_builder.mock_llm("my_llm", "mocked response")
349
+ result = await runner.test_tool_with_builder(
350
+ config_type=MyToolConfig,
351
+ builder=mock_builder,
352
+ input_data="test input"
353
+ )
354
+ """
355
+ mock_builder = MockBuilder()
356
+ try:
357
+ yield mock_builder
358
+ finally:
359
+ pass
360
+
361
+ async def test_tool_with_builder(
362
+ self,
363
+ config_type: type[FunctionBaseConfig],
364
+ builder: MockBuilder,
365
+ config_params: dict[str, typing.Any] | None = None,
366
+ input_data: typing.Any = None,
367
+ expected_output: typing.Any = None,
368
+ ) -> typing.Any:
369
+ """
370
+ Test a tool with a pre-configured mock builder.
371
+
372
+ Args:
373
+ config_type: The tool configuration class
374
+ builder: Pre-configured MockBuilder with mocked dependencies
375
+ config_params: Parameters to pass to the config constructor
376
+ input_data: Input data to pass to the tool
377
+ expected_output: Expected output for assertion (optional)
378
+
379
+ Returns:
380
+ The tool's output
381
+ """
382
+ config_params = config_params or {}
383
+
384
+ # Create tool configuration
385
+ config = config_type(**config_params)
386
+
387
+ # Get the registered tool function
388
+ registry = GlobalTypeRegistry.get()
389
+ try:
390
+ tool_registration = registry.get_function(config_type)
391
+ except KeyError:
392
+ raise ValueError(
393
+ f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
394
+ )
395
+
396
+ # Build the tool function with the provided builder
397
+ async with tool_registration.build_fn(config, builder) as tool_result:
398
+
399
+ # Handle different tool result types (same as above)
400
+ if isinstance(tool_result, Function):
401
+ tool_function = tool_result
402
+ elif isinstance(tool_result, FunctionInfo):
403
+ if tool_result.single_fn:
404
+ tool_function = tool_result.single_fn
405
+ elif tool_result.streaming_fn:
406
+ tool_function = tool_result.streaming_fn
407
+ else:
408
+ raise ValueError("Tool function not found in FunctionInfo")
409
+ elif callable(tool_result):
410
+ tool_function = tool_result
411
+ else:
412
+ raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
413
+
414
+ # Execute the tool
415
+ if input_data is not None:
416
+ if asyncio.iscoroutinefunction(tool_function):
417
+ result = await tool_function(input_data)
418
+ else:
419
+ result = tool_function(input_data)
420
+ else:
421
+ if asyncio.iscoroutinefunction(tool_function):
422
+ result = await tool_function()
423
+ else:
424
+ result = tool_function()
425
+
426
+ # Assert expected output if provided
427
+ if expected_output is not None:
428
+ assert result == expected_output, f"Expected {expected_output}, got {result}"
429
+
430
+ return result
431
+
432
+
433
+ @asynccontextmanager
434
+ async def with_mocked_dependencies():
435
+ """
436
+ Convenience context manager for testing tools with mocked dependencies.
437
+
438
+ Usage:
439
+ async with with_mocked_dependencies() as (runner, mock_builder):
440
+ mock_builder.mock_llm("my_llm", "mocked response")
441
+ result = await runner.test_tool_with_builder(
442
+ config_type=MyToolConfig,
443
+ builder=mock_builder,
444
+ input_data="test input"
445
+ )
446
+ """
447
+ runner = ToolTestRunner()
448
+ async with runner.with_mocks() as mock_builder:
449
+ yield runner, mock_builder
@@ -0,0 +1,35 @@
1
+ Metadata-Version: 2.4
2
+ Name: nvidia-nat-test
3
+ Version: 1.2.0
4
+ Summary: Testing utilities for NeMo Agent toolkit
5
+ Keywords: ai,rag,agents
6
+ Classifier: Programming Language :: Python
7
+ Requires-Python: <3.13,>=3.11
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: nvidia-nat==v1.2.0
10
+ Requires-Dist: langchain-community~=0.3
11
+ Requires-Dist: pytest~=8.3
12
+
13
+ <!--
14
+ SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
15
+ SPDX-License-Identifier: Apache-2.0
16
+
17
+ Licensed under the Apache License, Version 2.0 (the "License");
18
+ you may not use this file except in compliance with the License.
19
+ You may obtain a copy of the License at
20
+
21
+ http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ Unless required by applicable law or agreed to in writing, software
24
+ distributed under the License is distributed on an "AS IS" BASIS,
25
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ See the License for the specific language governing permissions and
27
+ limitations under the License.
28
+ -->
29
+
30
+ ![NVIDIA NeMo Agent Toolkit](https://media.githubusercontent.com/media/NVIDIA/NeMo-Agent-Toolkit/refs/heads/main/docs/source/_static/banner.png "NeMo Agent toolkit banner image")
31
+
32
+ # NVIDIA NeMo Agent Toolkit Subpackage
33
+ This is a subpackage for NeMo Agent toolkit test utilities.
34
+
35
+ For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
@@ -0,0 +1,14 @@
1
+ nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
2
+ nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
3
+ nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
4
+ nat/test/functions.py,sha256=0ScrdsjcxCsPRLnyb5gfwukmvZxFi_ptCswLSIG0DVY,3095
5
+ nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
6
+ nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
7
+ nat/test/plugin.py,sha256=fp39ib0W63vfqX6Ssvq4sCuSd8Lm6yQyknL3_qRijgI,3610
8
+ nat/test/register.py,sha256=jU1pW5wf20ZmCOTgkaQshKZfvYh8_-sMJ4P3xXilfTY,891
9
+ nat/test/tool_test_runner.py,sha256=ccErldob2VwBbVL0_pmLrOcKLc18qYjxeAEACYoKKGQ,17469
10
+ nvidia_nat_test-1.2.0.dist-info/METADATA,sha256=7nR5ymwMR124gv8C6UhR7bVn2vlExEC5UvG9juwZkhY,1448
11
+ nvidia_nat_test-1.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ nvidia_nat_test-1.2.0.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
13
+ nvidia_nat_test-1.2.0.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
14
+ nvidia_nat_test-1.2.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,5 @@
1
+ [nat.components]
2
+ nvidia-nat-test = nat.test.register
3
+
4
+ [pytest11]
5
+ nvidia-nat-test = nat.test.plugin
@@ -0,0 +1 @@
1
+ nat