nvidia-nat-test 1.4.0a20260117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/test/utils.py ADDED
@@ -0,0 +1,215 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import importlib.resources
18
+ import inspect
19
+ import json
20
+ import subprocess
21
+ import time
22
+ import typing
23
+ from contextlib import asynccontextmanager
24
+ from pathlib import Path
25
+
26
+ if typing.TYPE_CHECKING:
27
+ from collections.abc import AsyncIterator
28
+
29
+ from httpx import AsyncClient
30
+
31
+ from nat.data_models.config import Config
32
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
33
+ from nat.utils.type_utils import StrPath
34
+
35
+
36
+ def locate_repo_root() -> Path:
37
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
38
+ assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
39
+ return Path(result.stdout.strip())
40
+
41
+
42
+ def locate_example_src_dir(example_config_class: type) -> Path:
43
+ """
44
+ Locate the example src directory for an example's config class.
45
+ """
46
+ package_name = inspect.getmodule(example_config_class).__package__
47
+ return importlib.resources.files(package_name)
48
+
49
+
50
+ def locate_example_dir(example_config_class: type) -> Path:
51
+ """
52
+ Locate the example directory for an example's config class.
53
+ """
54
+ src_dir = locate_example_src_dir(example_config_class)
55
+ example_dir = src_dir.parent.parent
56
+ return example_dir
57
+
58
+
59
+ def locate_example_config(example_config_class: type,
60
+ config_file: str = "config.yml",
61
+ assert_exists: bool = True) -> Path:
62
+ """
63
+ Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
64
+ """
65
+ example_dir = locate_example_src_dir(example_config_class)
66
+ config_path = example_dir.joinpath("configs", config_file).absolute()
67
+ if assert_exists:
68
+ assert config_path.exists(), f"Config file {config_path} does not exist"
69
+
70
+ return config_path
71
+
72
+
73
+ async def run_workflow(*,
74
+ config: "Config | None" = None,
75
+ config_file: "StrPath | None" = None,
76
+ question: str,
77
+ expected_answer: str,
78
+ assert_expected_answer: bool = True,
79
+ **kwargs) -> str:
80
+ """
81
+ Test specific wrapper for `nat.utils.run_workflow` to run a workflow with a question and validate the expected
82
+ answer. This variant always sets the result type to `str`.
83
+ """
84
+ from nat.utils import run_workflow as nat_run_workflow
85
+
86
+ result = await nat_run_workflow(config=config, config_file=config_file, prompt=question, to_type=str, **kwargs)
87
+
88
+ if assert_expected_answer:
89
+ assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
90
+
91
+ return result
92
+
93
+
94
+ async def serve_workflow(*,
95
+ config_path: Path,
96
+ question: str,
97
+ expected_answer: str,
98
+ assert_expected_answer: bool = True,
99
+ port: int = 8000,
100
+ pipeline_timeout: int = 60,
101
+ request_timeout: int = 30) -> dict:
102
+ """
103
+ Execute a workflow using `nat serve`, and issue a POST request to the `/generate` endpoint with the given question.
104
+
105
+ Intended to be analogous to `run_workflow` but for the REST API serving mode.
106
+ """
107
+ import requests
108
+ workflow_url = f"http://localhost:{port}"
109
+ workflow_cmd = ["nat", "serve", "--port", str(port), "--config_file", str(config_path.absolute())]
110
+ proc = subprocess.Popen(workflow_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
111
+ assert proc.poll() is None, f"NAT server process failed to start: {proc.stdout.read()}"
112
+
113
+ response_payload = {}
114
+ try:
115
+ deadline = time.time() + pipeline_timeout # timeout waiting for the workflow to respond
116
+ response = None
117
+ while response is None and time.time() < deadline:
118
+ try:
119
+ response = requests.post(url=f"{workflow_url}/generate",
120
+ json={"messages": [{
121
+ "role": "user", "content": question
122
+ }]},
123
+ timeout=request_timeout)
124
+ except Exception:
125
+ await asyncio.sleep(0.1)
126
+
127
+ assert response is not None, f"deadline exceeded waiting for workflow response: {proc.stdout.read()}"
128
+ response.raise_for_status()
129
+ response_payload = response.json()
130
+ combined_response = []
131
+ response_value = response_payload.get('value', {})
132
+ if isinstance(response_value, str):
133
+ response_text = response_value
134
+ else:
135
+ for choice in response_value.get('choices', []):
136
+ combined_response.append(choice.get('message', {}).get('content', ''))
137
+
138
+ response_text = "\n".join(combined_response)
139
+
140
+ if assert_expected_answer:
141
+ assert expected_answer.lower() in response_text.lower(), \
142
+ f"Unexpected response: {response.text}"
143
+ finally:
144
+ # Teardown
145
+ i = 0
146
+ while proc.poll() is None and i < 5:
147
+ if i == 0:
148
+ proc.terminate()
149
+ else:
150
+ proc.kill()
151
+ await asyncio.sleep(0.1)
152
+ i += 1
153
+
154
+ assert proc.poll() is not None, "NAT server process failed to terminate"
155
+
156
+ return response_payload
157
+
158
+
159
+ @asynccontextmanager
160
+ async def build_nat_client(
161
+ config: "Config",
162
+ worker_class: "type[FastApiFrontEndPluginWorker] | None" = None) -> "AsyncIterator[AsyncClient]":
163
+ """
164
+ Build a NAT client for testing purposes.
165
+
166
+ Creates a test client with an ASGI transport for the specified configuration.
167
+ The client is backed by a FastAPI application built from the provided worker class.
168
+
169
+ Args:
170
+ config: The NAT configuration to use for building the client.
171
+ worker_class: Optional worker class to use. Defaults to FastApiFrontEndPluginWorker.
172
+
173
+ Yields:
174
+ An AsyncClient instance configured for testing.
175
+ """
176
+ from asgi_lifespan import LifespanManager
177
+ from httpx import ASGITransport
178
+ from httpx import AsyncClient
179
+
180
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
181
+
182
+ if worker_class is None:
183
+ worker_class = FastApiFrontEndPluginWorker
184
+
185
+ worker = worker_class(config)
186
+ app = worker.build_app()
187
+
188
+ async with LifespanManager(app):
189
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
190
+ yield client
191
+
192
+
193
+ def validate_workflow_output(workflow_output_file: Path) -> None:
194
+ """
195
+ Validate the contents of the workflow output file.
196
+ WIP: output format should be published as a schema and this validation should be done against that schema.
197
+ """
198
+ # Ensure the workflow_output.json file was created
199
+ assert workflow_output_file.exists(), "The workflow_output.json file was not created"
200
+
201
+ # Read and validate the workflow_output.json file
202
+ try:
203
+ with open(workflow_output_file, encoding="utf-8") as f:
204
+ result_json = json.load(f)
205
+ except json.JSONDecodeError as err:
206
+ raise RuntimeError("Failed to parse workflow_output.json as valid JSON") from err
207
+
208
+ assert isinstance(result_json, list), "The workflow_output.json file is not a list"
209
+ assert len(result_json) > 0, "The workflow_output.json file is empty"
210
+ assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries"
211
+
212
+ # Ensure required keys exist
213
+ required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"]
214
+ for key in required_keys:
215
+ assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.4
2
+ Name: nvidia-nat-test
3
+ Version: 1.4.0a20260117
4
+ Summary: Testing utilities for NeMo Agent toolkit
5
+ Author: NVIDIA Corporation
6
+ Maintainer: NVIDIA Corporation
7
+ License: Apache-2.0
8
+ Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
9
+ Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
10
+ Keywords: ai,rag,agents
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Python: <3.14,>=3.11
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE-3rd-party.txt
18
+ License-File: LICENSE.md
19
+ Requires-Dist: nvidia-nat==v1.4.0a20260117
20
+ Requires-Dist: langchain-community~=0.3
21
+ Requires-Dist: pytest~=8.3
22
+ Dynamic: license-file
23
+
24
+ <!--
25
+ SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
26
+ SPDX-License-Identifier: Apache-2.0
27
+
28
+ Licensed under the Apache License, Version 2.0 (the "License");
29
+ you may not use this file except in compliance with the License.
30
+ You may obtain a copy of the License at
31
+
32
+ http://www.apache.org/licenses/LICENSE-2.0
33
+
34
+ Unless required by applicable law or agreed to in writing, software
35
+ distributed under the License is distributed on an "AS IS" BASIS,
36
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
37
+ See the License for the specific language governing permissions and
38
+ limitations under the License.
39
+ -->
40
+
41
+ ![NVIDIA NeMo Agent Toolkit](https://media.githubusercontent.com/media/NVIDIA/NeMo-Agent-Toolkit/refs/heads/main/docs/source/_static/banner.png "NeMo Agent toolkit banner image")
42
+
43
+ # NVIDIA NeMo Agent Toolkit Subpackage
44
+ This is a subpackage for NeMo Agent toolkit test utilities.
45
+
46
+ For more information about the NVIDIA NeMo Agent toolkit, please visit the [NeMo Agent toolkit GitHub Repo](https://github.com/NVIDIA/NeMo-Agent-Toolkit).
@@ -0,0 +1,18 @@
1
+ nat/meta/pypi.md,sha256=Jtko_QA6tjRGYolUxQSRJhkYIxINaCBesEqUvhdHvjM,1105
2
+ nat/test/__init__.py,sha256=M6-7XytAvHeOTCU-myke2CFZW8vt72LckVvWMPnfv98,880
3
+ nat/test/embedder.py,sha256=0B-IyquRPFl6XJZs6dzinmEBbwNjYw8NFhzG2Mi1pag,1814
4
+ nat/test/functions.py,sha256=JQUX8JoukZo_xWoyzX6clWhAtaDPnBxAKfMkHGyHREs,3562
5
+ nat/test/llm.py,sha256=rszKlDmXFSqSLq2rvnq25oYgrd3A8UDIV2Da6akr11k,9947
6
+ nat/test/memory.py,sha256=85aAxGxYicEvkGJ9qd2-4KXkqQMke01IXvvmrKaUNd8,1461
7
+ nat/test/object_store_tests.py,sha256=-mpqm26tg-96J6E3YiHAvIEvPI7Ka07JAnWlOKc0SBg,4286
8
+ nat/test/plugin.py,sha256=2qSgFwyTTlNuBi7Oe_4PV0jMA_ho9AjLbvMeDtkDsb8,33193
9
+ nat/test/register.py,sha256=Os695g0fDxvSrs9sRo6NSRN488cRQpJHBvORN1p7ADc,897
10
+ nat/test/tool_test_runner.py,sha256=Suvrx6lR2A30WCCAhZuYvfR6Dud66i2GudJ-wAWhrtU,26327
11
+ nat/test/utils.py,sha256=vkqKgdzF4Ew7lNvhzLu3nyrIavXfPafhfyosuJ0oLrk,8511
12
+ nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
13
+ nvidia_nat_test-1.4.0a20260117.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
14
+ nvidia_nat_test-1.4.0a20260117.dist-info/METADATA,sha256=nIIvdWtC8qDk0yfavjzvdGi1dBV0m9ku9abCLuzjeVA,1930
15
+ nvidia_nat_test-1.4.0a20260117.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ nvidia_nat_test-1.4.0a20260117.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
17
+ nvidia_nat_test-1.4.0a20260117.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
+ nvidia_nat_test-1.4.0a20260117.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,5 @@
1
+ [nat.components]
2
+ nvidia-nat-test = nat.test.register
3
+
4
+ [pytest11]
5
+ nvidia-nat-test = nat.test.plugin