nvidia-nat-test 1.3.0a20251108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nvidia-nat-test might be problematic. Click here for more details.
- nat/meta/pypi.md +23 -0
- nat/test/__init__.py +23 -0
- nat/test/embedder.py +44 -0
- nat/test/functions.py +99 -0
- nat/test/llm.py +236 -0
- nat/test/memory.py +41 -0
- nat/test/object_store_tests.py +117 -0
- nat/test/plugin.py +628 -0
- nat/test/register.py +25 -0
- nat/test/tool_test_runner.py +516 -0
- nat/test/utils.py +155 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/METADATA +46 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/RECORD +18 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/WHEEL +5 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/entry_points.txt +5 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_test-1.3.0a20251108.dist-info/top_level.txt +1 -0
nat/test/plugin.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
import subprocess
|
|
19
|
+
import time
|
|
20
|
+
import types
|
|
21
|
+
import typing
|
|
22
|
+
from collections.abc import AsyncGenerator
|
|
23
|
+
from collections.abc import Generator
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
import pytest
|
|
27
|
+
import pytest_asyncio
|
|
28
|
+
|
|
29
|
+
if typing.TYPE_CHECKING:
|
|
30
|
+
import langsmith.client
|
|
31
|
+
|
|
32
|
+
from docker.client import DockerClient
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def pytest_addoption(parser: pytest.Parser):
|
|
36
|
+
"""
|
|
37
|
+
Adds command line options for running specfic tests that are disabled by default
|
|
38
|
+
"""
|
|
39
|
+
parser.addoption(
|
|
40
|
+
"--run_integration",
|
|
41
|
+
action="store_true",
|
|
42
|
+
dest="run_integration",
|
|
43
|
+
help=("Run integrations tests that would otherwise be skipped. "
|
|
44
|
+
"This will call out to external services instead of using mocks"),
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
parser.addoption(
|
|
48
|
+
"--run_slow",
|
|
49
|
+
action="store_true",
|
|
50
|
+
dest="run_slow",
|
|
51
|
+
help="Run end to end tests that would otherwise be skipped",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
parser.addoption(
|
|
55
|
+
"--fail_missing",
|
|
56
|
+
action="store_true",
|
|
57
|
+
dest="fail_missing",
|
|
58
|
+
help=("Tests requiring unmet dependencies are normally skipped. "
|
|
59
|
+
"Setting this flag will instead cause them to be reported as a failure"),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def pytest_runtest_setup(item):
|
|
64
|
+
if (not item.config.getoption("--run_integration")):
|
|
65
|
+
if (item.get_closest_marker("integration") is not None):
|
|
66
|
+
pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
|
|
67
|
+
|
|
68
|
+
if (not item.config.getoption("--run_slow")):
|
|
69
|
+
if (item.get_closest_marker("slow") is not None):
|
|
70
|
+
pytest.skip("Skipping slow tests by default. Use --run_slow to enable")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest.fixture(name="register_components", scope="session", autouse=True)
|
|
74
|
+
def register_components_fixture():
|
|
75
|
+
from nat.runtime.loader import PluginTypes
|
|
76
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
77
|
+
|
|
78
|
+
# Ensure that all components which need to be registered as part of an import are done so. This is necessary
|
|
79
|
+
# because imports will not be reloaded between tests, so we need to ensure that all components are registered
|
|
80
|
+
# before any tests are run.
|
|
81
|
+
discover_and_register_plugins(PluginTypes.ALL)
|
|
82
|
+
|
|
83
|
+
# Also import the nat.test.register module to register test-only components
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@pytest.fixture(name="module_registry", scope="module", autouse=True)
|
|
87
|
+
def module_registry_fixture():
|
|
88
|
+
"""
|
|
89
|
+
Resets and returns the global type registry for testing
|
|
90
|
+
|
|
91
|
+
This gets automatically used at the module level to ensure no state is leaked between modules
|
|
92
|
+
"""
|
|
93
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
94
|
+
|
|
95
|
+
with GlobalTypeRegistry.push() as registry:
|
|
96
|
+
yield registry
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.fixture(name="registry", scope="function", autouse=True)
|
|
100
|
+
def function_registry_fixture():
|
|
101
|
+
"""
|
|
102
|
+
Resets and returns the global type registry for testing
|
|
103
|
+
|
|
104
|
+
This gets automatically used at the function level to ensure no state is leaked between functions
|
|
105
|
+
"""
|
|
106
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
107
|
+
|
|
108
|
+
with GlobalTypeRegistry.push() as registry:
|
|
109
|
+
yield registry
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.fixture(scope="session", name="fail_missing")
|
|
113
|
+
def fail_missing_fixture(pytestconfig: pytest.Config) -> bool:
|
|
114
|
+
"""
|
|
115
|
+
Returns the value of the `fail_missing` flag, when false tests requiring unmet dependencies will be skipped, when
|
|
116
|
+
True they will fail.
|
|
117
|
+
"""
|
|
118
|
+
yield pytestconfig.getoption("fail_missing")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def require_env_variables(varnames: list[str], reason: str, fail_missing: bool = False) -> dict[str, str]:
|
|
122
|
+
"""
|
|
123
|
+
Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and
|
|
124
|
+
`fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised.
|
|
125
|
+
"""
|
|
126
|
+
env_variables = {}
|
|
127
|
+
try:
|
|
128
|
+
for varname in varnames:
|
|
129
|
+
env_variables[varname] = os.environ[varname]
|
|
130
|
+
except KeyError as e:
|
|
131
|
+
if fail_missing:
|
|
132
|
+
raise RuntimeError(reason) from e
|
|
133
|
+
|
|
134
|
+
pytest.skip(reason=reason)
|
|
135
|
+
|
|
136
|
+
return env_variables
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.fixture(name="openai_api_key", scope='session')
|
|
140
|
+
def openai_api_key_fixture(fail_missing: bool):
|
|
141
|
+
"""
|
|
142
|
+
Use for integration tests that require an Openai API key.
|
|
143
|
+
"""
|
|
144
|
+
yield require_env_variables(
|
|
145
|
+
varnames=["OPENAI_API_KEY"],
|
|
146
|
+
reason="openai integration tests require the `OPENAI_API_KEY` environment variable to be defined.",
|
|
147
|
+
fail_missing=fail_missing)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@pytest.fixture(name="nvidia_api_key", scope='session')
|
|
151
|
+
def nvidia_api_key_fixture(fail_missing: bool):
|
|
152
|
+
"""
|
|
153
|
+
Use for integration tests that require an Nvidia API key.
|
|
154
|
+
"""
|
|
155
|
+
yield require_env_variables(
|
|
156
|
+
varnames=["NVIDIA_API_KEY"],
|
|
157
|
+
reason="Nvidia integration tests require the `NVIDIA_API_KEY` environment variable to be defined.",
|
|
158
|
+
fail_missing=fail_missing)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@pytest.fixture(name="serp_api_key", scope='session')
|
|
162
|
+
def serp_api_key_fixture(fail_missing: bool):
|
|
163
|
+
"""
|
|
164
|
+
Use for integration tests that require a SERP API (serpapi.com) key.
|
|
165
|
+
"""
|
|
166
|
+
yield require_env_variables(
|
|
167
|
+
varnames=["SERP_API_KEY"],
|
|
168
|
+
reason="SERP integration tests require the `SERP_API_KEY` environment variable to be defined.",
|
|
169
|
+
fail_missing=fail_missing)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@pytest.fixture(name="serperdev", scope='session')
|
|
173
|
+
def serperdev_api_key_fixture(fail_missing: bool):
|
|
174
|
+
"""
|
|
175
|
+
Use for integration tests that require a Serper Dev API (https://serper.dev) key.
|
|
176
|
+
"""
|
|
177
|
+
yield require_env_variables(
|
|
178
|
+
varnames=["SERPERDEV_API_KEY"],
|
|
179
|
+
reason="SERPERDEV integration tests require the `SERPERDEV_API_KEY` environment variable to be defined.",
|
|
180
|
+
fail_missing=fail_missing)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@pytest.fixture(name="tavily_api_key", scope='session')
|
|
184
|
+
def tavily_api_key_fixture(fail_missing: bool):
|
|
185
|
+
"""
|
|
186
|
+
Use for integration tests that require a Tavily API key.
|
|
187
|
+
"""
|
|
188
|
+
yield require_env_variables(
|
|
189
|
+
varnames=["TAVILY_API_KEY"],
|
|
190
|
+
reason="Tavily integration tests require the `TAVILY_API_KEY` environment variable to be defined.",
|
|
191
|
+
fail_missing=fail_missing)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@pytest.fixture(name="mem0_api_key", scope='session')
|
|
195
|
+
def mem0_api_key_fixture(fail_missing: bool):
|
|
196
|
+
"""
|
|
197
|
+
Use for integration tests that require a Mem0 API key.
|
|
198
|
+
"""
|
|
199
|
+
yield require_env_variables(
|
|
200
|
+
varnames=["MEM0_API_KEY"],
|
|
201
|
+
reason="Mem0 integration tests require the `MEM0_API_KEY` environment variable to be defined.",
|
|
202
|
+
fail_missing=fail_missing)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.fixture(name="aws_keys", scope='session')
|
|
206
|
+
def aws_keys_fixture(fail_missing: bool):
|
|
207
|
+
"""
|
|
208
|
+
Use for integration tests that require AWS credentials.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
yield require_env_variables(
|
|
212
|
+
varnames=["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
|
213
|
+
reason=
|
|
214
|
+
"AWS integration tests require the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be "
|
|
215
|
+
"defined.",
|
|
216
|
+
fail_missing=fail_missing)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.fixture(name="azure_openai_keys", scope='session')
|
|
220
|
+
def azure_openai_keys_fixture(fail_missing: bool):
|
|
221
|
+
"""
|
|
222
|
+
Use for integration tests that require Azure OpenAI credentials.
|
|
223
|
+
"""
|
|
224
|
+
yield require_env_variables(
|
|
225
|
+
varnames=["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
|
|
226
|
+
reason="Azure integration tests require the `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` environment "
|
|
227
|
+
"variables to be defined.",
|
|
228
|
+
fail_missing=fail_missing)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@pytest.fixture(name="langfuse_keys", scope='session')
|
|
232
|
+
def langfuse_keys_fixture(fail_missing: bool):
|
|
233
|
+
"""
|
|
234
|
+
Use for integration tests that require Langfuse credentials.
|
|
235
|
+
"""
|
|
236
|
+
yield require_env_variables(
|
|
237
|
+
varnames=["LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"],
|
|
238
|
+
reason="Langfuse integration tests require the `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` environment "
|
|
239
|
+
"variables to be defined.",
|
|
240
|
+
fail_missing=fail_missing)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@pytest.fixture(name="wandb_api_key", scope='session')
|
|
244
|
+
def wandb_api_key_fixture(fail_missing: bool):
|
|
245
|
+
"""
|
|
246
|
+
Use for integration tests that require a Weights & Biases API key.
|
|
247
|
+
"""
|
|
248
|
+
yield require_env_variables(
|
|
249
|
+
varnames=["WANDB_API_KEY"],
|
|
250
|
+
reason="Weights & Biases integration tests require the `WANDB_API_KEY` environment variable to be defined.",
|
|
251
|
+
fail_missing=fail_missing)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.fixture(name="weave", scope='session')
|
|
255
|
+
def require_weave_fixture(fail_missing: bool) -> types.ModuleType:
|
|
256
|
+
"""
|
|
257
|
+
Use for integration tests that require Weave to be running.
|
|
258
|
+
"""
|
|
259
|
+
try:
|
|
260
|
+
import weave
|
|
261
|
+
return weave
|
|
262
|
+
except Exception as e:
|
|
263
|
+
reason = "Weave must be installed to run weave based tests"
|
|
264
|
+
if fail_missing:
|
|
265
|
+
raise RuntimeError(reason) from e
|
|
266
|
+
pytest.skip(reason=reason)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@pytest.fixture(name="langsmith_api_key", scope='session')
|
|
270
|
+
def langsmith_api_key_fixture(fail_missing: bool):
|
|
271
|
+
"""
|
|
272
|
+
Use for integration tests that require a LangSmith API key.
|
|
273
|
+
"""
|
|
274
|
+
yield require_env_variables(
|
|
275
|
+
varnames=["LANGSMITH_API_KEY"],
|
|
276
|
+
reason="LangSmith integration tests require the `LANGSMITH_API_KEY` environment variable to be defined.",
|
|
277
|
+
fail_missing=fail_missing)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@pytest.fixture(name="langsmith_client")
|
|
281
|
+
def langsmith_client_fixture(langsmith_api_key: str, fail_missing: bool) -> "langsmith.client.Client":
|
|
282
|
+
try:
|
|
283
|
+
import langsmith.client
|
|
284
|
+
client = langsmith.client.Client()
|
|
285
|
+
return client
|
|
286
|
+
except ImportError:
|
|
287
|
+
reason = "LangSmith integration tests require the `langsmith` package to be installed."
|
|
288
|
+
if fail_missing:
|
|
289
|
+
raise RuntimeError(reason)
|
|
290
|
+
pytest.skip(reason=reason)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@pytest.fixture(name="langsmith_project_name")
|
|
294
|
+
def langsmith_project_name_fixture(langsmith_client: "langsmith.client.Client") -> Generator[str]:
|
|
295
|
+
# Createa a unique project name for each test run
|
|
296
|
+
project_name = f"nat-e2e-test-{time.time()}-{random.random()}"
|
|
297
|
+
langsmith_client.create_project(project_name)
|
|
298
|
+
yield project_name
|
|
299
|
+
|
|
300
|
+
langsmith_client.delete_project(project_name=project_name)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@pytest.fixture(name="require_docker", scope='session')
|
|
304
|
+
def require_docker_fixture(fail_missing: bool) -> "DockerClient":
|
|
305
|
+
"""
|
|
306
|
+
Use for integration tests that require Docker to be running.
|
|
307
|
+
"""
|
|
308
|
+
try:
|
|
309
|
+
from docker.client import DockerClient
|
|
310
|
+
yield DockerClient()
|
|
311
|
+
except Exception as e:
|
|
312
|
+
reason = f"Unable to connect to Docker daemon: {e}"
|
|
313
|
+
if fail_missing:
|
|
314
|
+
raise RuntimeError(reason) from e
|
|
315
|
+
pytest.skip(reason=reason)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@pytest.fixture(name="restore_environ")
|
|
319
|
+
def restore_environ_fixture():
|
|
320
|
+
orig_vars = os.environ.copy()
|
|
321
|
+
yield os.environ
|
|
322
|
+
|
|
323
|
+
for key, value in orig_vars.items():
|
|
324
|
+
os.environ[key] = value
|
|
325
|
+
|
|
326
|
+
# Delete any new environment variables
|
|
327
|
+
# Iterating over a copy of the keys as we will potentially be deleting keys in the loop
|
|
328
|
+
for key in list(os.environ.keys()):
|
|
329
|
+
if key not in orig_vars:
|
|
330
|
+
del (os.environ[key])
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@pytest.fixture(name="root_repo_dir", scope='session')
|
|
334
|
+
def root_repo_dir_fixture() -> Path:
|
|
335
|
+
from nat.test.utils import locate_repo_root
|
|
336
|
+
return locate_repo_root()
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@pytest.fixture(name="examples_dir", scope='session')
|
|
340
|
+
def examples_dir_fixture(root_repo_dir: Path) -> Path:
|
|
341
|
+
return root_repo_dir / "examples"
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@pytest.fixture(name="env_without_nat_log_level", scope='function')
|
|
345
|
+
def env_without_nat_log_level_fixture() -> dict[str, str]:
|
|
346
|
+
env = os.environ.copy()
|
|
347
|
+
env.pop("NAT_LOG_LEVEL", None)
|
|
348
|
+
return env
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@pytest.fixture(name="etcd_url", scope="session")
|
|
352
|
+
def etcd_url_fixture(fail_missing: bool = False) -> str:
|
|
353
|
+
"""
|
|
354
|
+
To run these tests, an etcd server must be running
|
|
355
|
+
"""
|
|
356
|
+
import requests
|
|
357
|
+
|
|
358
|
+
host = os.getenv("NAT_CI_ETCD_HOST", "localhost")
|
|
359
|
+
port = os.getenv("NAT_CI_ETCD_PORT", "2379")
|
|
360
|
+
url = f"http://{host}:{port}"
|
|
361
|
+
health_url = f"{url}/health"
|
|
362
|
+
|
|
363
|
+
try:
|
|
364
|
+
response = requests.get(health_url, timeout=5)
|
|
365
|
+
response.raise_for_status()
|
|
366
|
+
return url
|
|
367
|
+
except: # noqa: E722
|
|
368
|
+
failure_reason = f"Unable to connect to etcd server at {url}"
|
|
369
|
+
if fail_missing:
|
|
370
|
+
raise RuntimeError(failure_reason)
|
|
371
|
+
pytest.skip(reason=failure_reason)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
@pytest.fixture(name="milvus_uri", scope="session")
|
|
375
|
+
def milvus_uri_fixture(etcd_url: str, fail_missing: bool = False) -> str:
|
|
376
|
+
"""
|
|
377
|
+
To run these tests, a Milvus server must be running
|
|
378
|
+
"""
|
|
379
|
+
host = os.getenv("NAT_CI_MILVUS_HOST", "localhost")
|
|
380
|
+
port = os.getenv("NAT_CI_MILVUS_PORT", "19530")
|
|
381
|
+
uri = f"http://{host}:{port}"
|
|
382
|
+
try:
|
|
383
|
+
from pymilvus import MilvusClient
|
|
384
|
+
MilvusClient(uri=uri)
|
|
385
|
+
|
|
386
|
+
return uri
|
|
387
|
+
except: # noqa: E722
|
|
388
|
+
reason = f"Unable to connect to Milvus server at {uri}"
|
|
389
|
+
if fail_missing:
|
|
390
|
+
raise RuntimeError(reason)
|
|
391
|
+
pytest.skip(reason=reason)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@pytest.fixture(name="populate_milvus", scope="session")
|
|
395
|
+
def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path):
|
|
396
|
+
"""
|
|
397
|
+
Populate Milvus with some test data.
|
|
398
|
+
"""
|
|
399
|
+
populate_script = root_repo_dir / "scripts/langchain_web_ingest.py"
|
|
400
|
+
|
|
401
|
+
# Ingest default cuda docs
|
|
402
|
+
subprocess.run(["python", str(populate_script), "--milvus_uri", milvus_uri], check=True)
|
|
403
|
+
|
|
404
|
+
# Ingest MCP docs
|
|
405
|
+
subprocess.run([
|
|
406
|
+
"python",
|
|
407
|
+
str(populate_script),
|
|
408
|
+
"--milvus_uri",
|
|
409
|
+
milvus_uri,
|
|
410
|
+
"--urls",
|
|
411
|
+
"https://github.com/modelcontextprotocol/python-sdk",
|
|
412
|
+
"--urls",
|
|
413
|
+
"https://modelcontextprotocol.io/introduction",
|
|
414
|
+
"--urls",
|
|
415
|
+
"https://modelcontextprotocol.io/quickstart/server",
|
|
416
|
+
"--urls",
|
|
417
|
+
"https://modelcontextprotocol.io/quickstart/client",
|
|
418
|
+
"--urls",
|
|
419
|
+
"https://modelcontextprotocol.io/examples",
|
|
420
|
+
"--urls",
|
|
421
|
+
"https://modelcontextprotocol.io/docs/concepts/architecture",
|
|
422
|
+
"--collection_name",
|
|
423
|
+
"mcp_docs"
|
|
424
|
+
],
|
|
425
|
+
check=True)
|
|
426
|
+
|
|
427
|
+
# Ingest some wikipedia docs
|
|
428
|
+
subprocess.run([
|
|
429
|
+
"python",
|
|
430
|
+
str(populate_script),
|
|
431
|
+
"--milvus_uri",
|
|
432
|
+
milvus_uri,
|
|
433
|
+
"--urls",
|
|
434
|
+
"https://en.wikipedia.org/wiki/Aardvark",
|
|
435
|
+
"--collection_name",
|
|
436
|
+
"wikipedia_docs"
|
|
437
|
+
],
|
|
438
|
+
check=True)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@pytest.fixture(name="require_nest_asyncio", scope="session")
|
|
442
|
+
def require_nest_asyncio_fixture():
|
|
443
|
+
"""
|
|
444
|
+
Some tests require nest_asyncio to be installed to allow nested event loops, calling nest_asyncio.apply() more than
|
|
445
|
+
once is a no-op so it's safe to call this fixture even if one of our dependencies already called it.
|
|
446
|
+
"""
|
|
447
|
+
import nest_asyncio
|
|
448
|
+
nest_asyncio.apply()
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@pytest.fixture(name="phoenix_url", scope="session")
|
|
452
|
+
def phoenix_url_fixture(fail_missing: bool) -> str:
|
|
453
|
+
"""
|
|
454
|
+
To run these tests, a phoenix server must be running.
|
|
455
|
+
The phoenix server can be started by running the following command:
|
|
456
|
+
docker run -p 6006:6006 -p 4317:4317 arizephoenix/phoenix:latest
|
|
457
|
+
"""
|
|
458
|
+
import requests
|
|
459
|
+
|
|
460
|
+
url = os.getenv("NAT_CI_PHOENIX_URL", "http://localhost:6006")
|
|
461
|
+
try:
|
|
462
|
+
response = requests.get(url, timeout=5)
|
|
463
|
+
response.raise_for_status()
|
|
464
|
+
|
|
465
|
+
return url
|
|
466
|
+
except Exception as e:
|
|
467
|
+
reason = f"Unable to connect to Phoenix server at {url}: {e}"
|
|
468
|
+
if fail_missing:
|
|
469
|
+
raise RuntimeError(reason)
|
|
470
|
+
pytest.skip(reason=reason)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@pytest.fixture(name="phoenix_trace_url", scope="session")
|
|
474
|
+
def phoenix_trace_url_fixture(phoenix_url: str) -> str:
|
|
475
|
+
"""
|
|
476
|
+
Some of our tools expect the base url provided by the phoenix_url fixture, however the
|
|
477
|
+
general.telemetry.tracing["phoenix"].endpoint expects the trace url which is what this fixture provides.
|
|
478
|
+
"""
|
|
479
|
+
return f"{phoenix_url}/v1/traces"
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
@pytest.fixture(name="redis_server", scope="session")
|
|
483
|
+
def fixture_redis_server(fail_missing: bool) -> Generator[dict[str, str | int]]:
|
|
484
|
+
"""Fixture to safely skip redis based tests if redis is not running"""
|
|
485
|
+
host = os.environ.get("NAT_CI_REDIS_HOST", "localhost")
|
|
486
|
+
port = int(os.environ.get("NAT_CI_REDIS_PORT", "6379"))
|
|
487
|
+
db = int(os.environ.get("NAT_CI_REDIS_DB", "0"))
|
|
488
|
+
bucket_name = os.environ.get("NAT_CI_REDIS_BUCKET_NAME", "test")
|
|
489
|
+
|
|
490
|
+
try:
|
|
491
|
+
import redis
|
|
492
|
+
client = redis.Redis(host=host, port=port, db=db)
|
|
493
|
+
if not client.ping():
|
|
494
|
+
raise RuntimeError("Failed to connect to Redis")
|
|
495
|
+
yield {"host": host, "port": port, "db": db, "bucket_name": bucket_name}
|
|
496
|
+
except ImportError:
|
|
497
|
+
if fail_missing:
|
|
498
|
+
raise
|
|
499
|
+
pytest.skip("redis not installed, skipping redis tests")
|
|
500
|
+
except Exception as e:
|
|
501
|
+
if fail_missing:
|
|
502
|
+
raise
|
|
503
|
+
pytest.skip(f"Error connecting to Redis server: {e}, skipping redis tests")
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
@pytest_asyncio.fixture(name="mysql_server", scope="session")
|
|
507
|
+
async def fixture_mysql_server(fail_missing: bool) -> AsyncGenerator[dict[str, str | int]]:
|
|
508
|
+
"""Fixture to safely skip MySQL based tests if MySQL is not running"""
|
|
509
|
+
host = os.environ.get('NAT_CI_MYSQL_HOST', '127.0.0.1')
|
|
510
|
+
port = int(os.environ.get('NAT_CI_MYSQL_PORT', '3306'))
|
|
511
|
+
user = os.environ.get('NAT_CI_MYSQL_USER', 'root')
|
|
512
|
+
password = os.environ.get('MYSQL_ROOT_PASSWORD', 'my_password')
|
|
513
|
+
bucket_name = os.environ.get('NAT_CI_MYSQL_BUCKET_NAME', 'test')
|
|
514
|
+
try:
|
|
515
|
+
import aiomysql
|
|
516
|
+
conn = await aiomysql.connect(host=host, port=port, user=user, password=password)
|
|
517
|
+
yield {"host": host, "port": port, "username": user, "password": password, "bucket_name": bucket_name}
|
|
518
|
+
conn.close()
|
|
519
|
+
except ImportError:
|
|
520
|
+
if fail_missing:
|
|
521
|
+
raise
|
|
522
|
+
pytest.skip("aiomysql not installed, skipping MySQL tests")
|
|
523
|
+
except Exception as e:
|
|
524
|
+
if fail_missing:
|
|
525
|
+
raise
|
|
526
|
+
pytest.skip(f"Error connecting to MySQL server: {e}, skipping MySQL tests")
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
@pytest.fixture(name="minio_server", scope="session")
|
|
530
|
+
def minio_server_fixture(fail_missing: bool) -> Generator[dict[str, str | int]]:
|
|
531
|
+
"""Fixture to safely skip MinIO based tests if MinIO is not running"""
|
|
532
|
+
host = os.getenv("NAT_CI_MINIO_HOST", "localhost")
|
|
533
|
+
port = int(os.getenv("NAT_CI_MINIO_PORT", "9000"))
|
|
534
|
+
bucket_name = os.getenv("NAT_CI_MINIO_BUCKET_NAME", "test")
|
|
535
|
+
aws_access_key_id = os.getenv("NAT_CI_MINIO_ACCESS_KEY_ID", "minioadmin")
|
|
536
|
+
aws_secret_access_key = os.getenv("NAT_CI_MINIO_SECRET_ACCESS_KEY", "minioadmin")
|
|
537
|
+
endpoint_url = f"http://{host}:{port}"
|
|
538
|
+
|
|
539
|
+
minio_info = {
|
|
540
|
+
"host": host,
|
|
541
|
+
"port": port,
|
|
542
|
+
"bucket_name": bucket_name,
|
|
543
|
+
"endpoint_url": endpoint_url,
|
|
544
|
+
"aws_access_key_id": aws_access_key_id,
|
|
545
|
+
"aws_secret_access_key": aws_secret_access_key,
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
try:
|
|
549
|
+
import botocore.session
|
|
550
|
+
session = botocore.session.get_session()
|
|
551
|
+
|
|
552
|
+
client = session.create_client("s3",
|
|
553
|
+
aws_access_key_id=aws_access_key_id,
|
|
554
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
555
|
+
endpoint_url=endpoint_url)
|
|
556
|
+
client.list_buckets()
|
|
557
|
+
yield minio_info
|
|
558
|
+
except ImportError:
|
|
559
|
+
if fail_missing:
|
|
560
|
+
raise
|
|
561
|
+
pytest.skip("aioboto3 not installed, skipping MinIO tests")
|
|
562
|
+
except Exception as e:
|
|
563
|
+
if fail_missing:
|
|
564
|
+
raise
|
|
565
|
+
else:
|
|
566
|
+
pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
@pytest.fixture(name="langfuse_bucket", scope="session")
|
|
570
|
+
def langfuse_bucket_fixture(fail_missing: bool, minio_server: dict[str, str | int]) -> Generator[str]:
|
|
571
|
+
|
|
572
|
+
bucket_name = os.getenv("NAT_CI_LANGFUSE_BUCKET", "langfuse")
|
|
573
|
+
try:
|
|
574
|
+
import botocore.session
|
|
575
|
+
session = botocore.session.get_session()
|
|
576
|
+
|
|
577
|
+
client = session.create_client("s3",
|
|
578
|
+
aws_access_key_id=minio_server["aws_access_key_id"],
|
|
579
|
+
aws_secret_access_key=minio_server["aws_secret_access_key"],
|
|
580
|
+
endpoint_url=minio_server["endpoint_url"])
|
|
581
|
+
|
|
582
|
+
buckets = client.list_buckets()
|
|
583
|
+
bucket_names = [b['Name'] for b in buckets['Buckets']]
|
|
584
|
+
if bucket_name not in bucket_names:
|
|
585
|
+
client.create_bucket(Bucket=bucket_name)
|
|
586
|
+
|
|
587
|
+
yield bucket_name
|
|
588
|
+
except ImportError:
|
|
589
|
+
if fail_missing:
|
|
590
|
+
raise
|
|
591
|
+
pytest.skip("aioboto3 not installed, skipping MinIO tests")
|
|
592
|
+
except Exception as e:
|
|
593
|
+
if fail_missing:
|
|
594
|
+
raise
|
|
595
|
+
else:
|
|
596
|
+
pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
@pytest.fixture(name="langfuse_url", scope="session")
|
|
600
|
+
def langfuse_url_fixture(fail_missing: bool, langfuse_bucket: str) -> str:
|
|
601
|
+
"""
|
|
602
|
+
To run these tests, a langfuse server must be running.
|
|
603
|
+
"""
|
|
604
|
+
import requests
|
|
605
|
+
|
|
606
|
+
host = os.getenv("NAT_CI_LANGFUSE_HOST", "localhost")
|
|
607
|
+
port = int(os.getenv("NAT_CI_LANGFUSE_PORT", "3000"))
|
|
608
|
+
url = f"http://{host}:{port}"
|
|
609
|
+
health_endpoint = f"{url}/api/public/health"
|
|
610
|
+
try:
|
|
611
|
+
response = requests.get(health_endpoint, timeout=5)
|
|
612
|
+
response.raise_for_status()
|
|
613
|
+
|
|
614
|
+
return url
|
|
615
|
+
except Exception as e:
|
|
616
|
+
reason = f"Unable to connect to Langfuse server at {url}: {e}"
|
|
617
|
+
if fail_missing:
|
|
618
|
+
raise RuntimeError(reason)
|
|
619
|
+
pytest.skip(reason=reason)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
@pytest.fixture(name="langfuse_trace_url", scope="session")
|
|
623
|
+
def langfuse_trace_url_fixture(langfuse_url: str) -> str:
|
|
624
|
+
"""
|
|
625
|
+
The langfuse_url fixture provides the base url, however the general.telemetry.tracing["langfuse"].endpoint expects
|
|
626
|
+
the trace url which is what this fixture provides.
|
|
627
|
+
"""
|
|
628
|
+
return f"{langfuse_url}/api/public/otel/v1/traces"
|
nat/test/register.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
# flake8: noqa
|
|
17
|
+
# isort:skip_file
|
|
18
|
+
|
|
19
|
+
# Import any providers which need to be automatically registered here
|
|
20
|
+
|
|
21
|
+
from . import embedder
|
|
22
|
+
from . import functions
|
|
23
|
+
from . import memory
|
|
24
|
+
from . import llm
|
|
25
|
+
from . import utils
|