nvidia-nat-test 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/test/functions.py +9 -1
- nat/test/llm.py +205 -0
- nat/test/plugin.py +249 -12
- nat/test/register.py +2 -1
- nat/test/tool_test_runner.py +114 -47
- nat/test/utils.py +87 -0
- {nvidia_nat_test-1.3.0.dev2.dist-info → nvidia_nat_test-1.3.0rc2.dist-info}/METADATA +6 -3
- nvidia_nat_test-1.3.0rc2.dist-info/RECORD +16 -0
- nvidia_nat_test-1.3.0.dev2.dist-info/RECORD +0 -14
- {nvidia_nat_test-1.3.0.dev2.dist-info → nvidia_nat_test-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat_test-1.3.0.dev2.dist-info → nvidia_nat_test-1.3.0rc2.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_test-1.3.0.dev2.dist-info → nvidia_nat_test-1.3.0rc2.dist-info}/top_level.txt +0 -0
nat/test/functions.py
CHANGED
@@ -21,6 +21,7 @@ from nat.cli.register_workflow import register_function
|
|
21
21
|
from nat.data_models.api_server import ChatRequest
|
22
22
|
from nat.data_models.api_server import ChatResponse
|
23
23
|
from nat.data_models.api_server import ChatResponseChunk
|
24
|
+
from nat.data_models.api_server import Usage
|
24
25
|
from nat.data_models.function import FunctionBaseConfig
|
25
26
|
|
26
27
|
|
@@ -35,7 +36,14 @@ async def echo_function(config: EchoFunctionConfig, builder: Builder):
|
|
35
36
|
return message
|
36
37
|
|
37
38
|
async def inner_oai(message: ChatRequest) -> ChatResponse:
|
38
|
-
|
39
|
+
content = message.messages[0].content
|
40
|
+
|
41
|
+
# Create usage statistics for the response
|
42
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
|
43
|
+
completion_tokens = len(content.split()) if content else 0
|
44
|
+
total_tokens = prompt_tokens + completion_tokens
|
45
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
46
|
+
return ChatResponse.from_string(content, usage=usage)
|
39
47
|
|
40
48
|
if (config.use_openai_api):
|
41
49
|
yield inner_oai
|
nat/test/llm.py
ADDED
@@ -0,0 +1,205 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
# pylint: disable=unused-argument,missing-class-docstring,missing-function-docstring,import-outside-toplevel
|
17
|
+
# pylint: disable=too-few-public-methods
|
18
|
+
|
19
|
+
import asyncio
|
20
|
+
import time
|
21
|
+
from collections.abc import AsyncGenerator
|
22
|
+
from collections.abc import Iterator
|
23
|
+
from itertools import cycle as iter_cycle
|
24
|
+
from typing import Any
|
25
|
+
|
26
|
+
from pydantic import Field
|
27
|
+
|
28
|
+
from nat.builder.builder import Builder
|
29
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
30
|
+
from nat.builder.llm import LLMProviderInfo
|
31
|
+
from nat.cli.register_workflow import register_llm_client
|
32
|
+
from nat.cli.register_workflow import register_llm_provider
|
33
|
+
from nat.data_models.llm import LLMBaseConfig
|
34
|
+
|
35
|
+
|
36
|
+
class TestLLMConfig(LLMBaseConfig, name="nat_test_llm"):
|
37
|
+
"""Test LLM configuration."""
|
38
|
+
__test__ = False
|
39
|
+
response_seq: list[str] = Field(
|
40
|
+
default=[],
|
41
|
+
description="Returns the next element in order (wraps)",
|
42
|
+
)
|
43
|
+
delay_ms: int = Field(default=0, ge=0, description="Artificial per-call delay in milliseconds to mimic latency")
|
44
|
+
|
45
|
+
|
46
|
+
class _ResponseChooser:
|
47
|
+
"""
|
48
|
+
Helper class to choose the next response according to config using itertools.cycle and provide synchronous and
|
49
|
+
asynchronous sleep functions.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(self, response_seq: list[str], delay_ms: int):
|
53
|
+
self._cycler = iter_cycle(response_seq) if response_seq else None
|
54
|
+
self._delay_ms = delay_ms
|
55
|
+
|
56
|
+
def next_response(self) -> str:
|
57
|
+
"""Return the next response in the cycle, or an empty string if no responses are configured."""
|
58
|
+
if self._cycler is None:
|
59
|
+
return ""
|
60
|
+
return next(self._cycler)
|
61
|
+
|
62
|
+
def sync_sleep(self) -> None:
|
63
|
+
time.sleep(self._delay_ms / 1000.0)
|
64
|
+
|
65
|
+
async def async_sleep(self) -> None:
|
66
|
+
await asyncio.sleep(self._delay_ms / 1000.0)
|
67
|
+
|
68
|
+
|
69
|
+
@register_llm_provider(config_type=TestLLMConfig)
|
70
|
+
async def test_llm_provider(config: TestLLMConfig, builder: Builder):
|
71
|
+
"""Register the `nat_test_llm` provider for the NAT registry."""
|
72
|
+
yield LLMProviderInfo(config=config, description="Test LLM provider")
|
73
|
+
|
74
|
+
|
75
|
+
@register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
76
|
+
async def test_llm_langchain(config: TestLLMConfig, builder: Builder):
|
77
|
+
"""LLM client for LangChain/LangGraph."""
|
78
|
+
|
79
|
+
chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
|
80
|
+
|
81
|
+
class LangChainTestLLM:
|
82
|
+
|
83
|
+
def invoke(self, messages: Any, **_kwargs: Any) -> str:
|
84
|
+
chooser.sync_sleep()
|
85
|
+
return chooser.next_response()
|
86
|
+
|
87
|
+
async def ainvoke(self, messages: Any, **_kwargs: Any) -> str:
|
88
|
+
await chooser.async_sleep()
|
89
|
+
return chooser.next_response()
|
90
|
+
|
91
|
+
def stream(self, messages: Any, **_kwargs: Any) -> Iterator[str]:
|
92
|
+
chooser.sync_sleep()
|
93
|
+
yield chooser.next_response()
|
94
|
+
|
95
|
+
async def astream(self, messages: Any, **_kwargs: Any) -> AsyncGenerator[str]:
|
96
|
+
await chooser.async_sleep()
|
97
|
+
yield chooser.next_response()
|
98
|
+
|
99
|
+
yield LangChainTestLLM()
|
100
|
+
|
101
|
+
|
102
|
+
@register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX)
|
103
|
+
async def test_llm_llama_index(config: TestLLMConfig, builder: Builder):
|
104
|
+
|
105
|
+
try:
|
106
|
+
from llama_index.core.base.llms.types import ChatMessage
|
107
|
+
from llama_index.core.base.llms.types import ChatResponse
|
108
|
+
except ImportError as exc:
|
109
|
+
raise ImportError("llama_index is required for using the test_llm with llama_index. "
|
110
|
+
"Please install the `nvidia-nat-llama-index` package. ") from exc
|
111
|
+
|
112
|
+
chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
|
113
|
+
|
114
|
+
class LITestLLM:
|
115
|
+
|
116
|
+
def chat(self, messages: list[Any] | None = None, **_kwargs: Any) -> ChatResponse:
|
117
|
+
chooser.sync_sleep()
|
118
|
+
return ChatResponse(message=ChatMessage(chooser.next_response()))
|
119
|
+
|
120
|
+
async def achat(self, messages: list[Any] | None = None, **_kwargs: Any) -> ChatResponse:
|
121
|
+
await chooser.async_sleep()
|
122
|
+
return ChatResponse(message=ChatMessage(chooser.next_response()))
|
123
|
+
|
124
|
+
def stream_chat(self, messages: list[Any] | None = None, **_kwargs: Any) -> Iterator[ChatResponse]:
|
125
|
+
chooser.sync_sleep()
|
126
|
+
yield ChatResponse(message=ChatMessage(chooser.next_response()))
|
127
|
+
|
128
|
+
async def astream_chat(self,
|
129
|
+
messages: list[Any] | None = None,
|
130
|
+
**_kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
131
|
+
await chooser.async_sleep()
|
132
|
+
yield ChatResponse(message=ChatMessage(chooser.next_response()))
|
133
|
+
|
134
|
+
yield LITestLLM()
|
135
|
+
|
136
|
+
|
137
|
+
@register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
|
138
|
+
async def test_llm_crewai(config: TestLLMConfig, builder: Builder):
|
139
|
+
"""LLM client for CrewAI."""
|
140
|
+
|
141
|
+
chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
|
142
|
+
|
143
|
+
class CrewAITestLLM:
|
144
|
+
|
145
|
+
def call(self, messages: list[dict[str, str]] | None = None, **kwargs: Any) -> str:
|
146
|
+
chooser.sync_sleep()
|
147
|
+
return chooser.next_response()
|
148
|
+
|
149
|
+
yield CrewAITestLLM()
|
150
|
+
|
151
|
+
|
152
|
+
@register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.SEMANTIC_KERNEL)
|
153
|
+
async def test_llm_semantic_kernel(config: TestLLMConfig, builder: Builder):
|
154
|
+
"""LLM client for SemanticKernel."""
|
155
|
+
|
156
|
+
try:
|
157
|
+
from semantic_kernel.contents.chat_message_content import ChatMessageContent
|
158
|
+
from semantic_kernel.contents.utils.author_role import AuthorRole
|
159
|
+
except ImportError as exc:
|
160
|
+
raise ImportError("Semantic Kernel is required for using the test_llm with semantic_kernel. "
|
161
|
+
"Please install the `nvidia-nat-semantic-kernel` package. ") from exc
|
162
|
+
|
163
|
+
chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
|
164
|
+
|
165
|
+
class SKTestLLM:
|
166
|
+
|
167
|
+
async def get_chat_message_contents(self, chat_history: Any, **_kwargs: Any) -> list[ChatMessageContent]:
|
168
|
+
await chooser.async_sleep()
|
169
|
+
text = chooser.next_response()
|
170
|
+
return [ChatMessageContent(role=AuthorRole.ASSISTANT, content=text)]
|
171
|
+
|
172
|
+
async def get_streaming_chat_message_contents(self, chat_history: Any,
|
173
|
+
**_kwargs: Any) -> AsyncGenerator[ChatMessageContent, None]:
|
174
|
+
await chooser.async_sleep()
|
175
|
+
text = chooser.next_response()
|
176
|
+
yield ChatMessageContent(role=AuthorRole.ASSISTANT, content=text)
|
177
|
+
|
178
|
+
yield SKTestLLM()
|
179
|
+
|
180
|
+
|
181
|
+
@register_llm_client(config_type=TestLLMConfig, wrapper_type=LLMFrameworkEnum.AGNO)
|
182
|
+
async def test_llm_agno(config: TestLLMConfig, builder: Builder):
|
183
|
+
"""LLM client for agno."""
|
184
|
+
|
185
|
+
chooser = _ResponseChooser(response_seq=config.response_seq, delay_ms=config.delay_ms)
|
186
|
+
|
187
|
+
class AgnoTestLLM:
|
188
|
+
|
189
|
+
def invoke(self, messages: Any | None = None, **_kwargs: Any) -> str:
|
190
|
+
chooser.sync_sleep()
|
191
|
+
return chooser.next_response()
|
192
|
+
|
193
|
+
async def ainvoke(self, messages: Any | None = None, **_kwargs: Any) -> str:
|
194
|
+
await chooser.async_sleep()
|
195
|
+
return chooser.next_response()
|
196
|
+
|
197
|
+
def invoke_stream(self, messages: Any | None = None, **_kwargs: Any) -> Iterator[str]:
|
198
|
+
chooser.sync_sleep()
|
199
|
+
yield chooser.next_response()
|
200
|
+
|
201
|
+
async def ainvoke_stream(self, messages: Any | None = None, **_kwargs: Any) -> AsyncGenerator[str, None]:
|
202
|
+
await chooser.async_sleep()
|
203
|
+
yield chooser.next_response()
|
204
|
+
|
205
|
+
yield AgnoTestLLM()
|
nat/test/plugin.py
CHANGED
@@ -13,20 +13,21 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import os
|
17
|
+
import subprocess
|
18
|
+
import typing
|
19
|
+
from pathlib import Path
|
20
|
+
|
16
21
|
import pytest
|
17
22
|
|
23
|
+
if typing.TYPE_CHECKING:
|
24
|
+
from docker.client import DockerClient
|
25
|
+
|
18
26
|
|
19
27
|
def pytest_addoption(parser: pytest.Parser):
|
20
28
|
"""
|
21
29
|
Adds command line options for running specfic tests that are disabled by default
|
22
30
|
"""
|
23
|
-
parser.addoption(
|
24
|
-
"--run_e2e",
|
25
|
-
action="store_true",
|
26
|
-
dest="run_e2e",
|
27
|
-
help="Run end to end tests that would otherwise be skipped",
|
28
|
-
)
|
29
|
-
|
30
31
|
parser.addoption(
|
31
32
|
"--run_integration",
|
32
33
|
action="store_true",
|
@@ -42,12 +43,16 @@ def pytest_addoption(parser: pytest.Parser):
|
|
42
43
|
help="Run end to end tests that would otherwise be skipped",
|
43
44
|
)
|
44
45
|
|
46
|
+
parser.addoption(
|
47
|
+
"--fail_missing",
|
48
|
+
action="store_true",
|
49
|
+
dest="fail_missing",
|
50
|
+
help=("Tests requiring unmet dependencies are normally skipped. "
|
51
|
+
"Setting this flag will instead cause them to be reported as a failure"),
|
52
|
+
)
|
53
|
+
|
45
54
|
|
46
55
|
def pytest_runtest_setup(item):
|
47
|
-
if (not item.config.getoption("--run_e2e")):
|
48
|
-
if (item.get_closest_marker("e2e") is not None):
|
49
|
-
pytest.skip("Skipping end to end tests by default. Use --run_e2e to enable")
|
50
|
-
|
51
56
|
if (not item.config.getoption("--run_integration")):
|
52
57
|
if (item.get_closest_marker("integration") is not None):
|
53
58
|
pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
|
@@ -68,7 +73,6 @@ def register_components_fixture():
|
|
68
73
|
discover_and_register_plugins(PluginTypes.ALL)
|
69
74
|
|
70
75
|
# Also import the nat.test.register module to register test-only components
|
71
|
-
import nat.test.register # pylint: disable=unused-import # noqa: F401
|
72
76
|
|
73
77
|
|
74
78
|
@pytest.fixture(name="module_registry", scope="module", autouse=True)
|
@@ -95,3 +99,236 @@ def function_registry_fixture():
|
|
95
99
|
|
96
100
|
with GlobalTypeRegistry.push() as registry:
|
97
101
|
yield registry
|
102
|
+
|
103
|
+
|
104
|
+
@pytest.fixture(scope="session", name="fail_missing")
|
105
|
+
def fail_missing_fixture(pytestconfig: pytest.Config) -> bool:
|
106
|
+
"""
|
107
|
+
Returns the value of the `fail_missing` flag, when false tests requiring unmet dependencies will be skipped, when
|
108
|
+
True they will fail.
|
109
|
+
"""
|
110
|
+
yield pytestconfig.getoption("fail_missing")
|
111
|
+
|
112
|
+
|
113
|
+
def require_env_variables(varnames: list[str], reason: str, fail_missing: bool = False) -> dict[str, str]:
|
114
|
+
"""
|
115
|
+
Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and
|
116
|
+
`fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised.
|
117
|
+
"""
|
118
|
+
env_variables = {}
|
119
|
+
try:
|
120
|
+
for varname in varnames:
|
121
|
+
env_variables[varname] = os.environ[varname]
|
122
|
+
except KeyError as e:
|
123
|
+
if fail_missing:
|
124
|
+
raise RuntimeError(reason) from e
|
125
|
+
|
126
|
+
pytest.skip(reason=reason)
|
127
|
+
|
128
|
+
return env_variables
|
129
|
+
|
130
|
+
|
131
|
+
@pytest.fixture(name="openai_api_key", scope='session')
|
132
|
+
def openai_api_key_fixture(fail_missing: bool):
|
133
|
+
"""
|
134
|
+
Use for integration tests that require an Openai API key.
|
135
|
+
"""
|
136
|
+
yield require_env_variables(
|
137
|
+
varnames=["OPENAI_API_KEY"],
|
138
|
+
reason="openai integration tests require the `OPENAI_API_KEY` environment variable to be defined.",
|
139
|
+
fail_missing=fail_missing)
|
140
|
+
|
141
|
+
|
142
|
+
@pytest.fixture(name="nvidia_api_key", scope='session')
|
143
|
+
def nvidia_api_key_fixture(fail_missing: bool):
|
144
|
+
"""
|
145
|
+
Use for integration tests that require an Nvidia API key.
|
146
|
+
"""
|
147
|
+
yield require_env_variables(
|
148
|
+
varnames=["NVIDIA_API_KEY"],
|
149
|
+
reason="Nvidia integration tests require the `NVIDIA_API_KEY` environment variable to be defined.",
|
150
|
+
fail_missing=fail_missing)
|
151
|
+
|
152
|
+
|
153
|
+
@pytest.fixture(name="serp_api_key", scope='session')
|
154
|
+
def serp_api_key_fixture(fail_missing: bool):
|
155
|
+
"""
|
156
|
+
Use for integration tests that require a SERP API key.
|
157
|
+
"""
|
158
|
+
yield require_env_variables(
|
159
|
+
varnames=["SERP_API_KEY"],
|
160
|
+
reason="SERP integration tests require the `SERP_API_KEY` environment variable to be defined.",
|
161
|
+
fail_missing=fail_missing)
|
162
|
+
|
163
|
+
|
164
|
+
@pytest.fixture(name="tavily_api_key", scope='session')
|
165
|
+
def tavily_api_key_fixture(fail_missing: bool):
|
166
|
+
"""
|
167
|
+
Use for integration tests that require a Tavily API key.
|
168
|
+
"""
|
169
|
+
yield require_env_variables(
|
170
|
+
varnames=["TAVILY_API_KEY"],
|
171
|
+
reason="Tavily integration tests require the `TAVILY_API_KEY` environment variable to be defined.",
|
172
|
+
fail_missing=fail_missing)
|
173
|
+
|
174
|
+
|
175
|
+
@pytest.fixture(name="mem0_api_key", scope='session')
|
176
|
+
def mem0_api_key_fixture(fail_missing: bool):
|
177
|
+
"""
|
178
|
+
Use for integration tests that require a Mem0 API key.
|
179
|
+
"""
|
180
|
+
yield require_env_variables(
|
181
|
+
varnames=["MEM0_API_KEY"],
|
182
|
+
reason="Mem0 integration tests require the `MEM0_API_KEY` environment variable to be defined.",
|
183
|
+
fail_missing=fail_missing)
|
184
|
+
|
185
|
+
|
186
|
+
@pytest.fixture(name="aws_keys", scope='session')
|
187
|
+
def aws_keys_fixture(fail_missing: bool):
|
188
|
+
"""
|
189
|
+
Use for integration tests that require AWS credentials.
|
190
|
+
"""
|
191
|
+
|
192
|
+
yield require_env_variables(
|
193
|
+
varnames=["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
194
|
+
reason=
|
195
|
+
"AWS integration tests require the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be "
|
196
|
+
"defined.",
|
197
|
+
fail_missing=fail_missing)
|
198
|
+
|
199
|
+
|
200
|
+
@pytest.fixture(name="azure_openai_keys", scope='session')
|
201
|
+
def azure_openai_keys_fixture(fail_missing: bool):
|
202
|
+
"""
|
203
|
+
Use for integration tests that require Azure OpenAI credentials.
|
204
|
+
"""
|
205
|
+
yield require_env_variables(
|
206
|
+
varnames=["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
|
207
|
+
reason="Azure integration tests require the `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` environment "
|
208
|
+
"variable to be defined.",
|
209
|
+
fail_missing=fail_missing)
|
210
|
+
|
211
|
+
|
212
|
+
@pytest.fixture(name="require_docker", scope='session')
|
213
|
+
def require_docker_fixture(fail_missing: bool) -> "DockerClient":
|
214
|
+
"""
|
215
|
+
Use for integration tests that require Docker to be running.
|
216
|
+
"""
|
217
|
+
try:
|
218
|
+
from docker.client import DockerClient
|
219
|
+
yield DockerClient()
|
220
|
+
except Exception as e:
|
221
|
+
reason = f"Unable to connect to Docker daemon: {e}"
|
222
|
+
if fail_missing:
|
223
|
+
raise RuntimeError(reason) from e
|
224
|
+
pytest.skip(reason=reason)
|
225
|
+
|
226
|
+
|
227
|
+
@pytest.fixture(name="restore_environ")
|
228
|
+
def restore_environ_fixture():
|
229
|
+
orig_vars = os.environ.copy()
|
230
|
+
yield os.environ
|
231
|
+
|
232
|
+
for key, value in orig_vars.items():
|
233
|
+
os.environ[key] = value
|
234
|
+
|
235
|
+
# Delete any new environment variables
|
236
|
+
# Iterating over a copy of the keys as we will potentially be deleting keys in the loop
|
237
|
+
for key in list(os.environ.keys()):
|
238
|
+
if key not in orig_vars:
|
239
|
+
del (os.environ[key])
|
240
|
+
|
241
|
+
|
242
|
+
@pytest.fixture(name="root_repo_dir", scope='session')
|
243
|
+
def root_repo_dir_fixture() -> Path:
|
244
|
+
from nat.test.utils import locate_repo_root
|
245
|
+
return locate_repo_root()
|
246
|
+
|
247
|
+
|
248
|
+
@pytest.fixture(name="require_etcd", scope="session")
|
249
|
+
def require_etcd_fixture(fail_missing: bool = False) -> bool:
|
250
|
+
"""
|
251
|
+
To run these tests, an etcd server must be running
|
252
|
+
"""
|
253
|
+
import requests
|
254
|
+
|
255
|
+
host = os.getenv("NAT_CI_ETCD_HOST", "localhost")
|
256
|
+
port = os.getenv("NAT_CI_ETCD_PORT", "2379")
|
257
|
+
health_url = f"http://{host}:{port}/health"
|
258
|
+
|
259
|
+
try:
|
260
|
+
response = requests.get(health_url, timeout=5)
|
261
|
+
response.raise_for_status()
|
262
|
+
return True
|
263
|
+
except: # noqa: E722
|
264
|
+
failure_reason = f"Unable to connect to etcd server at {health_url}"
|
265
|
+
if fail_missing:
|
266
|
+
raise RuntimeError(failure_reason)
|
267
|
+
pytest.skip(reason=failure_reason)
|
268
|
+
|
269
|
+
|
270
|
+
@pytest.fixture(name="milvus_uri", scope="session")
|
271
|
+
def milvus_uri_fixture(require_etcd: bool, fail_missing: bool = False) -> str:
|
272
|
+
"""
|
273
|
+
To run these tests, a Milvus server must be running
|
274
|
+
"""
|
275
|
+
host = os.getenv("NAT_CI_MILVUS_HOST", "localhost")
|
276
|
+
port = os.getenv("NAT_CI_MILVUS_PORT", "19530")
|
277
|
+
uri = f"http://{host}:{port}"
|
278
|
+
try:
|
279
|
+
from pymilvus import MilvusClient
|
280
|
+
MilvusClient(uri=uri)
|
281
|
+
|
282
|
+
return uri
|
283
|
+
except: # noqa: E722
|
284
|
+
reason = f"Unable to connect to Milvus server at {uri}"
|
285
|
+
if fail_missing:
|
286
|
+
raise RuntimeError(reason)
|
287
|
+
pytest.skip(reason=reason)
|
288
|
+
|
289
|
+
|
290
|
+
@pytest.fixture(name="populate_milvus", scope="session")
|
291
|
+
def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path):
|
292
|
+
"""
|
293
|
+
Populate Milvus with some test data.
|
294
|
+
"""
|
295
|
+
populate_script = root_repo_dir / "scripts/langchain_web_ingest.py"
|
296
|
+
|
297
|
+
# Ingest default cuda docs
|
298
|
+
subprocess.run(["python", str(populate_script), "--milvus_uri", milvus_uri], check=True)
|
299
|
+
|
300
|
+
# Ingest MCP docs
|
301
|
+
subprocess.run([
|
302
|
+
"python",
|
303
|
+
str(populate_script),
|
304
|
+
"--milvus_uri",
|
305
|
+
milvus_uri,
|
306
|
+
"--urls",
|
307
|
+
"https://github.com/modelcontextprotocol/python-sdk",
|
308
|
+
"--urls",
|
309
|
+
"https://modelcontextprotocol.io/introduction",
|
310
|
+
"--urls",
|
311
|
+
"https://modelcontextprotocol.io/quickstart/server",
|
312
|
+
"--urls",
|
313
|
+
"https://modelcontextprotocol.io/quickstart/client",
|
314
|
+
"--urls",
|
315
|
+
"https://modelcontextprotocol.io/examples",
|
316
|
+
"--urls",
|
317
|
+
"https://modelcontextprotocol.io/docs/concepts/architecture",
|
318
|
+
"--collection_name",
|
319
|
+
"mcp_docs"
|
320
|
+
],
|
321
|
+
check=True)
|
322
|
+
|
323
|
+
# Ingest some wikipedia docs
|
324
|
+
subprocess.run([
|
325
|
+
"python",
|
326
|
+
str(populate_script),
|
327
|
+
"--milvus_uri",
|
328
|
+
milvus_uri,
|
329
|
+
"--urls",
|
330
|
+
"https://en.wikipedia.org/wiki/Aardvark",
|
331
|
+
"--collection_name",
|
332
|
+
"wikipedia_docs"
|
333
|
+
],
|
334
|
+
check=True)
|
nat/test/register.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
-
# pylint: disable=unused-import
|
17
16
|
# flake8: noqa
|
18
17
|
# isort:skip_file
|
19
18
|
|
@@ -22,3 +21,5 @@
|
|
22
21
|
from . import embedder
|
23
22
|
from . import functions
|
24
23
|
from . import memory
|
24
|
+
from . import llm
|
25
|
+
from . import utils
|
nat/test/tool_test_runner.py
CHANGED
@@ -14,18 +14,33 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
import asyncio
|
17
|
+
import inspect
|
17
18
|
import logging
|
18
19
|
import typing
|
20
|
+
from collections.abc import Sequence
|
19
21
|
from contextlib import asynccontextmanager
|
20
22
|
from unittest.mock import AsyncMock
|
21
23
|
from unittest.mock import MagicMock
|
22
24
|
|
25
|
+
from nat.authentication.interfaces import AuthProviderBase
|
23
26
|
from nat.builder.builder import Builder
|
24
27
|
from nat.builder.function import Function
|
28
|
+
from nat.builder.function import FunctionGroup
|
25
29
|
from nat.builder.function_info import FunctionInfo
|
26
30
|
from nat.cli.type_registry import GlobalTypeRegistry
|
31
|
+
from nat.data_models.authentication import AuthProviderBaseConfig
|
32
|
+
from nat.data_models.embedder import EmbedderBaseConfig
|
27
33
|
from nat.data_models.function import FunctionBaseConfig
|
34
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
35
|
+
from nat.data_models.function_dependencies import FunctionDependencies
|
36
|
+
from nat.data_models.llm import LLMBaseConfig
|
37
|
+
from nat.data_models.memory import MemoryBaseConfig
|
28
38
|
from nat.data_models.object_store import ObjectStoreBaseConfig
|
39
|
+
from nat.data_models.retriever import RetrieverBaseConfig
|
40
|
+
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
|
41
|
+
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
|
42
|
+
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
|
43
|
+
from nat.memory.interfaces import MemoryEditor
|
29
44
|
from nat.object_store.interfaces import ObjectStore
|
30
45
|
from nat.runtime.loader import PluginTypes
|
31
46
|
from nat.runtime.loader import discover_and_register_plugins
|
@@ -46,6 +61,10 @@ class MockBuilder(Builder):
|
|
46
61
|
"""Add a mock function that returns a fixed response."""
|
47
62
|
self._mocks[name] = mock_response
|
48
63
|
|
64
|
+
def mock_function_group(self, name: str, mock_response: typing.Any):
|
65
|
+
"""Add a mock function group that returns a fixed response."""
|
66
|
+
self._mocks[name] = mock_response
|
67
|
+
|
49
68
|
def mock_llm(self, name: str, mock_response: typing.Any):
|
50
69
|
"""Add a mock LLM that returns a fixed response."""
|
51
70
|
self._mocks[f"llm_{name}"] = mock_response
|
@@ -70,14 +89,16 @@ class MockBuilder(Builder):
|
|
70
89
|
"""Add a mock TTC strategy that returns a fixed response."""
|
71
90
|
self._mocks[f"ttc_strategy_{name}"] = mock_response
|
72
91
|
|
73
|
-
|
92
|
+
def mock_auth_provider(self, name: str, mock_response: typing.Any):
|
93
|
+
"""Add a mock auth provider that returns a fixed response."""
|
94
|
+
self._mocks[f"auth_provider_{name}"] = mock_response
|
95
|
+
|
96
|
+
async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
|
74
97
|
"""Mock implementation (no‑op)."""
|
75
98
|
pass
|
76
99
|
|
77
|
-
async def get_ttc_strategy(self,
|
78
|
-
|
79
|
-
pipeline_type: typing.Any = None,
|
80
|
-
stage_type: typing.Any = None):
|
100
|
+
async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
|
101
|
+
stage_type: StageTypeEnum) -> typing.Any:
|
81
102
|
"""Return a mock TTC strategy if one is configured."""
|
82
103
|
key = f"ttc_strategy_{strategy_name}"
|
83
104
|
if key in self._mocks:
|
@@ -90,16 +111,29 @@ class MockBuilder(Builder):
|
|
90
111
|
|
91
112
|
async def get_ttc_strategy_config(self,
|
92
113
|
strategy_name: str,
|
93
|
-
pipeline_type:
|
94
|
-
stage_type:
|
114
|
+
pipeline_type: PipelineTypeEnum,
|
115
|
+
stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
|
95
116
|
"""Mock implementation."""
|
117
|
+
return TTCStrategyBaseConfig()
|
118
|
+
|
119
|
+
async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
|
120
|
+
"""Mock implementation (no‑op)."""
|
96
121
|
pass
|
97
122
|
|
123
|
+
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
124
|
+
"""Return a mock auth provider if one is configured."""
|
125
|
+
key = f"auth_provider_{auth_provider_name}"
|
126
|
+
if key in self._mocks:
|
127
|
+
mock_auth = MagicMock()
|
128
|
+
mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
|
129
|
+
return mock_auth
|
130
|
+
raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
|
131
|
+
|
98
132
|
async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
|
99
133
|
"""Mock implementation - not used in tool testing."""
|
100
134
|
raise NotImplementedError("Mock implementation does not support add_function")
|
101
135
|
|
102
|
-
def get_function(self, name: str) -> Function:
|
136
|
+
async def get_function(self, name: str) -> Function:
|
103
137
|
"""Return a mock function if one is configured."""
|
104
138
|
if name in self._mocks:
|
105
139
|
mock_fn = AsyncMock()
|
@@ -109,25 +143,49 @@ class MockBuilder(Builder):
|
|
109
143
|
|
110
144
|
def get_function_config(self, name: str) -> FunctionBaseConfig:
|
111
145
|
"""Mock implementation."""
|
112
|
-
|
146
|
+
return FunctionBaseConfig()
|
147
|
+
|
148
|
+
async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
|
149
|
+
"""Mock implementation - not used in tool testing."""
|
150
|
+
raise NotImplementedError("Mock implementation does not support add_function_group")
|
151
|
+
|
152
|
+
async def get_function_group(self, name: str) -> FunctionGroup:
|
153
|
+
"""Return a mock function group if one is configured."""
|
154
|
+
if name in self._mocks:
|
155
|
+
mock_fn_group = MagicMock(spec=FunctionGroup)
|
156
|
+
mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
|
157
|
+
return mock_fn_group
|
158
|
+
raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
|
159
|
+
|
160
|
+
def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
|
161
|
+
"""Mock implementation."""
|
162
|
+
return FunctionGroupBaseConfig()
|
113
163
|
|
114
164
|
async def set_workflow(self, config: FunctionBaseConfig) -> Function:
|
115
165
|
"""Mock implementation."""
|
116
|
-
|
166
|
+
mock_fn = AsyncMock()
|
167
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
168
|
+
return mock_fn
|
117
169
|
|
118
170
|
def get_workflow(self) -> Function:
|
119
171
|
"""Mock implementation."""
|
120
|
-
|
172
|
+
mock_fn = AsyncMock()
|
173
|
+
mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
|
174
|
+
return mock_fn
|
121
175
|
|
122
176
|
def get_workflow_config(self) -> FunctionBaseConfig:
|
123
177
|
"""Mock implementation."""
|
124
|
-
|
178
|
+
return FunctionBaseConfig()
|
179
|
+
|
180
|
+
async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
|
181
|
+
"""Mock implementation."""
|
182
|
+
return []
|
125
183
|
|
126
|
-
def get_tool(self, fn_name: str, wrapper_type):
|
184
|
+
async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
|
127
185
|
"""Mock implementation."""
|
128
186
|
pass
|
129
187
|
|
130
|
-
async def add_llm(self, name: str, config):
|
188
|
+
async def add_llm(self, name: str, config) -> None:
|
131
189
|
"""Mock implementation."""
|
132
190
|
pass
|
133
191
|
|
@@ -141,11 +199,11 @@ class MockBuilder(Builder):
|
|
141
199
|
return mock_llm
|
142
200
|
raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
|
143
201
|
|
144
|
-
def get_llm_config(self, llm_name: str):
|
202
|
+
def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
|
145
203
|
"""Mock implementation."""
|
146
|
-
|
204
|
+
return LLMBaseConfig()
|
147
205
|
|
148
|
-
async def add_embedder(self, name: str, config):
|
206
|
+
async def add_embedder(self, name: str, config) -> None:
|
149
207
|
"""Mock implementation."""
|
150
208
|
pass
|
151
209
|
|
@@ -159,15 +217,14 @@ class MockBuilder(Builder):
|
|
159
217
|
return mock_embedder
|
160
218
|
raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
|
161
219
|
|
162
|
-
def get_embedder_config(self, embedder_name: str):
|
220
|
+
def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
|
163
221
|
"""Mock implementation."""
|
164
|
-
|
222
|
+
return EmbedderBaseConfig()
|
165
223
|
|
166
|
-
async def add_memory_client(self, name: str, config):
|
167
|
-
|
168
|
-
pass
|
224
|
+
async def add_memory_client(self, name: str, config) -> MemoryEditor:
|
225
|
+
return MagicMock(spec=MemoryEditor)
|
169
226
|
|
170
|
-
def get_memory_client(self, memory_name: str):
|
227
|
+
async def get_memory_client(self, memory_name: str) -> MemoryEditor:
|
171
228
|
"""Return a mock memory client if one is configured."""
|
172
229
|
key = f"memory_{memory_name}"
|
173
230
|
if key in self._mocks:
|
@@ -177,11 +234,11 @@ class MockBuilder(Builder):
|
|
177
234
|
return mock_memory
|
178
235
|
raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
|
179
236
|
|
180
|
-
def get_memory_client_config(self, memory_name: str):
|
237
|
+
def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
|
181
238
|
"""Mock implementation."""
|
182
|
-
|
239
|
+
return MemoryBaseConfig()
|
183
240
|
|
184
|
-
async def add_retriever(self, name: str, config):
|
241
|
+
async def add_retriever(self, name: str, config) -> None:
|
185
242
|
"""Mock implementation."""
|
186
243
|
pass
|
187
244
|
|
@@ -194,13 +251,13 @@ class MockBuilder(Builder):
|
|
194
251
|
return mock_retriever
|
195
252
|
raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
|
196
253
|
|
197
|
-
async def get_retriever_config(self, retriever_name: str):
|
254
|
+
async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
|
198
255
|
"""Mock implementation."""
|
199
|
-
|
256
|
+
return RetrieverBaseConfig()
|
200
257
|
|
201
|
-
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
|
258
|
+
async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
|
202
259
|
"""Mock implementation for object store."""
|
203
|
-
|
260
|
+
return MagicMock(spec=ObjectStore)
|
204
261
|
|
205
262
|
async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
|
206
263
|
"""Return a mock object store client if one is configured."""
|
@@ -216,7 +273,7 @@ class MockBuilder(Builder):
|
|
216
273
|
|
217
274
|
def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
|
218
275
|
"""Mock implementation for object store config."""
|
219
|
-
|
276
|
+
return ObjectStoreBaseConfig()
|
220
277
|
|
221
278
|
def get_user_manager(self):
|
222
279
|
"""Mock implementation."""
|
@@ -224,9 +281,13 @@ class MockBuilder(Builder):
|
|
224
281
|
mock_user.get_id = MagicMock(return_value="test_user")
|
225
282
|
return mock_user
|
226
283
|
|
227
|
-
def get_function_dependencies(self, fn_name: str):
|
284
|
+
def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
|
228
285
|
"""Mock implementation."""
|
229
|
-
|
286
|
+
return FunctionDependencies()
|
287
|
+
|
288
|
+
def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
|
289
|
+
"""Mock implementation."""
|
290
|
+
return FunctionDependencies()
|
230
291
|
|
231
292
|
|
232
293
|
class ToolTestRunner:
|
@@ -322,15 +383,19 @@ class ToolTestRunner:
|
|
322
383
|
|
323
384
|
# Execute the tool
|
324
385
|
if input_data is not None:
|
325
|
-
if
|
386
|
+
if isinstance(tool_function, Function):
|
387
|
+
result = await tool_function.ainvoke(input_data)
|
388
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
326
389
|
result = await tool_function(input_data)
|
327
390
|
else:
|
328
391
|
result = tool_function(input_data)
|
392
|
+
elif isinstance(tool_function, Function):
|
393
|
+
# Function objects require input, so pass None if no input_data
|
394
|
+
result = await tool_function.ainvoke(None)
|
395
|
+
elif asyncio.iscoroutinefunction(tool_function):
|
396
|
+
result = await tool_function()
|
329
397
|
else:
|
330
|
-
|
331
|
-
result = await tool_function()
|
332
|
-
else:
|
333
|
-
result = tool_function()
|
398
|
+
result = tool_function()
|
334
399
|
|
335
400
|
# Assert expected output if provided
|
336
401
|
if expected_output is not None:
|
@@ -402,8 +467,8 @@ class ToolTestRunner:
|
|
402
467
|
elif isinstance(tool_result, FunctionInfo):
|
403
468
|
if tool_result.single_fn:
|
404
469
|
tool_function = tool_result.single_fn
|
405
|
-
elif tool_result.
|
406
|
-
tool_function = tool_result.
|
470
|
+
elif tool_result.stream_fn:
|
471
|
+
tool_function = tool_result.stream_fn
|
407
472
|
else:
|
408
473
|
raise ValueError("Tool function not found in FunctionInfo")
|
409
474
|
elif callable(tool_result):
|
@@ -413,15 +478,17 @@ class ToolTestRunner:
|
|
413
478
|
|
414
479
|
# Execute the tool
|
415
480
|
if input_data is not None:
|
416
|
-
if
|
417
|
-
result = await tool_function(input_data)
|
481
|
+
if isinstance(tool_function, Function):
|
482
|
+
result = await tool_function.ainvoke(input_data)
|
418
483
|
else:
|
419
|
-
|
484
|
+
maybe_result = tool_function(input_data)
|
485
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
486
|
+
elif isinstance(tool_function, Function):
|
487
|
+
# Function objects require input, so pass None if no input_data
|
488
|
+
result = await tool_function.ainvoke(None)
|
420
489
|
else:
|
421
|
-
|
422
|
-
|
423
|
-
else:
|
424
|
-
result = tool_function()
|
490
|
+
maybe_result = tool_function()
|
491
|
+
result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
|
425
492
|
|
426
493
|
# Assert expected output if provided
|
427
494
|
if expected_output is not None:
|
nat/test/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import importlib.resources
|
17
|
+
import inspect
|
18
|
+
import subprocess
|
19
|
+
import typing
|
20
|
+
from pathlib import Path
|
21
|
+
|
22
|
+
if typing.TYPE_CHECKING:
|
23
|
+
from nat.data_models.config import Config
|
24
|
+
from nat.utils.type_utils import StrPath
|
25
|
+
|
26
|
+
|
27
|
+
def locate_repo_root() -> Path:
|
28
|
+
result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
|
29
|
+
assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
|
30
|
+
return Path(result.stdout.strip())
|
31
|
+
|
32
|
+
|
33
|
+
def locate_example_src_dir(example_config_class: type) -> Path:
|
34
|
+
"""
|
35
|
+
Locate the example src directory for an example's config class.
|
36
|
+
"""
|
37
|
+
package_name = inspect.getmodule(example_config_class).__package__
|
38
|
+
return importlib.resources.files(package_name)
|
39
|
+
|
40
|
+
|
41
|
+
def locate_example_dir(example_config_class: type) -> Path:
|
42
|
+
"""
|
43
|
+
Locate the example directory for an example's config class.
|
44
|
+
"""
|
45
|
+
src_dir = locate_example_src_dir(example_config_class)
|
46
|
+
example_dir = src_dir.parent.parent
|
47
|
+
return example_dir
|
48
|
+
|
49
|
+
|
50
|
+
def locate_example_config(example_config_class: type,
|
51
|
+
config_file: str = "config.yml",
|
52
|
+
assert_exists: bool = True) -> Path:
|
53
|
+
"""
|
54
|
+
Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
|
55
|
+
"""
|
56
|
+
example_dir = locate_example_src_dir(example_config_class)
|
57
|
+
config_path = example_dir.joinpath("configs", config_file).absolute()
|
58
|
+
if assert_exists:
|
59
|
+
assert config_path.exists(), f"Config file {config_path} does not exist"
|
60
|
+
|
61
|
+
return config_path
|
62
|
+
|
63
|
+
|
64
|
+
async def run_workflow(
|
65
|
+
config_file: "StrPath | None",
|
66
|
+
question: str,
|
67
|
+
expected_answer: str,
|
68
|
+
assert_expected_answer: bool = True,
|
69
|
+
config: "Config | None" = None,
|
70
|
+
) -> str:
|
71
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
72
|
+
from nat.runtime.loader import load_config
|
73
|
+
from nat.runtime.session import SessionManager
|
74
|
+
|
75
|
+
if config is None:
|
76
|
+
assert config_file is not None, "Either config_file or config must be provided"
|
77
|
+
config = load_config(config_file)
|
78
|
+
|
79
|
+
async with WorkflowBuilder.from_config(config=config) as workflow_builder:
|
80
|
+
workflow = SessionManager(await workflow_builder.build())
|
81
|
+
async with workflow.run(question) as runner:
|
82
|
+
result = await runner.result(to_type=str)
|
83
|
+
|
84
|
+
if assert_expected_answer:
|
85
|
+
assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
|
86
|
+
|
87
|
+
return result
|
@@ -1,12 +1,15 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-test
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0rc2
|
4
4
|
Summary: Testing utilities for NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents
|
6
6
|
Classifier: Programming Language :: Python
|
7
|
-
|
7
|
+
Classifier: Programming Language :: Python :: 3.11
|
8
|
+
Classifier: Programming Language :: Python :: 3.12
|
9
|
+
Classifier: Programming Language :: Python :: 3.13
|
10
|
+
Requires-Python: <3.14,>=3.11
|
8
11
|
Description-Content-Type: text/markdown
|
9
|
-
Requires-Dist: nvidia-nat==v1.3.0-
|
12
|
+
Requires-Dist: nvidia-nat==v1.3.0-rc2
|
10
13
|
Requires-Dist: langchain-community~=0.3
|
11
14
|
Requires-Dist: pytest~=8.3
|
12
15
|
|
@@ -0,0 +1,16 @@
|
|
1
|
+
nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
|
2
|
+
nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
|
3
|
+
nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
|
4
|
+
nat/test/functions.py,sha256=ZxXVzfaLBGOpR5qtmMrKU7q-M9-vVGGj3Xi5mrw4vHY,3557
|
5
|
+
nat/test/llm.py,sha256=osJWGsJN7x-JGOaitueKeSwuJPVmnIFqJUCz28ngSRg,8215
|
6
|
+
nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
|
7
|
+
nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
|
8
|
+
nat/test/plugin.py,sha256=P6awHPN-iVmb8VIzLX_pd9rMnR4PD8Of3H_ypPPFr8Q,11246
|
9
|
+
nat/test/register.py,sha256=o1BEA5fyxyFyCxXhQ6ArmtuNpgRyTEfvw6HdBgECPLI,897
|
10
|
+
nat/test/tool_test_runner.py,sha256=SxavwXHkvCQDl_PUiiiqgvGfexKJJTeBdI5i1qk6AzI,21712
|
11
|
+
nat/test/utils.py,sha256=y77p5uVpRPSQqVOnetBLvJVsSRgS4_fEgcuRoHwvRKE,3187
|
12
|
+
nvidia_nat_test-1.3.0rc2.dist-info/METADATA,sha256=vyclpVYsAUUgANU3_ZbjWPqzGolQdTNjnVwq0_K-t5E,1608
|
13
|
+
nvidia_nat_test-1.3.0rc2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
14
|
+
nvidia_nat_test-1.3.0rc2.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
|
15
|
+
nvidia_nat_test-1.3.0rc2.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
16
|
+
nvidia_nat_test-1.3.0rc2.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
|
2
|
-
nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
|
3
|
-
nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
|
4
|
-
nat/test/functions.py,sha256=0ScrdsjcxCsPRLnyb5gfwukmvZxFi_ptCswLSIG0DVY,3095
|
5
|
-
nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
|
6
|
-
nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
|
7
|
-
nat/test/plugin.py,sha256=fp39ib0W63vfqX6Ssvq4sCuSd8Lm6yQyknL3_qRijgI,3610
|
8
|
-
nat/test/register.py,sha256=jU1pW5wf20ZmCOTgkaQshKZfvYh8_-sMJ4P3xXilfTY,891
|
9
|
-
nat/test/tool_test_runner.py,sha256=ccErldob2VwBbVL0_pmLrOcKLc18qYjxeAEACYoKKGQ,17469
|
10
|
-
nvidia_nat_test-1.3.0.dev2.dist-info/METADATA,sha256=LFC2hgfp8ITqrKuES1NOCXbkyBwWPm_8OhjiIPIjxe4,1458
|
11
|
-
nvidia_nat_test-1.3.0.dev2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
12
|
-
nvidia_nat_test-1.3.0.dev2.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
|
13
|
-
nvidia_nat_test-1.3.0.dev2.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
14
|
-
nvidia_nat_test-1.3.0.dev2.dist-info/RECORD,,
|
File without changes
|
{nvidia_nat_test-1.3.0.dev2.dist-info → nvidia_nat_test-1.3.0rc2.dist-info}/entry_points.txt
RENAMED
File without changes
|
File without changes
|