nvidia-nat-test 1.4.0a20260117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/meta/pypi.md +23 -0
- nat/test/__init__.py +23 -0
- nat/test/embedder.py +44 -0
- nat/test/functions.py +99 -0
- nat/test/llm.py +244 -0
- nat/test/memory.py +41 -0
- nat/test/object_store_tests.py +117 -0
- nat/test/plugin.py +890 -0
- nat/test/register.py +25 -0
- nat/test/tool_test_runner.py +612 -0
- nat/test/utils.py +215 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/METADATA +46 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/RECORD +18 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/WHEEL +5 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/entry_points.txt +5 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE-3rd-party.txt +5478 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE.md +201 -0
- nvidia_nat_test-1.4.0a20260117.dist-info/top_level.txt +1 -0
nat/test/plugin.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
import subprocess
|
|
19
|
+
import time
|
|
20
|
+
import types
|
|
21
|
+
import typing
|
|
22
|
+
from collections.abc import AsyncGenerator
|
|
23
|
+
from collections.abc import Generator
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
|
|
26
|
+
import pytest
|
|
27
|
+
import pytest_asyncio
|
|
28
|
+
|
|
29
|
+
if typing.TYPE_CHECKING:
|
|
30
|
+
import galileo.log_streams
|
|
31
|
+
import galileo.projects
|
|
32
|
+
import langsmith.client
|
|
33
|
+
|
|
34
|
+
from docker.client import DockerClient
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def pytest_addoption(parser: pytest.Parser):
|
|
38
|
+
"""
|
|
39
|
+
Adds command line options for running specfic tests that are disabled by default
|
|
40
|
+
"""
|
|
41
|
+
parser.addoption(
|
|
42
|
+
"--run_integration",
|
|
43
|
+
action="store_true",
|
|
44
|
+
dest="run_integration",
|
|
45
|
+
help=("Run integrations tests that would otherwise be skipped. "
|
|
46
|
+
"This will call out to external services instead of using mocks"),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
parser.addoption(
|
|
50
|
+
"--run_slow",
|
|
51
|
+
action="store_true",
|
|
52
|
+
dest="run_slow",
|
|
53
|
+
help="Run end to end tests that would otherwise be skipped",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
parser.addoption(
|
|
57
|
+
"--fail_missing",
|
|
58
|
+
action="store_true",
|
|
59
|
+
dest="fail_missing",
|
|
60
|
+
help=("Tests requiring unmet dependencies are normally skipped. "
|
|
61
|
+
"Setting this flag will instead cause them to be reported as a failure"),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def pytest_runtest_setup(item):
|
|
66
|
+
if (not item.config.getoption("--run_integration")):
|
|
67
|
+
if (item.get_closest_marker("integration") is not None):
|
|
68
|
+
pytest.skip("Skipping integration tests by default. Use --run_integration to enable")
|
|
69
|
+
|
|
70
|
+
if (not item.config.getoption("--run_slow")):
|
|
71
|
+
if (item.get_closest_marker("slow") is not None):
|
|
72
|
+
pytest.skip("Skipping slow tests by default. Use --run_slow to enable")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@pytest.fixture(name="register_components", scope="session", autouse=True)
|
|
76
|
+
def register_components_fixture():
|
|
77
|
+
from nat.runtime.loader import PluginTypes
|
|
78
|
+
from nat.runtime.loader import discover_and_register_plugins
|
|
79
|
+
|
|
80
|
+
# Ensure that all components which need to be registered as part of an import are done so. This is necessary
|
|
81
|
+
# because imports will not be reloaded between tests, so we need to ensure that all components are registered
|
|
82
|
+
# before any tests are run.
|
|
83
|
+
discover_and_register_plugins(PluginTypes.ALL)
|
|
84
|
+
|
|
85
|
+
# Also import the nat.test.register module to register test-only components
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@pytest.fixture(name="module_registry", scope="module", autouse=True)
|
|
89
|
+
def module_registry_fixture():
|
|
90
|
+
"""
|
|
91
|
+
Resets and returns the global type registry for testing
|
|
92
|
+
|
|
93
|
+
This gets automatically used at the module level to ensure no state is leaked between modules
|
|
94
|
+
"""
|
|
95
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
96
|
+
|
|
97
|
+
with GlobalTypeRegistry.push() as registry:
|
|
98
|
+
yield registry
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.fixture(name="registry", scope="function", autouse=True)
|
|
102
|
+
def function_registry_fixture():
|
|
103
|
+
"""
|
|
104
|
+
Resets and returns the global type registry for testing
|
|
105
|
+
|
|
106
|
+
This gets automatically used at the function level to ensure no state is leaked between functions
|
|
107
|
+
"""
|
|
108
|
+
from nat.cli.type_registry import GlobalTypeRegistry
|
|
109
|
+
|
|
110
|
+
with GlobalTypeRegistry.push() as registry:
|
|
111
|
+
yield registry
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.fixture(scope="session", name="fail_missing")
|
|
115
|
+
def fail_missing_fixture(pytestconfig: pytest.Config) -> bool:
|
|
116
|
+
"""
|
|
117
|
+
Returns the value of the `fail_missing` flag, when false tests requiring unmet dependencies will be skipped, when
|
|
118
|
+
True they will fail.
|
|
119
|
+
"""
|
|
120
|
+
yield pytestconfig.getoption("fail_missing")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def require_env_variables(varnames: list[str], reason: str, fail_missing: bool = False) -> dict[str, str]:
|
|
124
|
+
"""
|
|
125
|
+
Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and
|
|
126
|
+
`fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised.
|
|
127
|
+
"""
|
|
128
|
+
env_variables = {}
|
|
129
|
+
try:
|
|
130
|
+
for varname in varnames:
|
|
131
|
+
env_variables[varname] = os.environ[varname]
|
|
132
|
+
except KeyError as e:
|
|
133
|
+
if fail_missing:
|
|
134
|
+
raise RuntimeError(reason) from e
|
|
135
|
+
|
|
136
|
+
pytest.skip(reason=reason)
|
|
137
|
+
|
|
138
|
+
return env_variables
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.fixture(name="openai_api_key", scope='session')
|
|
142
|
+
def openai_api_key_fixture(fail_missing: bool):
|
|
143
|
+
"""
|
|
144
|
+
Use for integration tests that require an Openai API key.
|
|
145
|
+
"""
|
|
146
|
+
yield require_env_variables(
|
|
147
|
+
varnames=["OPENAI_API_KEY"],
|
|
148
|
+
reason="openai integration tests require the `OPENAI_API_KEY` environment variable to be defined.",
|
|
149
|
+
fail_missing=fail_missing)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@pytest.fixture(name="nvidia_api_key", scope='session')
|
|
153
|
+
def nvidia_api_key_fixture(fail_missing: bool):
|
|
154
|
+
"""
|
|
155
|
+
Use for integration tests that require an Nvidia API key.
|
|
156
|
+
"""
|
|
157
|
+
yield require_env_variables(
|
|
158
|
+
varnames=["NVIDIA_API_KEY"],
|
|
159
|
+
reason="Nvidia integration tests require the `NVIDIA_API_KEY` environment variable to be defined.",
|
|
160
|
+
fail_missing=fail_missing)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@pytest.fixture(name="serp_api_key", scope='session')
|
|
164
|
+
def serp_api_key_fixture(fail_missing: bool):
|
|
165
|
+
"""
|
|
166
|
+
Use for integration tests that require a SERP API (serpapi.com) key.
|
|
167
|
+
"""
|
|
168
|
+
yield require_env_variables(
|
|
169
|
+
varnames=["SERP_API_KEY"],
|
|
170
|
+
reason="SERP integration tests require the `SERP_API_KEY` environment variable to be defined.",
|
|
171
|
+
fail_missing=fail_missing)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@pytest.fixture(name="serperdev", scope='session')
|
|
175
|
+
def serperdev_api_key_fixture(fail_missing: bool):
|
|
176
|
+
"""
|
|
177
|
+
Use for integration tests that require a Serper Dev API (https://serper.dev) key.
|
|
178
|
+
"""
|
|
179
|
+
yield require_env_variables(
|
|
180
|
+
varnames=["SERPERDEV_API_KEY"],
|
|
181
|
+
reason="SERPERDEV integration tests require the `SERPERDEV_API_KEY` environment variable to be defined.",
|
|
182
|
+
fail_missing=fail_missing)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@pytest.fixture(name="tavily_api_key", scope='session')
|
|
186
|
+
def tavily_api_key_fixture(fail_missing: bool):
|
|
187
|
+
"""
|
|
188
|
+
Use for integration tests that require a Tavily API key.
|
|
189
|
+
"""
|
|
190
|
+
yield require_env_variables(
|
|
191
|
+
varnames=["TAVILY_API_KEY"],
|
|
192
|
+
reason="Tavily integration tests require the `TAVILY_API_KEY` environment variable to be defined.",
|
|
193
|
+
fail_missing=fail_missing)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@pytest.fixture(name="mem0_api_key", scope='session')
|
|
197
|
+
def mem0_api_key_fixture(fail_missing: bool):
|
|
198
|
+
"""
|
|
199
|
+
Use for integration tests that require a Mem0 API key.
|
|
200
|
+
"""
|
|
201
|
+
yield require_env_variables(
|
|
202
|
+
varnames=["MEM0_API_KEY"],
|
|
203
|
+
reason="Mem0 integration tests require the `MEM0_API_KEY` environment variable to be defined.",
|
|
204
|
+
fail_missing=fail_missing)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@pytest.fixture(name="aws_keys", scope='session')
|
|
208
|
+
def aws_keys_fixture(fail_missing: bool):
|
|
209
|
+
"""
|
|
210
|
+
Use for integration tests that require AWS credentials.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
yield require_env_variables(
|
|
214
|
+
varnames=["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
|
215
|
+
reason=
|
|
216
|
+
"AWS integration tests require the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables to be "
|
|
217
|
+
"defined.",
|
|
218
|
+
fail_missing=fail_missing)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@pytest.fixture(name="azure_openai_keys", scope='session')
|
|
222
|
+
def azure_openai_keys_fixture(fail_missing: bool):
|
|
223
|
+
"""
|
|
224
|
+
Use for integration tests that require Azure OpenAI credentials.
|
|
225
|
+
"""
|
|
226
|
+
yield require_env_variables(
|
|
227
|
+
varnames=["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
|
|
228
|
+
reason="Azure integration tests require the `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` environment "
|
|
229
|
+
"variables to be defined.",
|
|
230
|
+
fail_missing=fail_missing)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@pytest.fixture(name="langfuse_keys", scope='session')
|
|
234
|
+
def langfuse_keys_fixture(fail_missing: bool):
|
|
235
|
+
"""
|
|
236
|
+
Use for integration tests that require Langfuse credentials.
|
|
237
|
+
"""
|
|
238
|
+
yield require_env_variables(
|
|
239
|
+
varnames=["LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"],
|
|
240
|
+
reason="Langfuse integration tests require the `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` environment "
|
|
241
|
+
"variables to be defined.",
|
|
242
|
+
fail_missing=fail_missing)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@pytest.fixture(name="wandb_api_key", scope='session')
|
|
246
|
+
def wandb_api_key_fixture(fail_missing: bool):
|
|
247
|
+
"""
|
|
248
|
+
Use for integration tests that require a Weights & Biases API key.
|
|
249
|
+
"""
|
|
250
|
+
yield require_env_variables(
|
|
251
|
+
varnames=["WANDB_API_KEY"],
|
|
252
|
+
reason="Weights & Biases integration tests require the `WANDB_API_KEY` environment variable to be defined.",
|
|
253
|
+
fail_missing=fail_missing)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@pytest.fixture(name="weave", scope='session')
|
|
257
|
+
def require_weave_fixture(fail_missing: bool) -> types.ModuleType:
|
|
258
|
+
"""
|
|
259
|
+
Use for integration tests that require Weave to be running.
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
import weave
|
|
263
|
+
return weave
|
|
264
|
+
except Exception as e:
|
|
265
|
+
reason = "Weave must be installed to run weave based tests"
|
|
266
|
+
if fail_missing:
|
|
267
|
+
raise RuntimeError(reason) from e
|
|
268
|
+
pytest.skip(reason=reason)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@pytest.fixture(name="langsmith_api_key", scope='session')
|
|
272
|
+
def langsmith_api_key_fixture(fail_missing: bool):
|
|
273
|
+
"""
|
|
274
|
+
Use for integration tests that require a LangSmith API key.
|
|
275
|
+
"""
|
|
276
|
+
yield require_env_variables(
|
|
277
|
+
varnames=["LANGSMITH_API_KEY"],
|
|
278
|
+
reason="LangSmith integration tests require the `LANGSMITH_API_KEY` environment variable to be defined.",
|
|
279
|
+
fail_missing=fail_missing)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@pytest.fixture(name="langsmith_client")
|
|
283
|
+
def langsmith_client_fixture(langsmith_api_key: str, fail_missing: bool) -> "langsmith.client.Client":
|
|
284
|
+
try:
|
|
285
|
+
import langsmith.client
|
|
286
|
+
client = langsmith.client.Client()
|
|
287
|
+
return client
|
|
288
|
+
except ImportError:
|
|
289
|
+
reason = "LangSmith integration tests require the `langsmith` package to be installed."
|
|
290
|
+
if fail_missing:
|
|
291
|
+
raise RuntimeError(reason)
|
|
292
|
+
pytest.skip(reason=reason)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@pytest.fixture(name="project_name")
|
|
296
|
+
def project_name_fixture() -> str:
|
|
297
|
+
# Create a unique project name for each test run
|
|
298
|
+
return f"nat-e2e-test-{time.time()}-{random.random()}"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@pytest.fixture(name="langsmith_project_name")
|
|
302
|
+
def langsmith_project_name_fixture(langsmith_client: "langsmith.client.Client", project_name: str) -> Generator[str]:
|
|
303
|
+
langsmith_client.create_project(project_name)
|
|
304
|
+
yield project_name
|
|
305
|
+
|
|
306
|
+
langsmith_client.delete_project(project_name=project_name)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@pytest.fixture(name="galileo_api_key", scope='session')
|
|
310
|
+
def galileo_api_key_fixture(fail_missing: bool):
|
|
311
|
+
"""
|
|
312
|
+
Use for integration tests that require a Galileo API key.
|
|
313
|
+
"""
|
|
314
|
+
yield require_env_variables(
|
|
315
|
+
varnames=["GALILEO_API_KEY"],
|
|
316
|
+
reason="Galileo integration tests require the `GALILEO_API_KEY` environment variable to be defined.",
|
|
317
|
+
fail_missing=fail_missing)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@pytest.fixture(name="galileo_project")
|
|
321
|
+
def galileo_project_fixture(galileo_api_key: str, fail_missing: bool,
|
|
322
|
+
project_name: str) -> Generator["galileo.projects.Project"]:
|
|
323
|
+
"""
|
|
324
|
+
Creates a unique Galileo project and deletes it after the test run.
|
|
325
|
+
"""
|
|
326
|
+
try:
|
|
327
|
+
import galileo.projects
|
|
328
|
+
project = galileo.projects.create_project(name=project_name)
|
|
329
|
+
yield project
|
|
330
|
+
|
|
331
|
+
galileo.projects.delete_project(id=project.id)
|
|
332
|
+
except ImportError as e:
|
|
333
|
+
reason = "Galileo integration tests require the `galileo` package to be installed."
|
|
334
|
+
if fail_missing:
|
|
335
|
+
raise RuntimeError(reason) from e
|
|
336
|
+
pytest.skip(reason=reason)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@pytest.fixture(name="galileo_log_stream")
|
|
340
|
+
def galileo_log_stream_fixture(galileo_project: "galileo.projects.Project") -> "galileo.log_streams.LogStream":
|
|
341
|
+
"""
|
|
342
|
+
Creates a Galileo log stream for integration tests.
|
|
343
|
+
|
|
344
|
+
The log stream is automatically deleted when the associated project is deleted.
|
|
345
|
+
"""
|
|
346
|
+
import galileo.log_streams
|
|
347
|
+
return galileo.log_streams.create_log_stream(project_id=galileo_project.id, name="test")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@pytest.fixture(name="catalyst_keys", scope='session')
|
|
351
|
+
def catalyst_keys_fixture(fail_missing: bool):
|
|
352
|
+
"""
|
|
353
|
+
Use for integration tests that require RagaAI Catalyst credentials.
|
|
354
|
+
"""
|
|
355
|
+
yield require_env_variables(
|
|
356
|
+
varnames=["CATALYST_ACCESS_KEY", "CATALYST_SECRET_KEY"],
|
|
357
|
+
reason="Catalyst integration tests require the `CATALYST_ACCESS_KEY` and `CATALYST_SECRET_KEY` environment "
|
|
358
|
+
"variables to be defined.",
|
|
359
|
+
fail_missing=fail_missing)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@pytest.fixture(name="catalyst_project_name")
|
|
363
|
+
def catalyst_project_name_fixture(catalyst_keys) -> str:
|
|
364
|
+
return os.environ.get("NAT_CI_CATALYST_PROJECT_NAME", "nat-e2e")
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@pytest.fixture(name="catalyst_dataset_name")
|
|
368
|
+
def catalyst_dataset_name_fixture(catalyst_project_name: str, project_name: str) -> str:
|
|
369
|
+
"""
|
|
370
|
+
We can't create and delete projects, but we can create and delete datasets, so use a unique dataset name
|
|
371
|
+
"""
|
|
372
|
+
dataset_name = project_name.replace('.', '-')
|
|
373
|
+
yield dataset_name
|
|
374
|
+
|
|
375
|
+
from ragaai_catalyst import Dataset
|
|
376
|
+
ds = Dataset(catalyst_project_name)
|
|
377
|
+
if dataset_name in ds.list_datasets():
|
|
378
|
+
ds.delete_dataset(dataset_name)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@pytest.fixture(name="require_docker", scope='session')
|
|
382
|
+
def require_docker_fixture(fail_missing: bool) -> "DockerClient":
|
|
383
|
+
"""
|
|
384
|
+
Use for integration tests that require Docker to be running.
|
|
385
|
+
"""
|
|
386
|
+
try:
|
|
387
|
+
from docker.client import DockerClient
|
|
388
|
+
yield DockerClient()
|
|
389
|
+
except Exception as e:
|
|
390
|
+
reason = f"Unable to connect to Docker daemon: {e}"
|
|
391
|
+
if fail_missing:
|
|
392
|
+
raise RuntimeError(reason) from e
|
|
393
|
+
pytest.skip(reason=reason)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
@pytest.fixture(name="restore_environ")
|
|
397
|
+
def restore_environ_fixture():
|
|
398
|
+
orig_vars = os.environ.copy()
|
|
399
|
+
yield os.environ
|
|
400
|
+
|
|
401
|
+
for key, value in orig_vars.items():
|
|
402
|
+
os.environ[key] = value
|
|
403
|
+
|
|
404
|
+
# Delete any new environment variables
|
|
405
|
+
# Iterating over a copy of the keys as we will potentially be deleting keys in the loop
|
|
406
|
+
for key in list(os.environ.keys()):
|
|
407
|
+
if key not in orig_vars:
|
|
408
|
+
del (os.environ[key])
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
@pytest.fixture(name="root_repo_dir", scope='session')
|
|
412
|
+
def root_repo_dir_fixture() -> Path:
|
|
413
|
+
from nat.test.utils import locate_repo_root
|
|
414
|
+
return locate_repo_root()
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
@pytest.fixture(name="examples_dir", scope='session')
|
|
418
|
+
def examples_dir_fixture(root_repo_dir: Path) -> Path:
|
|
419
|
+
return root_repo_dir / "examples"
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
@pytest.fixture(name="env_without_nat_log_level", scope='function')
|
|
423
|
+
def env_without_nat_log_level_fixture() -> dict[str, str]:
|
|
424
|
+
env = os.environ.copy()
|
|
425
|
+
env.pop("NAT_LOG_LEVEL", None)
|
|
426
|
+
return env
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
@pytest.fixture(name="etcd_url", scope="session")
|
|
430
|
+
def etcd_url_fixture(fail_missing: bool = False) -> str:
|
|
431
|
+
"""
|
|
432
|
+
To run these tests, an etcd server must be running
|
|
433
|
+
"""
|
|
434
|
+
import requests
|
|
435
|
+
|
|
436
|
+
host = os.getenv("NAT_CI_ETCD_HOST", "localhost")
|
|
437
|
+
port = os.getenv("NAT_CI_ETCD_PORT", "2379")
|
|
438
|
+
url = f"http://{host}:{port}"
|
|
439
|
+
health_url = f"{url}/health"
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
response = requests.get(health_url, timeout=5)
|
|
443
|
+
response.raise_for_status()
|
|
444
|
+
return url
|
|
445
|
+
except: # noqa: E722
|
|
446
|
+
failure_reason = f"Unable to connect to etcd server at {url}"
|
|
447
|
+
if fail_missing:
|
|
448
|
+
raise RuntimeError(failure_reason)
|
|
449
|
+
pytest.skip(reason=failure_reason)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@pytest.fixture(name="milvus_uri", scope="session")
|
|
453
|
+
def milvus_uri_fixture(etcd_url: str, fail_missing: bool = False) -> str:
|
|
454
|
+
"""
|
|
455
|
+
To run these tests, a Milvus server must be running
|
|
456
|
+
"""
|
|
457
|
+
host = os.getenv("NAT_CI_MILVUS_HOST", "localhost")
|
|
458
|
+
port = os.getenv("NAT_CI_MILVUS_PORT", "19530")
|
|
459
|
+
uri = f"http://{host}:{port}"
|
|
460
|
+
try:
|
|
461
|
+
from pymilvus import MilvusClient
|
|
462
|
+
MilvusClient(uri=uri)
|
|
463
|
+
|
|
464
|
+
return uri
|
|
465
|
+
except: # noqa: E722
|
|
466
|
+
reason = f"Unable to connect to Milvus server at {uri}"
|
|
467
|
+
if fail_missing:
|
|
468
|
+
raise RuntimeError(reason)
|
|
469
|
+
pytest.skip(reason=reason)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@pytest.fixture(name="populate_milvus", scope="session")
|
|
473
|
+
def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path):
|
|
474
|
+
"""
|
|
475
|
+
Populate Milvus with some test data.
|
|
476
|
+
"""
|
|
477
|
+
populate_script = root_repo_dir / "scripts/langchain_web_ingest.py"
|
|
478
|
+
|
|
479
|
+
# Ingest default cuda docs
|
|
480
|
+
subprocess.run(["python", str(populate_script), "--milvus_uri", milvus_uri], check=True)
|
|
481
|
+
|
|
482
|
+
# Ingest MCP docs
|
|
483
|
+
subprocess.run([
|
|
484
|
+
"python",
|
|
485
|
+
str(populate_script),
|
|
486
|
+
"--milvus_uri",
|
|
487
|
+
milvus_uri,
|
|
488
|
+
"--urls",
|
|
489
|
+
"https://github.com/modelcontextprotocol/python-sdk",
|
|
490
|
+
"--urls",
|
|
491
|
+
"https://modelcontextprotocol.io/introduction",
|
|
492
|
+
"--urls",
|
|
493
|
+
"https://modelcontextprotocol.io/quickstart/server",
|
|
494
|
+
"--urls",
|
|
495
|
+
"https://modelcontextprotocol.io/quickstart/client",
|
|
496
|
+
"--urls",
|
|
497
|
+
"https://modelcontextprotocol.io/examples",
|
|
498
|
+
"--urls",
|
|
499
|
+
"https://modelcontextprotocol.io/docs/concepts/architecture",
|
|
500
|
+
"--collection_name",
|
|
501
|
+
"mcp_docs"
|
|
502
|
+
],
|
|
503
|
+
check=True)
|
|
504
|
+
|
|
505
|
+
# Ingest some wikipedia docs
|
|
506
|
+
subprocess.run([
|
|
507
|
+
"python",
|
|
508
|
+
str(populate_script),
|
|
509
|
+
"--milvus_uri",
|
|
510
|
+
milvus_uri,
|
|
511
|
+
"--urls",
|
|
512
|
+
"https://en.wikipedia.org/wiki/Aardvark",
|
|
513
|
+
"--collection_name",
|
|
514
|
+
"wikipedia_docs"
|
|
515
|
+
],
|
|
516
|
+
check=True)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@pytest.fixture(name="require_nest_asyncio", scope="session", autouse=True)
|
|
520
|
+
def require_nest_asyncio_fixture():
|
|
521
|
+
"""
|
|
522
|
+
Some tests require the nest_asyncio2 patch to be applied to allow nested event loops, calling
|
|
523
|
+
`nest_asyncio2.apply()` more than once is a no-op. However we need to ensure that the nest_asyncio2 patch is
|
|
524
|
+
applied prior to the older nest_asyncio patch is applied. Requiring us to ensure that any library which will apply
|
|
525
|
+
the patch on import is lazily imported.
|
|
526
|
+
"""
|
|
527
|
+
import nest_asyncio2
|
|
528
|
+
try:
|
|
529
|
+
nest_asyncio2.apply(error_on_mispatched=True)
|
|
530
|
+
except RuntimeError as e:
|
|
531
|
+
raise RuntimeError(
|
|
532
|
+
"nest_asyncio2 fixture called but asyncio is already patched, most likely this is due to the nest_asyncio "
|
|
533
|
+
"being applied first, which is not compatible with Python 3.12+. Please ensure that any libraries which "
|
|
534
|
+
"apply nest_asyncio on import are lazily imported.") from e
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
@pytest.fixture(name="phoenix_url", scope="session")
|
|
538
|
+
def phoenix_url_fixture(fail_missing: bool) -> str:
|
|
539
|
+
"""
|
|
540
|
+
To run these tests, a phoenix server must be running.
|
|
541
|
+
The phoenix server can be started by running the following command:
|
|
542
|
+
docker run -p 6006:6006 -p 4317:4317 arizephoenix/phoenix:latest
|
|
543
|
+
"""
|
|
544
|
+
import requests
|
|
545
|
+
|
|
546
|
+
url = os.getenv("NAT_CI_PHOENIX_URL", "http://localhost:6006")
|
|
547
|
+
try:
|
|
548
|
+
response = requests.get(url, timeout=5)
|
|
549
|
+
response.raise_for_status()
|
|
550
|
+
|
|
551
|
+
return url
|
|
552
|
+
except Exception as e:
|
|
553
|
+
reason = f"Unable to connect to Phoenix server at {url}: {e}"
|
|
554
|
+
if fail_missing:
|
|
555
|
+
raise RuntimeError(reason)
|
|
556
|
+
pytest.skip(reason=reason)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
@pytest.fixture(name="phoenix_trace_url", scope="session")
|
|
560
|
+
def phoenix_trace_url_fixture(phoenix_url: str) -> str:
|
|
561
|
+
"""
|
|
562
|
+
Some of our tools expect the base url provided by the phoenix_url fixture, however the
|
|
563
|
+
general.telemetry.tracing["phoenix"].endpoint expects the trace url which is what this fixture provides.
|
|
564
|
+
"""
|
|
565
|
+
return f"{phoenix_url}/v1/traces"
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@pytest.fixture(name="redis_server", scope="session")
|
|
569
|
+
def fixture_redis_server(fail_missing: bool) -> Generator[dict[str, str | int]]:
|
|
570
|
+
"""Fixture to safely skip redis based tests if redis is not running"""
|
|
571
|
+
host = os.environ.get("NAT_CI_REDIS_HOST", "localhost")
|
|
572
|
+
port = int(os.environ.get("NAT_CI_REDIS_PORT", "6379"))
|
|
573
|
+
db = int(os.environ.get("NAT_CI_REDIS_DB", "0"))
|
|
574
|
+
password = os.environ.get("REDIS_PASSWORD", "redis")
|
|
575
|
+
bucket_name = os.environ.get("NAT_CI_REDIS_BUCKET_NAME", "test")
|
|
576
|
+
|
|
577
|
+
try:
|
|
578
|
+
import redis
|
|
579
|
+
client = redis.Redis(host=host, port=port, db=db, password=password)
|
|
580
|
+
if not client.ping():
|
|
581
|
+
raise RuntimeError("Failed to connect to Redis")
|
|
582
|
+
yield {"host": host, "port": port, "db": db, "bucket_name": bucket_name, "password": password}
|
|
583
|
+
except ImportError:
|
|
584
|
+
if fail_missing:
|
|
585
|
+
raise
|
|
586
|
+
pytest.skip("redis not installed, skipping redis tests")
|
|
587
|
+
except Exception as e:
|
|
588
|
+
if fail_missing:
|
|
589
|
+
raise
|
|
590
|
+
pytest.skip(f"Error connecting to Redis server: {e}, skipping redis tests")
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@pytest_asyncio.fixture(name="mysql_server", scope="session")
|
|
594
|
+
async def fixture_mysql_server(fail_missing: bool) -> AsyncGenerator[dict[str, str | int]]:
|
|
595
|
+
"""Fixture to safely skip MySQL based tests if MySQL is not running"""
|
|
596
|
+
host = os.environ.get('NAT_CI_MYSQL_HOST', '127.0.0.1')
|
|
597
|
+
port = int(os.environ.get('NAT_CI_MYSQL_PORT', '3306'))
|
|
598
|
+
user = os.environ.get('NAT_CI_MYSQL_USER', 'root')
|
|
599
|
+
password = os.environ.get('MYSQL_ROOT_PASSWORD', 'my_password')
|
|
600
|
+
bucket_name = os.environ.get('NAT_CI_MYSQL_BUCKET_NAME', 'test')
|
|
601
|
+
try:
|
|
602
|
+
import aiomysql
|
|
603
|
+
conn = await aiomysql.connect(host=host, port=port, user=user, password=password)
|
|
604
|
+
yield {"host": host, "port": port, "username": user, "password": password, "bucket_name": bucket_name}
|
|
605
|
+
conn.close()
|
|
606
|
+
except ImportError:
|
|
607
|
+
if fail_missing:
|
|
608
|
+
raise
|
|
609
|
+
pytest.skip("aiomysql not installed, skipping MySQL tests")
|
|
610
|
+
except Exception as e:
|
|
611
|
+
if fail_missing:
|
|
612
|
+
raise
|
|
613
|
+
pytest.skip(f"Error connecting to MySQL server: {e}, skipping MySQL tests")
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
@pytest.fixture(name="minio_server", scope="session")
|
|
617
|
+
def minio_server_fixture(fail_missing: bool) -> Generator[dict[str, str | int]]:
|
|
618
|
+
"""Fixture to safely skip MinIO based tests if MinIO is not running"""
|
|
619
|
+
host = os.getenv("NAT_CI_MINIO_HOST", "localhost")
|
|
620
|
+
port = int(os.getenv("NAT_CI_MINIO_PORT", "9000"))
|
|
621
|
+
bucket_name = os.getenv("NAT_CI_MINIO_BUCKET_NAME", "test")
|
|
622
|
+
aws_access_key_id = os.getenv("NAT_CI_MINIO_ACCESS_KEY_ID", "minioadmin")
|
|
623
|
+
aws_secret_access_key = os.getenv("NAT_CI_MINIO_SECRET_ACCESS_KEY", "minioadmin")
|
|
624
|
+
endpoint_url = f"http://{host}:{port}"
|
|
625
|
+
|
|
626
|
+
minio_info = {
|
|
627
|
+
"host": host,
|
|
628
|
+
"port": port,
|
|
629
|
+
"bucket_name": bucket_name,
|
|
630
|
+
"endpoint_url": endpoint_url,
|
|
631
|
+
"aws_access_key_id": aws_access_key_id,
|
|
632
|
+
"aws_secret_access_key": aws_secret_access_key,
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
try:
|
|
636
|
+
import botocore.session
|
|
637
|
+
session = botocore.session.get_session()
|
|
638
|
+
|
|
639
|
+
client = session.create_client("s3",
|
|
640
|
+
aws_access_key_id=aws_access_key_id,
|
|
641
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
642
|
+
endpoint_url=endpoint_url)
|
|
643
|
+
client.list_buckets()
|
|
644
|
+
yield minio_info
|
|
645
|
+
except ImportError:
|
|
646
|
+
if fail_missing:
|
|
647
|
+
raise
|
|
648
|
+
pytest.skip("aioboto3 not installed, skipping MinIO tests")
|
|
649
|
+
except Exception as e:
|
|
650
|
+
if fail_missing:
|
|
651
|
+
raise
|
|
652
|
+
else:
|
|
653
|
+
pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
@pytest.fixture(name="langfuse_bucket", scope="session")
|
|
657
|
+
def langfuse_bucket_fixture(fail_missing: bool, minio_server: dict[str, str | int]) -> Generator[str]:
|
|
658
|
+
|
|
659
|
+
bucket_name = os.getenv("NAT_CI_LANGFUSE_BUCKET", "langfuse")
|
|
660
|
+
try:
|
|
661
|
+
import botocore.session
|
|
662
|
+
session = botocore.session.get_session()
|
|
663
|
+
|
|
664
|
+
client = session.create_client("s3",
|
|
665
|
+
aws_access_key_id=minio_server["aws_access_key_id"],
|
|
666
|
+
aws_secret_access_key=minio_server["aws_secret_access_key"],
|
|
667
|
+
endpoint_url=minio_server["endpoint_url"])
|
|
668
|
+
|
|
669
|
+
buckets = client.list_buckets()
|
|
670
|
+
bucket_names = [b['Name'] for b in buckets['Buckets']]
|
|
671
|
+
if bucket_name not in bucket_names:
|
|
672
|
+
client.create_bucket(Bucket=bucket_name)
|
|
673
|
+
|
|
674
|
+
yield bucket_name
|
|
675
|
+
except ImportError:
|
|
676
|
+
if fail_missing:
|
|
677
|
+
raise
|
|
678
|
+
pytest.skip("aioboto3 not installed, skipping MinIO tests")
|
|
679
|
+
except Exception as e:
|
|
680
|
+
if fail_missing:
|
|
681
|
+
raise
|
|
682
|
+
else:
|
|
683
|
+
pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
@pytest.fixture(name="langfuse_url", scope="session")
|
|
687
|
+
def langfuse_url_fixture(fail_missing: bool, langfuse_bucket: str) -> str:
|
|
688
|
+
"""
|
|
689
|
+
To run these tests, a langfuse server must be running.
|
|
690
|
+
"""
|
|
691
|
+
import requests
|
|
692
|
+
|
|
693
|
+
host = os.getenv("NAT_CI_LANGFUSE_HOST", "localhost")
|
|
694
|
+
port = int(os.getenv("NAT_CI_LANGFUSE_PORT", "3000"))
|
|
695
|
+
url = f"http://{host}:{port}"
|
|
696
|
+
health_endpoint = f"{url}/api/public/health"
|
|
697
|
+
try:
|
|
698
|
+
response = requests.get(health_endpoint, timeout=5)
|
|
699
|
+
response.raise_for_status()
|
|
700
|
+
|
|
701
|
+
return url
|
|
702
|
+
except Exception as e:
|
|
703
|
+
reason = f"Unable to connect to Langfuse server at {url}: {e}"
|
|
704
|
+
if fail_missing:
|
|
705
|
+
raise RuntimeError(reason)
|
|
706
|
+
pytest.skip(reason=reason)
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
@pytest.fixture(name="langfuse_trace_url", scope="session")
|
|
710
|
+
def langfuse_trace_url_fixture(langfuse_url: str) -> str:
|
|
711
|
+
"""
|
|
712
|
+
The langfuse_url fixture provides the base url, however the general.telemetry.tracing["langfuse"].endpoint expects
|
|
713
|
+
the trace url which is what this fixture provides.
|
|
714
|
+
"""
|
|
715
|
+
return f"{langfuse_url}/api/public/otel/v1/traces"
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
@pytest.fixture(name="oauth2_server_url", scope="session")
|
|
719
|
+
def oauth2_server_url_fixture(fail_missing: bool) -> str:
|
|
720
|
+
"""
|
|
721
|
+
To run these tests, an oauth2 server must be running.
|
|
722
|
+
"""
|
|
723
|
+
import requests
|
|
724
|
+
|
|
725
|
+
host = os.getenv("NAT_CI_OAUTH2_HOST", "localhost")
|
|
726
|
+
port = int(os.getenv("NAT_CI_OAUTH2_PORT", "5001"))
|
|
727
|
+
url = f"http://{host}:{port}"
|
|
728
|
+
try:
|
|
729
|
+
response = requests.get(url, timeout=5)
|
|
730
|
+
response.raise_for_status()
|
|
731
|
+
|
|
732
|
+
return url
|
|
733
|
+
except Exception as e:
|
|
734
|
+
reason = f"Unable to connect to OAuth2 server at {url}: {e}"
|
|
735
|
+
if fail_missing:
|
|
736
|
+
raise RuntimeError(reason)
|
|
737
|
+
pytest.skip(reason=reason)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
@pytest.fixture(name="oauth2_client_credentials", scope="session")
|
|
741
|
+
def oauth2_client_credentials_fixture(oauth2_server_url: str, fail_missing: bool) -> dict[str, typing.Any]:
|
|
742
|
+
"""
|
|
743
|
+
Fixture to provide OAuth2 client credentials for testing
|
|
744
|
+
|
|
745
|
+
Simulates the steps a user would take in a web browser to create a new OAuth2 client as documented in:
|
|
746
|
+
examples/front_ends/simple_auth/README.md
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
try:
|
|
750
|
+
import requests
|
|
751
|
+
from bs4 import BeautifulSoup
|
|
752
|
+
username = os.getenv("NAT_CI_OAUTH2_CLIENT_USERNAME", "Testy Testerson")
|
|
753
|
+
|
|
754
|
+
# This post request responds with a cookie that we need for future requests and a 302 redirect, the response
|
|
755
|
+
# for the redirected url doesn't contain the cookie, so we disable the redirect here to capture the cookie
|
|
756
|
+
user_create_response = requests.post(oauth2_server_url,
|
|
757
|
+
data=[("username", username)],
|
|
758
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
759
|
+
allow_redirects=False,
|
|
760
|
+
timeout=5)
|
|
761
|
+
user_create_response.raise_for_status()
|
|
762
|
+
cookies = user_create_response.cookies
|
|
763
|
+
|
|
764
|
+
client_create_response = requests.post(f"{oauth2_server_url}/create_client",
|
|
765
|
+
cookies=cookies,
|
|
766
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
767
|
+
data=[
|
|
768
|
+
("client_name", "test"),
|
|
769
|
+
("client_uri", "https://test.com"),
|
|
770
|
+
("scope", "openid profile email"),
|
|
771
|
+
("redirect_uri", "http://localhost:8000/auth/redirect"),
|
|
772
|
+
("grant_type", "authorization_code\nrefresh_token"),
|
|
773
|
+
("response_type", "code"),
|
|
774
|
+
("token_endpoint_auth_method", "client_secret_post"),
|
|
775
|
+
],
|
|
776
|
+
timeout=5)
|
|
777
|
+
client_create_response.raise_for_status()
|
|
778
|
+
|
|
779
|
+
# Unfortunately the response is HTML so we need to parse it to get the client ID and secret, which are not
|
|
780
|
+
# locatable via ID tags
|
|
781
|
+
soup = BeautifulSoup(client_create_response.text, 'html.parser')
|
|
782
|
+
strong_tags = soup.find_all('strong')
|
|
783
|
+
i = 0
|
|
784
|
+
client_id = None
|
|
785
|
+
client_secret = None
|
|
786
|
+
while i < len(strong_tags) and None in (client_id, client_secret):
|
|
787
|
+
tag = strong_tags[i]
|
|
788
|
+
contents = "".join(tag.contents)
|
|
789
|
+
if client_id is None and "client_id:" in contents:
|
|
790
|
+
client_id = tag.next_sibling.strip()
|
|
791
|
+
elif client_secret is None and "client_secret:" in contents:
|
|
792
|
+
client_secret = tag.next_sibling.strip()
|
|
793
|
+
|
|
794
|
+
i += 1
|
|
795
|
+
|
|
796
|
+
assert client_id is not None and client_secret is not None, "Failed to parse client credentials from response"
|
|
797
|
+
|
|
798
|
+
return {
|
|
799
|
+
"id": client_id,
|
|
800
|
+
"secret": client_secret,
|
|
801
|
+
"username": username,
|
|
802
|
+
"url": oauth2_server_url,
|
|
803
|
+
"cookies": cookies
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
except Exception as e:
|
|
807
|
+
reason = f"Unable to create OAuth2 client: {e}"
|
|
808
|
+
if fail_missing:
|
|
809
|
+
raise RuntimeError(reason)
|
|
810
|
+
pytest.skip(reason=reason)
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
@pytest.fixture(name="local_sandbox_url", scope="session")
|
|
814
|
+
def local_sandbox_url_fixture(fail_missing: bool) -> str:
|
|
815
|
+
"""Check if sandbox server is running before running tests."""
|
|
816
|
+
import requests
|
|
817
|
+
url = os.environ.get("NAT_CI_SANDBOX_URL", "http://127.0.0.1:6000")
|
|
818
|
+
try:
|
|
819
|
+
response = requests.get(url, timeout=5)
|
|
820
|
+
response.raise_for_status()
|
|
821
|
+
return url
|
|
822
|
+
except Exception:
|
|
823
|
+
reason = (f"Sandbox server is not running at {url}. "
|
|
824
|
+
"Please start it with: cd src/nat/tool/code_execution/local_sandbox && ./start_local_sandbox.sh")
|
|
825
|
+
if fail_missing:
|
|
826
|
+
raise RuntimeError(reason)
|
|
827
|
+
pytest.skip(reason)
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
@pytest.fixture(name="sandbox_config", scope="session")
|
|
831
|
+
def sandbox_config_fixture(local_sandbox_url: str) -> dict[str, typing.Any]:
|
|
832
|
+
"""Configuration for sandbox testing."""
|
|
833
|
+
return {
|
|
834
|
+
"base_url": local_sandbox_url,
|
|
835
|
+
"execute_url": f"{local_sandbox_url.rstrip('/')}/execute",
|
|
836
|
+
"timeout": int(os.environ.get("SANDBOX_TIMEOUT", "30")),
|
|
837
|
+
"connection_timeout": 5
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
@pytest.fixture(name="piston_url", scope="session")
|
|
842
|
+
def piston_url_fixture(fail_missing: bool) -> str:
|
|
843
|
+
"""
|
|
844
|
+
Verify that a Piston server is running and has the required python version installed.
|
|
845
|
+
"""
|
|
846
|
+
import requests
|
|
847
|
+
|
|
848
|
+
url = os.environ.get("NAT_CI_PISTON_URL", "http://localhost:2000/api/v2")
|
|
849
|
+
url = url.rstrip('/')
|
|
850
|
+
|
|
851
|
+
# This is the version of Python used in `src/nat/tool/code_execution/code_sandbox.py`
|
|
852
|
+
python_version = os.environ.get("NAT_CI_PISTON_PYTHON_VERSION", "3.10.0")
|
|
853
|
+
try:
|
|
854
|
+
# If this request returns a 200 status code then the server is running
|
|
855
|
+
response = requests.get(f"{url}/runtimes", timeout=30)
|
|
856
|
+
response.raise_for_status()
|
|
857
|
+
|
|
858
|
+
# Check if the required python version is installed
|
|
859
|
+
runtimes = response.json()
|
|
860
|
+
for runtime in runtimes:
|
|
861
|
+
if runtime["language"] == "python" and runtime["version"] == python_version:
|
|
862
|
+
return url
|
|
863
|
+
|
|
864
|
+
# Install the required python version
|
|
865
|
+
response = requests.post(f"{url}/packages", json={"language": "python", "version": python_version}, timeout=60)
|
|
866
|
+
response.raise_for_status()
|
|
867
|
+
|
|
868
|
+
return url
|
|
869
|
+
except Exception:
|
|
870
|
+
reason = (f"Piston server is not running at {url}. "
|
|
871
|
+
"Please start it along with the other integration services by running: "
|
|
872
|
+
"docker compose -f tests/test_data/docker-compose.services.yml up -d")
|
|
873
|
+
if fail_missing:
|
|
874
|
+
raise RuntimeError(reason)
|
|
875
|
+
pytest.skip(reason)
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
@pytest.fixture(autouse=True, scope="session")
|
|
879
|
+
def import_adk_early():
|
|
880
|
+
"""
|
|
881
|
+
Import ADK early to work-around slow import issue (https://github.com/google/adk-python/issues/2433),
|
|
882
|
+
when ADK is imported early it takes about 8 seconds, however if we wait until the `packages/nvidia_nat_adk/tests`
|
|
883
|
+
run the same import will take about 70 seconds.
|
|
884
|
+
|
|
885
|
+
Since ADK is an optional dependency, we will ignore any import errors.
|
|
886
|
+
"""
|
|
887
|
+
try:
|
|
888
|
+
import google.adk # noqa: F401
|
|
889
|
+
except ImportError:
|
|
890
|
+
pass
|