nvidia-nat-test 1.3.dev0__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/test/llm.py ADDED
@@ -0,0 +1,205 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=unused-argument,missing-class-docstring,missing-function-docstring,import-outside-toplevel
17
+ # pylint: disable=too-few-public-methods
18
+
19
+ import asyncio
20
+ import time
21
+ from collections.abc import AsyncGenerator
22
+ from collections.abc import Iterator
23
+ from itertools import cycle as iter_cycle
24
+ from typing import Any
25
+
26
+ from pydantic import Field
27
+
28
+ from nat.builder.builder import Builder
29
+ from nat.builder.framework_enum import LLMFrameworkEnum
30
+ from nat.builder.llm import LLMProviderInfo
31
+ from nat.cli.register_workflow import register_llm_client
32
+ from nat.cli.register_workflow import register_llm_provider
33
+ from nat.data_models.llm import LLMBaseConfig
34
+
35
+
36
+ class TestLLMConfig(LLMBaseConfig, name="nat_test_llm"):
37
+ """Test LLM configuration."""
38
+ __test__ = False
39
+ response_seq: list[str] = Field(
40
+ default=[],
41
+ description="Returns the next element in order (wraps)",
42
+ )
43
+ delay_ms: int = Field(default=0, ge=0, description="Artificial per-call delay in milliseconds to mimic latency")
44
+
45
+
46
+ class _ResponseChooser:
47
+ """
48
+ Helper class to choose the next response according to config using itertools.cycle and provide synchronous and
49
+ asynchronous sleep functions.
50
+ """
51
+
52
+ def __init__(self, response_seq: list[str], delay_ms: int):
53
+ self._cycler = iter_cycle(response_seq) if response_seq else None
54
+ self._delay_ms = delay_ms
55
+
56
+ def next_response(self) -> str:
57
+ """Return the next response in the cycle, or an empty string if no responses are configured."""
58
+ if self._cycler is None:
59
+ return ""
60
+ return next(self._cycler)
61
+
62
+ def sync_sleep(self) -> None:
63
+ time.sleep(self._delay_ms / 1000.0)
64
+
65
+ async def async_sleep(self) -> None:
66
+ await asyncio.sleep(self._delay_ms / 1000.0)
67
+
68
+
69
+ @register_llm_provider(config_type=TestLLMConfig)
70
+ async def test_llm_provider(config: TestLLMConfig, builder: Builder):
71
+ """Register the `nat_test_llm` provider for the NAT registry."""
72
+ yield LLMProviderInfo(config=config, description="Test LLM provider")
73
+
74
+
75
+ @register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
76
+ async def test_llm_langchain(config: TestLLMConfig, builder: Builder):
77
+ """LLM client for LangChain/LangGraph."""
78
+
79
+ chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
80
+
81
+ class LangChainTestLLM:
82
+
83
+ def invoke(self, messages: Any, **_kwargs: Any) -> str:
84
+ chooser.sync_sleep()
85
+ return chooser.next_response()
86
+
87
+ async def ainvoke(self, messages: Any, **_kwargs: Any) -> str:
88
+ await chooser.async_sleep()
89
+ return chooser.next_response()
90
+
91
+ def stream(self, messages: Any, **_kwargs: Any) -> Iterator[str]:
92
+ chooser.sync_sleep()
93
+ yield chooser.next_response()
94
+
95
+ async def astream(self, messages: Any, **_kwargs: Any) -> AsyncGenerator[str]:
96
+ await chooser.async_sleep()
97
+ yield chooser.next_response()
98
+
99
+ yield LangChainTestLLM()
100
+
101
+
102
+ @register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
103
+ async def test_llm_llama_index(config: TestLLMConfig, builder: Builder):
104
+
105
+ try:
106
+ from llama_index.core.base.llms.types import ChatMessage
107
+ from llama_index.core.base.llms.types import ChatResponse
108
+ except ImportError as exc:
109
+ raise ImportError("llama_index is required for using the test_llm with llama_index. "
110
+ "Please install the `nvidia-nat-llama-index` package. ") from exc
111
+
112
+ chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
113
+
114
+ class LITestLLM:
115
+
116
+ def chat(self, messages: list[Any] | None = None, **_kwargs: Any) -> ChatResponse:
117
+ chooser.sync_sleep()
118
+ return ChatResponse(message=ChatMessage(chooser.next_response()))
119
+
120
+ async def achat(self, messages: list[Any] | None = None, **_kwargs: Any) -> ChatResponse:
121
+ await chooser.async_sleep()
122
+ return ChatResponse(message=ChatMessage(chooser.next_response()))
123
+
124
+ def stream_chat(self, messages: list[Any] | None = None, **_kwargs: Any) -> Iterator[ChatResponse]:
125
+ chooser.sync_sleep()
126
+ yield ChatResponse(message=ChatMessage(chooser.next_response()))
127
+
128
+ async def astream_chat(self,
129
+ messages: list[Any] | None = None,
130
+ **_kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
131
+ await chooser.async_sleep()
132
+ yield ChatResponse(message=ChatMessage(chooser.next_response()))
133
+
134
+ yield LITestLLM()
135
+
136
+
137
+ @register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
138
+ async def test_llm_crewai(config: TestLLMConfig, builder: Builder):
139
+ """LLM client for CrewAI."""
140
+
141
+ chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
142
+
143
+ class CrewAITestLLM:
144
+
145
+ def call(self, messages: list[dict[str, str]] | None = None, **kwargs: Any) -> str:
146
+ chooser.sync_sleep()
147
+ return chooser.next_response()
148
+
149
+ yield CrewAITestLLM()
150
+
151
+
152
+ @register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL)
153
+ async def test_llm_semantic_kernel(config: TestLLMConfig, builder: Builder):
154
+ """LLM client for SemanticKernel."""
155
+
156
+ try:
157
+ from semantic_kernel.contents.chat_message_content import ChatMessageContent
158
+ from semantic_kernel.contents.utils.author_role import AuthorRole
159
+ except ImportError as exc:
160
+ raise ImportError("Semantic Kernel is required for using the test_llm with semantic_kernel. "
161
+ "Please install the `nvidia-nat-semantic-kernel` package. ") from exc
162
+
163
+ chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
164
+
165
+ class SKTestLLM:
166
+
167
+ async def get_chat_message_contents(self, chat_history: Any, **_kwargs: Any) -> list[ChatMessageContent]:
168
+ await chooser.async_sleep()
169
+ text = chooser.next_response()
170
+ return [ChatMessageContent(role=AuthorRole.ASSISTANT, content=text)]
171
+
172
+ async def get_streaming_chat_message_contents(self, chat_history: Any,
173
+ **_kwargs: Any) -> AsyncGenerator[ChatMessageContent, None]:
174
+ await chooser.async_sleep()
175
+ text = chooser.next_response()
176
+ yield ChatMessageContent(role=AuthorRole.ASSISTANT, content=text)
177
+
178
+ yield SKTestLLM()
179
+
180
+
181
+ @register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.AGNO)
182
+ async def test_llm_agno(config: TestLLMConfig, builder: Builder):
183
+ """LLM client for agno."""
184
+
185
+ chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
186
+
187
+ class AgnoTestLLM:
188
+
189
+ def invoke(self, messages: Any | None = None, **_kwargs: Any) -> str:
190
+ chooser.sync_sleep()
191
+ return chooser.next_response()
192
+
193
+ async def ainvoke(self, messages: Any | None = None, **_kwargs: Any) -> str:
194
+ await chooser.async_sleep()
195
+ return chooser.next_response()
196
+
197
+ def invoke_stream(self, messages: Any | None = None, **_kwargs: Any) -> Iterator[str]:
198
+ chooser.sync_sleep()
199
+ yield chooser.next_response()
200
+
201
+ async def ainvoke_stream(self, messages: Any | None = None, **_kwargs: Any) -> AsyncGenerator[str, None]:
202
+ await chooser.async_sleep()
203
+ yield chooser.next_response()
204
+
205
+ yield AgnoTestLLM()
nat/test/plugin.py CHANGED
@@ -13,20 +13,21 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import os
17
+ import subprocess
18
+ import typing
19
+ from pathlib import Path
20
+
16
21
  import pytest
17
22
 
23
+ if typing.TYPE_CHECKING:
24
+ from docker.client import DockerClient
25
+
18
26
 
19
27
  def pytest_addoption(parser: pytest.Parser):
20
28
  """
21
29
  Adds command line options for running specfic tests that are disabled by default
22
30
  """
23
- parser.addoption(
24
- "--run_e2e",
25
- action="store_true",
26
- dest="run_e2e",
27
- help="Run end to end tests that would otherwise be skipped",
28
- )
29
-
30
31
  parser.addoption(
31
32
  "--run_integration",
32
33
  action="store_true",
@@ -42,12 +43,16 @@ def pytest_addoption(parser: pytest.Parser):
42
43
  help="Run end to end tests that would otherwise be skipped",
43
44
  )
44
45
 
46
+ parser.addoption(
47
+ "--fail_missing",
48
+ action="store_true",
49
+ dest="fail_missing",
50
+ help=("Tests requiring unmet dependencies are normally skipped. "
51
+ "Setting this flag will instead cause them to be reported as a failure"),
52
+ )
53
+
45
54
 
46
55
  def pytest_runtest_setup(item):
47
- if (not item.config.getoption("--run_e2e")):
48
- if (item.get_closest_marker("e2e") is not None):
49
- pytest.skip("Skipping end to end tests by default. Use --run_e2e to enable")
50
-
51
56
  if (not item.config.getoption("--run_integration")):
52
57
  if (item.get_closest_marker("integration") is not None):
53
58
  pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
@@ -68,7 +73,6 @@ def register_components_fixture():
68
73
  discover_and_register_plugins(PluginTypes.ALL)
69
74
 
70
75
  # Also import the nat.test.register module to register test-only components
71
- import nat.test.register # pylint: disable=unused-import # noqa: F401
72
76
 
73
77
 
74
78
  @pytest.fixture(name="module_registry", scope="module", autouse=True)
@@ -95,3 +99,236 @@ def function_registry_fixture():
95
99
 
96
100
  with GlobalTypeRegistry.push() as registry:
97
101
  yield registry
102
+
103
+
104
+ @pytest.fixture(scope="session", name="fail_missing")
105
+ def fail_missing_fixture(pytestconfig: pytest.Config) -> bool:
106
+ """
107
+ Returns the value of the `fail_missing` flag, when false tests requiring unmet dependencies will be skipped, when
108
+ True they will fail.
109
+ """
110
+ yield pytestconfig.getoption("fail_missing")
111
+
112
+
113
+ def require_env_variables(varnames: list[str], reason: str, fail_missing: bool = False) -> dict[str, str]:
114
+ """
115
+ Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and
116
+ `fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised.
117
+ """
118
+ env_variables = {}
119
+ try:
120
+ for varname in varnames:
121
+ env_variables[varname] = os.environ[varname]
122
+ except KeyError as e:
123
+ if fail_missing:
124
+ raise RuntimeError(reason) from e
125
+
126
+ pytest.skip(reason=reason)
127
+
128
+ return env_variables
129
+
130
+
131
+ @pytest.fixture(name="openai_api_key", scope='session')
132
+ def openai_api_key_fixture(fail_missing: bool):
133
+ """
134
+ Use for integration tests that require an Openai API key.
135
+ """
136
+ yield require_env_variables(
137
+ varnames=["OPENAI_API_KEY"],
138
+ reason="openai integration tests require the `OPENAI_API_KEY` environment variable to be defined.",
139
+ fail_missing=fail_missing)
140
+
141
+
142
+ @pytest.fixture(name="nvidia_api_key", scope='session')
143
+ def nvidia_api_key_fixture(fail_missing: bool):
144
+ """
145
+ Use for integration tests that require an Nvidia API key.
146
+ """
147
+ yield require_env_variables(
148
+ varnames=["NVIDIA_API_KEY"],
149
+ reason="Nvidia integration tests require the `NVIDIA_API_KEY` environment variable to be defined.",
150
+ fail_missing=fail_missing)
151
+
152
+
153
+ @pytest.fixture(name="serp_api_key", scope='session')
154
+ def serp_api_key_fixture(fail_missing: bool):
155
+ """
156
+ Use for integration tests that require a SERP API key.
157
+ """
158
+ yield require_env_variables(
159
+ varnames=["SERP_API_KEY"],
160
+ reason="SERP integration tests require the `SERP_API_KEY` environment variable to be defined.",
161
+ fail_missing=fail_missing)
162
+
163
+
164
+ @pytest.fixture(name="tavily_api_key", scope='session')
165
+ def tavily_api_key_fixture(fail_missing: bool):
166
+ """
167
+ Use for integration tests that require a Tavily API key.
168
+ """
169
+ yield require_env_variables(
170
+ varnames=["TAVILY_API_KEY"],
171
+ reason="Tavily integration tests require the `TAVILY_API_KEY` environment variable to be defined.",
172
+ fail_missing=fail_missing)
173
+
174
+
175
+ @pytest.fixture(name="mem0_api_key", scope='session')
176
+ def mem0_api_key_fixture(fail_missing: bool):
177
+ """
178
+ Use for integration tests that require a Mem0 API key.
179
+ """
180
+ yield require_env_variables(
181
+ varnames=["MEM0_API_KEY"],
182
+ reason="Mem0 integration tests require the `MEM0_API_KEY` environment variable to be defined.",
183
+ fail_missing=fail_missing)
184
+
185
+
186
+ @pytest.fixture(name="aws_keys", scope='session')
187
+ def aws_keys_fixture(fail_missing: bool):
188
+ """
189
+ Use for integration tests that require AWS credentials.
190
+ """
191
+
192
+ yield require_env_variables(
193
+ varnames=["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
194
+ reason=
195
+ "AWS integration tests require the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be "
196
+ "defined.",
197
+ fail_missing=fail_missing)
198
+
199
+
200
+ @pytest.fixture(name="azure_openai_keys", scope='session')
201
+ def azure_openai_keys_fixture(fail_missing: bool):
202
+ """
203
+ Use for integration tests that require Azure OpenAI credentials.
204
+ """
205
+ yield require_env_variables(
206
+ varnames=["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
207
+ reason="Azure integration tests require the `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` environment "
208
+ "variable to be defined.",
209
+ fail_missing=fail_missing)
210
+
211
+
212
+ @pytest.fixture(name="require_docker", scope='session')
213
+ def require_docker_fixture(fail_missing: bool) -> "DockerClient":
214
+ """
215
+ Use for integration tests that require Docker to be running.
216
+ """
217
+ try:
218
+ from docker.client import DockerClient
219
+ yield DockerClient()
220
+ except Exception as e:
221
+ reason = f"Unable to connect to Docker daemon: {e}"
222
+ if fail_missing:
223
+ raise RuntimeError(reason) from e
224
+ pytest.skip(reason=reason)
225
+
226
+
227
+ @pytest.fixture(name="restore_environ")
228
+ def restore_environ_fixture():
229
+ orig_vars = os.environ.copy()
230
+ yield os.environ
231
+
232
+ for key, value in orig_vars.items():
233
+ os.environ[key] = value
234
+
235
+ # Delete any new environment variables
236
+ # Iterating over a copy of the keys as we will potentially be deleting keys in the loop
237
+ for key in list(os.environ.keys()):
238
+ if key not in orig_vars:
239
+ del (os.environ[key])
240
+
241
+
242
+ @pytest.fixture(name="root_repo_dir", scope='session')
243
+ def root_repo_dir_fixture() -> Path:
244
+ from nat.test.utils import locate_repo_root
245
+ return locate_repo_root()
246
+
247
+
248
+ @pytest.fixture(name="require_etcd", scope="session")
249
+ def require_etcd_fixture(fail_missing: bool = False) -> bool:
250
+ """
251
+ To run these tests, an etcd server must be running
252
+ """
253
+ import requests
254
+
255
+ host = os.getenv("NAT_CI_ETCD_HOST", "localhost")
256
+ port = os.getenv("NAT_CI_ETCD_PORT", "2379")
257
+ health_url = f"http://{host}:{port}/health"
258
+
259
+ try:
260
+ response = requests.get(health_url, timeout=5)
261
+ response.raise_for_status()
262
+ return True
263
+ except: # noqa: E722
264
+ failure_reason = f"Unable to connect to etcd server at {health_url}"
265
+ if fail_missing:
266
+ raise RuntimeError(failure_reason)
267
+ pytest.skip(reason=failure_reason)
268
+
269
+
270
+ @pytest.fixture(name="milvus_uri", scope="session")
271
+ def milvus_uri_fixture(require_etcd: bool, fail_missing: bool = False) -> str:
272
+ """
273
+ To run these tests, a Milvus server must be running
274
+ """
275
+ host = os.getenv("NAT_CI_MILVUS_HOST", "localhost")
276
+ port = os.getenv("NAT_CI_MILVUS_PORT", "19530")
277
+ uri = f"http://{host}:{port}"
278
+ try:
279
+ from pymilvus import MilvusClient
280
+ MilvusClient(uri=uri)
281
+
282
+ return uri
283
+ except: # noqa: E722
284
+ reason = f"Unable to connect to Milvus server at {uri}"
285
+ if fail_missing:
286
+ raise RuntimeError(reason)
287
+ pytest.skip(reason=reason)
288
+
289
+
290
+ @pytest.fixture(name="populate_milvus", scope="session")
291
+ def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path):
292
+ """
293
+ Populate Milvus with some test data.
294
+ """
295
+ populate_script = root_repo_dir / "scripts/langchain_web_ingest.py"
296
+
297
+ # Ingest default cuda docs
298
+ subprocess.run(["python", str(populate_script), "--milvus_uri", milvus_uri], check=True)
299
+
300
+ # Ingest MCP docs
301
+ subprocess.run([
302
+ "python",
303
+ str(populate_script),
304
+ "--milvus_uri",
305
+ milvus_uri,
306
+ "--urls",
307
+ "https://github.com/modelcontextprotocol/python-sdk",
308
+ "--urls",
309
+ "https://modelcontextprotocol.io/introduction",
310
+ "--urls",
311
+ "https://modelcontextprotocol.io/quickstart/server",
312
+ "--urls",
313
+ "https://modelcontextprotocol.io/quickstart/client",
314
+ "--urls",
315
+ "https://modelcontextprotocol.io/examples",
316
+ "--urls",
317
+ "https://modelcontextprotocol.io/docs/concepts/architecture",
318
+ "--collection_name",
319
+ "mcp_docs"
320
+ ],
321
+ check=True)
322
+
323
+ # Ingest some wikipedia docs
324
+ subprocess.run([
325
+ "python",
326
+ str(populate_script),
327
+ "--milvus_uri",
328
+ milvus_uri,
329
+ "--urls",
330
+ "https://en.wikipedia.org/wiki/Aardvark",
331
+ "--collection_name",
332
+ "wikipedia_docs"
333
+ ],
334
+ check=True)
nat/test/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -22,3 +21,5 @@
22
21
  from . import embedder
23
22
  from . import functions
24
23
  from . import memory
24
+ from . import llm
25
+ from . import utils
@@ -14,18 +14,33 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import inspect
17
18
  import logging
18
19
  import typing
20
+ from collections.abc import Sequence
19
21
  from contextlib import asynccontextmanager
20
22
  from unittest.mock import AsyncMock
21
23
  from unittest.mock import MagicMock
22
24
 
25
+ from nat.authentication.interfaces import AuthProviderBase
23
26
  from nat.builder.builder import Builder
24
27
  from nat.builder.function import Function
28
+ from nat.builder.function import FunctionGroup
25
29
  from nat.builder.function_info import FunctionInfo
26
30
  from nat.cli.type_registry import GlobalTypeRegistry
31
+ from nat.data_models.authentication import AuthProviderBaseConfig
32
+ from nat.data_models.embedder import EmbedderBaseConfig
27
33
  from nat.data_models.function import FunctionBaseConfig
34
+ from nat.data_models.function import FunctionGroupBaseConfig
35
+ from nat.data_models.function_dependencies import FunctionDependencies
36
+ from nat.data_models.llm import LLMBaseConfig
37
+ from nat.data_models.memory import MemoryBaseConfig
28
38
  from nat.data_models.object_store import ObjectStoreBaseConfig
39
+ from nat.data_models.retriever import RetrieverBaseConfig
40
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
41
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
42
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
43
+ from nat.memory.interfaces import MemoryEditor
29
44
  from nat.object_store.interfaces import ObjectStore
30
45
  from nat.runtime.loader import PluginTypes
31
46
  from nat.runtime.loader import discover_and_register_plugins
@@ -46,6 +61,10 @@ class MockBuilder(Builder):
46
61
  """Add a mock function that returns a fixed response."""
47
62
  self._mocks[name] = mock_response
48
63
 
64
+ def mock_function_group(self, name: str, mock_response: typing.Any):
65
+ """Add a mock function group that returns a fixed response."""
66
+ self._mocks[name] = mock_response
67
+
49
68
  def mock_llm(self, name: str, mock_response: typing.Any):
50
69
  """Add a mock LLM that returns a fixed response."""
51
70
  self._mocks[f"llm_{name}"] = mock_response
@@ -70,14 +89,16 @@ class MockBuilder(Builder):
70
89
  """Add a mock TTC strategy that returns a fixed response."""
71
90
  self._mocks[f"ttc_strategy_{name}"] = mock_response
72
91
 
73
- async def add_ttc_strategy(self, name: str, config):
92
+ def mock_auth_provider(self, name: str, mock_response: typing.Any):
93
+ """Add a mock auth provider that returns a fixed response."""
94
+ self._mocks[f"auth_provider_{name}"] = mock_response
95
+
96
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
74
97
  """Mock implementation (no‑op)."""
75
98
  pass
76
99
 
77
- async def get_ttc_strategy(self,
78
- strategy_name: str,
79
- pipeline_type: typing.Any = None,
80
- stage_type: typing.Any = None):
100
+ async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
101
+ stage_type: StageTypeEnum) -> typing.Any:
81
102
  """Return a mock TTC strategy if one is configured."""
82
103
  key = f"ttc_strategy_{strategy_name}"
83
104
  if key in self._mocks:
@@ -90,16 +111,29 @@ class MockBuilder(Builder):
90
111
 
91
112
  async def get_ttc_strategy_config(self,
92
113
  strategy_name: str,
93
- pipeline_type: typing.Any = None,
94
- stage_type: typing.Any = None):
114
+ pipeline_type: PipelineTypeEnum,
115
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
95
116
  """Mock implementation."""
117
+ return TTCStrategyBaseConfig()
118
+
119
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
120
+ """Mock implementation (no‑op)."""
96
121
  pass
97
122
 
123
+ async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
124
+ """Return a mock auth provider if one is configured."""
125
+ key = f"auth_provider_{auth_provider_name}"
126
+ if key in self._mocks:
127
+ mock_auth = MagicMock()
128
+ mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
129
+ return mock_auth
130
+ raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
131
+
98
132
  async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
99
133
  """Mock implementation - not used in tool testing."""
100
134
  raise NotImplementedError("Mock implementation does not support add_function")
101
135
 
102
- def get_function(self, name: str) -> Function:
136
+ async def get_function(self, name: str) -> Function:
103
137
  """Return a mock function if one is configured."""
104
138
  if name in self._mocks:
105
139
  mock_fn = AsyncMock()
@@ -109,25 +143,49 @@ class MockBuilder(Builder):
109
143
 
110
144
  def get_function_config(self, name: str) -> FunctionBaseConfig:
111
145
  """Mock implementation."""
112
- pass
146
+ return FunctionBaseConfig()
147
+
148
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
149
+ """Mock implementation - not used in tool testing."""
150
+ raise NotImplementedError("Mock implementation does not support add_function_group")
151
+
152
+ async def get_function_group(self, name: str) -> FunctionGroup:
153
+ """Return a mock function group if one is configured."""
154
+ if name in self._mocks:
155
+ mock_fn_group = MagicMock(spec=FunctionGroup)
156
+ mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
157
+ return mock_fn_group
158
+ raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
159
+
160
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
161
+ """Mock implementation."""
162
+ return FunctionGroupBaseConfig()
113
163
 
114
164
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
115
165
  """Mock implementation."""
116
- pass
166
+ mock_fn = AsyncMock()
167
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
168
+ return mock_fn
117
169
 
118
170
  def get_workflow(self) -> Function:
119
171
  """Mock implementation."""
120
- pass
172
+ mock_fn = AsyncMock()
173
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
174
+ return mock_fn
121
175
 
122
176
  def get_workflow_config(self) -> FunctionBaseConfig:
123
177
  """Mock implementation."""
124
- pass
178
+ return FunctionBaseConfig()
179
+
180
+ async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
181
+ """Mock implementation."""
182
+ return []
125
183
 
126
- def get_tool(self, fn_name: str, wrapper_type):
184
+ async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
127
185
  """Mock implementation."""
128
186
  pass
129
187
 
130
- async def add_llm(self, name: str, config):
188
+ async def add_llm(self, name: str, config) -> None:
131
189
  """Mock implementation."""
132
190
  pass
133
191
 
@@ -141,11 +199,11 @@ class MockBuilder(Builder):
141
199
  return mock_llm
142
200
  raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
143
201
 
144
- def get_llm_config(self, llm_name: str):
202
+ def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
145
203
  """Mock implementation."""
146
- pass
204
+ return LLMBaseConfig()
147
205
 
148
- async def add_embedder(self, name: str, config):
206
+ async def add_embedder(self, name: str, config) -> None:
149
207
  """Mock implementation."""
150
208
  pass
151
209
 
@@ -159,15 +217,14 @@ class MockBuilder(Builder):
159
217
  return mock_embedder
160
218
  raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
161
219
 
162
- def get_embedder_config(self, embedder_name: str):
220
+ def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
163
221
  """Mock implementation."""
164
- pass
222
+ return EmbedderBaseConfig()
165
223
 
166
- async def add_memory_client(self, name: str, config):
167
- """Mock implementation."""
168
- pass
224
+ async def add_memory_client(self, name: str, config) -> MemoryEditor:
225
+ return MagicMock(spec=MemoryEditor)
169
226
 
170
- def get_memory_client(self, memory_name: str):
227
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
171
228
  """Return a mock memory client if one is configured."""
172
229
  key = f"memory_{memory_name}"
173
230
  if key in self._mocks:
@@ -177,11 +234,11 @@ class MockBuilder(Builder):
177
234
  return mock_memory
178
235
  raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
179
236
 
180
- def get_memory_client_config(self, memory_name: str):
237
+ def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
181
238
  """Mock implementation."""
182
- pass
239
+ return MemoryBaseConfig()
183
240
 
184
- async def add_retriever(self, name: str, config):
241
+ async def add_retriever(self, name: str, config) -> None:
185
242
  """Mock implementation."""
186
243
  pass
187
244
 
@@ -194,13 +251,13 @@ class MockBuilder(Builder):
194
251
  return mock_retriever
195
252
  raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
196
253
 
197
- async def get_retriever_config(self, retriever_name: str):
254
+ async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
198
255
  """Mock implementation."""
199
- pass
256
+ return RetrieverBaseConfig()
200
257
 
201
- async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
258
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
202
259
  """Mock implementation for object store."""
203
- pass
260
+ return MagicMock(spec=ObjectStore)
204
261
 
205
262
  async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
206
263
  """Return a mock object store client if one is configured."""
@@ -216,7 +273,7 @@ class MockBuilder(Builder):
216
273
 
217
274
  def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
218
275
  """Mock implementation for object store config."""
219
- pass
276
+ return ObjectStoreBaseConfig()
220
277
 
221
278
  def get_user_manager(self):
222
279
  """Mock implementation."""
@@ -224,9 +281,13 @@ class MockBuilder(Builder):
224
281
  mock_user.get_id = MagicMock(return_value="test_user")
225
282
  return mock_user
226
283
 
227
- def get_function_dependencies(self, fn_name: str):
284
+ def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
228
285
  """Mock implementation."""
229
- pass
286
+ return FunctionDependencies()
287
+
288
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
289
+ """Mock implementation."""
290
+ return FunctionDependencies()
230
291
 
231
292
 
232
293
  class ToolTestRunner:
@@ -322,15 +383,19 @@ class ToolTestRunner:
322
383
 
323
384
  # Execute the tool
324
385
  if input_data is not None:
325
- if asyncio.iscoroutinefunction(tool_function):
386
+ if isinstance(tool_function, Function):
387
+ result = await tool_function.ainvoke(input_data)
388
+ elif asyncio.iscoroutinefunction(tool_function):
326
389
  result = await tool_function(input_data)
327
390
  else:
328
391
  result = tool_function(input_data)
392
+ elif isinstance(tool_function, Function):
393
+ # Function objects require input, so pass None if no input_data
394
+ result = await tool_function.ainvoke(None)
395
+ elif asyncio.iscoroutinefunction(tool_function):
396
+ result = await tool_function()
329
397
  else:
330
- if asyncio.iscoroutinefunction(tool_function):
331
- result = await tool_function()
332
- else:
333
- result = tool_function()
398
+ result = tool_function()
334
399
 
335
400
  # Assert expected output if provided
336
401
  if expected_output is not None:
@@ -402,8 +467,8 @@ class ToolTestRunner:
402
467
  elif isinstance(tool_result, FunctionInfo):
403
468
  if tool_result.single_fn:
404
469
  tool_function = tool_result.single_fn
405
- elif tool_result.streaming_fn:
406
- tool_function = tool_result.streaming_fn
470
+ elif tool_result.stream_fn:
471
+ tool_function = tool_result.stream_fn
407
472
  else:
408
473
  raise ValueError("Tool function not found in FunctionInfo")
409
474
  elif callable(tool_result):
@@ -413,15 +478,17 @@ class ToolTestRunner:
413
478
 
414
479
  # Execute the tool
415
480
  if input_data is not None:
416
- if asyncio.iscoroutinefunction(tool_function):
417
- result = await tool_function(input_data)
481
+ if isinstance(tool_function, Function):
482
+ result = await tool_function.ainvoke(input_data)
418
483
  else:
419
- result = tool_function(input_data)
484
+ maybe_result = tool_function(input_data)
485
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
486
+ elif isinstance(tool_function, Function):
487
+ # Function objects require input, so pass None if no input_data
488
+ result = await tool_function.ainvoke(None)
420
489
  else:
421
- if asyncio.iscoroutinefunction(tool_function):
422
- result = await tool_function()
423
- else:
424
- result = tool_function()
490
+ maybe_result = tool_function()
491
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
425
492
 
426
493
  # Assert expected output if provided
427
494
  if expected_output is not None:
nat/test/utils.py ADDED
@@ -0,0 +1,87 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib.resources
17
+ import inspect
18
+ import subprocess
19
+ import typing
20
+ from pathlib import Path
21
+
22
+ if typing.TYPE_CHECKING:
23
+ from nat.data_models.config import Config
24
+ from nat.utils.type_utils import StrPath
25
+
26
+
27
+ def locate_repo_root() -> Path:
28
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
29
+ assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
30
+ return Path(result.stdout.strip())
31
+
32
+
33
+ def locate_example_src_dir(example_config_class: type) -> Path:
34
+ """
35
+ Locate the example src directory for an example's config class.
36
+ """
37
+ package_name = inspect.getmodule(example_config_class).__package__
38
+ return importlib.resources.files(package_name)
39
+
40
+
41
+ def locate_example_dir(example_config_class: type) -> Path:
42
+ """
43
+ Locate the example directory for an example's config class.
44
+ """
45
+ src_dir = locate_example_src_dir(example_config_class)
46
+ example_dir = src_dir.parent.parent
47
+ return example_dir
48
+
49
+
50
+ def locate_example_config(example_config_class: type,
51
+ config_file: str = "config.yml",
52
+ assert_exists: bool = True) -> Path:
53
+ """
54
+ Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
55
+ """
56
+ example_dir = locate_example_src_dir(example_config_class)
57
+ config_path = example_dir.joinpath("configs", config_file).absolute()
58
+ if assert_exists:
59
+ assert config_path.exists(), f"Config file {config_path} does not exist"
60
+
61
+ return config_path
62
+
63
+
64
+ async def run_workflow(
65
+ config_file: "StrPath | None",
66
+ question: str,
67
+ expected_answer: str,
68
+ assert_expected_answer: bool = True,
69
+ config: "Config | None" = None,
70
+ ) -> str:
71
+ from nat.builder.workflow_builder import WorkflowBuilder
72
+ from nat.runtime.loader import load_config
73
+ from nat.runtime.session import SessionManager
74
+
75
+ if config is None:
76
+ assert config_file is not None, "Either config_file or config must be provided"
77
+ config = load_config(config_file)
78
+
79
+ async with WorkflowBuilder.from_config(config=config) as workflow_builder:
80
+ workflow = SessionManager(await workflow_builder.build())
81
+ async with workflow.run(question) as runner:
82
+ result = await runner.result(to_type=str)
83
+
84
+ if assert_expected_answer:
85
+ assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
86
+
87
+ return result
@@ -1,12 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-test
3
- Version: 1.3.dev0
3
+ Version: 1.3.0rc1
4
4
  Summary: Testing utilities for NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents
6
6
  Classifier: Programming Language :: Python
7
- Requires-Python: <3.13,>=3.11
7
+ Classifier: Programming Language :: Python :: 3.11
8
+ Classifier: Programming Language :: Python :: 3.12
9
+ Classifier: Programming Language :: Python :: 3.13
10
+ Requires-Python: <3.14,>=3.11
8
11
  Description-Content-Type: text/markdown
9
- Requires-Dist: nvidia-nat==v1.3-dev
12
+ Requires-Dist: nvidia-nat==v1.3.0-rc1
10
13
  Requires-Dist: langchain-community~=0.3
11
14
  Requires-Dist: pytest~=8.3
12
15
 
@@ -0,0 +1,16 @@
1
+ nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
2
+ nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
3
+ nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
4
+ nat/test/functions.py,sha256=0ScrdsjcxCsPRLnyb5gfwukmvZxFi_ptCswLSIG0DVY,3095
5
+ nat/test/llm.py,sha256=osJWGsJN7x-JGOaitueKeSwuJPVmnIFqJUCz28ngSRg,8215
6
+ nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
7
+ nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
8
+ nat/test/plugin.py,sha256=P6awHPN-iVmb8VIzLX_pd9rMnR4PD8Of3H_ypPPFr8Q,11246
9
+ nat/test/register.py,sha256=o1BEA5fyxyFyCxXhQ6ArmtuNpgRyTEfvw6HdBgECPLI,897
10
+ nat/test/tool_test_runner.py,sha256=SxavwXHkvCQDl_PUiiiqgvGfexKJJTeBdI5i1qk6AzI,21712
11
+ nat/test/utils.py,sha256=y77p5uVpRPSQqVOnetBLvJVsSRgS4_fEgcuRoHwvRKE,3187
12
+ nvidia_nat_test-1.3.0rc1.dist-info/METADATA,sha256=vMpSoSRpL1C-MTe1epDZueqZA_l2qcyC4MmVzmTcxqE,1608
13
+ nvidia_nat_test-1.3.0rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
+ nvidia_nat_test-1.3.0rc1.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
15
+ nvidia_nat_test-1.3.0rc1.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
16
+ nvidia_nat_test-1.3.0rc1.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
2
- nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
3
- nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
4
- nat/test/functions.py,sha256=0ScrdsjcxCsPRLnyb5gfwukmvZxFi_ptCswLSIG0DVY,3095
5
- nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
6
- nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
7
- nat/test/plugin.py,sha256=fp39ib0W63vfqX6Ssvq4sCuSd8Lm6yQyknL3_qRijgI,3610
8
- nat/test/register.py,sha256=jU1pW5wf20ZmCOTgkaQshKZfvYh8_-sMJ4P3xXilfTY,891
9
- nat/test/tool_test_runner.py,sha256=ccErldob2VwBbVL0_pmLrOcKLc18qYjxeAEACYoKKGQ,17469
10
- nvidia_nat_test-1.3.dev0.dist-info/METADATA,sha256=pu_tjl4f97yMYq27FcNYF9hGIvNzfigmSfU8sEyfiZE,1453
11
- nvidia_nat_test-1.3.dev0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
- nvidia_nat_test-1.3.dev0.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
13
- nvidia_nat_test-1.3.dev0.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
14
- nvidia_nat_test-1.3.dev0.dist-info/RECORD,,