nvidia-nat-test 1.2.0__py3-none-any.whl → 1.4.0a20251212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-test might be problematic. Click here for more details.

nat/test/register.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- # pylint: disable=unused-import
17
16
  # flake8: noqa
18
17
  # isort:skip_file
19
18
 
@@ -22,3 +21,5 @@
22
21
  from . import embedder
23
22
  from . import functions
24
23
  from . import memory
24
+ from . import llm
25
+ from . import utils
@@ -14,18 +14,42 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import inspect
17
18
  import logging
18
19
  import typing
20
+ from collections.abc import Sequence
19
21
  from contextlib import asynccontextmanager
20
22
  from unittest.mock import AsyncMock
21
23
  from unittest.mock import MagicMock
22
24
 
25
+ from nat.authentication.interfaces import AuthProviderBase
23
26
  from nat.builder.builder import Builder
24
27
  from nat.builder.function import Function
28
+ from nat.builder.function import FunctionGroup
25
29
  from nat.builder.function_info import FunctionInfo
26
30
  from nat.cli.type_registry import GlobalTypeRegistry
31
+ from nat.data_models.authentication import AuthProviderBaseConfig
32
+ from nat.data_models.component_ref import MiddlewareRef
33
+ from nat.data_models.embedder import EmbedderBaseConfig
34
+ from nat.data_models.finetuning import TrainerAdapterConfig
35
+ from nat.data_models.finetuning import TrainerConfig
36
+ from nat.data_models.finetuning import TrajectoryBuilderConfig
27
37
  from nat.data_models.function import FunctionBaseConfig
38
+ from nat.data_models.function import FunctionGroupBaseConfig
39
+ from nat.data_models.function_dependencies import FunctionDependencies
40
+ from nat.data_models.llm import LLMBaseConfig
41
+ from nat.data_models.memory import MemoryBaseConfig
42
+ from nat.data_models.middleware import FunctionMiddlewareBaseConfig
28
43
  from nat.data_models.object_store import ObjectStoreBaseConfig
44
+ from nat.data_models.retriever import RetrieverBaseConfig
45
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
46
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
47
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
48
+ from nat.finetuning.interfaces.finetuning_runner import Trainer
49
+ from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
50
+ from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
51
+ from nat.memory.interfaces import MemoryEditor
52
+ from nat.middleware import FunctionMiddleware
29
53
  from nat.object_store.interfaces import ObjectStore
30
54
  from nat.runtime.loader import PluginTypes
31
55
  from nat.runtime.loader import discover_and_register_plugins
@@ -46,6 +70,10 @@ class MockBuilder(Builder):
46
70
  """Add a mock function that returns a fixed response."""
47
71
  self._mocks[name] = mock_response
48
72
 
73
+ def mock_function_group(self, name: str, mock_response: typing.Any):
74
+ """Add a mock function group that returns a fixed response."""
75
+ self._mocks[name] = mock_response
76
+
49
77
  def mock_llm(self, name: str, mock_response: typing.Any):
50
78
  """Add a mock LLM that returns a fixed response."""
51
79
  self._mocks[f"llm_{name}"] = mock_response
@@ -70,14 +98,28 @@ class MockBuilder(Builder):
70
98
  """Add a mock TTC strategy that returns a fixed response."""
71
99
  self._mocks[f"ttc_strategy_{name}"] = mock_response
72
100
 
73
- async def add_ttc_strategy(self, name: str, config):
101
+ def mock_auth_provider(self, name: str, mock_response: typing.Any):
102
+ """Add a mock auth provider that returns a fixed response."""
103
+ self._mocks[f"auth_provider_{name}"] = mock_response
104
+
105
+ def mock_trainer(self, name: str, mock_response: typing.Any):
106
+ """Add a mock trainer that returns a fixed response."""
107
+ self._mocks[f"trainer_{name}"] = mock_response
108
+
109
+ def mock_trainer_adapter(self, name: str, mock_response: typing.Any):
110
+ """Add a mock trainer adapter that returns a fixed response."""
111
+ self._mocks[f"trainer_adapter_{name}"] = mock_response
112
+
113
+ def mock_trajectory_builder(self, name: str, mock_response: typing.Any):
114
+ """Add a mock trajectory builder that returns a fixed response."""
115
+ self._mocks[f"trajectory_builder_{name}"] = mock_response
116
+
117
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
74
118
  """Mock implementation (no‑op)."""
75
119
  pass
76
120
 
77
- async def get_ttc_strategy(self,
78
- strategy_name: str,
79
- pipeline_type: typing.Any = None,
80
- stage_type: typing.Any = None):
121
+ async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
122
+ stage_type: StageTypeEnum) -> typing.Any:
81
123
  """Return a mock TTC strategy if one is configured."""
82
124
  key = f"ttc_strategy_{strategy_name}"
83
125
  if key in self._mocks:
@@ -90,16 +132,29 @@ class MockBuilder(Builder):
90
132
 
91
133
  async def get_ttc_strategy_config(self,
92
134
  strategy_name: str,
93
- pipeline_type: typing.Any = None,
94
- stage_type: typing.Any = None):
135
+ pipeline_type: PipelineTypeEnum,
136
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
95
137
  """Mock implementation."""
138
+ return TTCStrategyBaseConfig()
139
+
140
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
141
+ """Mock implementation (no‑op)."""
96
142
  pass
97
143
 
144
+ async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
145
+ """Return a mock auth provider if one is configured."""
146
+ key = f"auth_provider_{auth_provider_name}"
147
+ if key in self._mocks:
148
+ mock_auth = MagicMock()
149
+ mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
150
+ return mock_auth
151
+ raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
152
+
98
153
  async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
99
154
  """Mock implementation - not used in tool testing."""
100
155
  raise NotImplementedError("Mock implementation does not support add_function")
101
156
 
102
- def get_function(self, name: str) -> Function:
157
+ async def get_function(self, name: str) -> Function:
103
158
  """Return a mock function if one is configured."""
104
159
  if name in self._mocks:
105
160
  mock_fn = AsyncMock()
@@ -109,25 +164,49 @@ class MockBuilder(Builder):
109
164
 
110
165
  def get_function_config(self, name: str) -> FunctionBaseConfig:
111
166
  """Mock implementation."""
112
- pass
167
+ return FunctionBaseConfig()
168
+
169
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
170
+ """Mock implementation - not used in tool testing."""
171
+ raise NotImplementedError("Mock implementation does not support add_function_group")
172
+
173
+ async def get_function_group(self, name: str) -> FunctionGroup:
174
+ """Return a mock function group if one is configured."""
175
+ if name in self._mocks:
176
+ mock_fn_group = MagicMock(spec=FunctionGroup)
177
+ mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
178
+ return mock_fn_group
179
+ raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
180
+
181
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
182
+ """Mock implementation."""
183
+ return FunctionGroupBaseConfig()
113
184
 
114
185
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
115
186
  """Mock implementation."""
116
- pass
187
+ mock_fn = AsyncMock()
188
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
189
+ return mock_fn
117
190
 
118
191
  def get_workflow(self) -> Function:
119
192
  """Mock implementation."""
120
- pass
193
+ mock_fn = AsyncMock()
194
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
195
+ return mock_fn
121
196
 
122
197
  def get_workflow_config(self) -> FunctionBaseConfig:
123
198
  """Mock implementation."""
124
- pass
199
+ return FunctionBaseConfig()
125
200
 
126
- def get_tool(self, fn_name: str, wrapper_type):
201
+ async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
202
+ """Mock implementation."""
203
+ return []
204
+
205
+ async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
127
206
  """Mock implementation."""
128
207
  pass
129
208
 
130
- async def add_llm(self, name: str, config):
209
+ async def add_llm(self, name: str, config) -> None:
131
210
  """Mock implementation."""
132
211
  pass
133
212
 
@@ -141,11 +220,11 @@ class MockBuilder(Builder):
141
220
  return mock_llm
142
221
  raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
143
222
 
144
- def get_llm_config(self, llm_name: str):
223
+ def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
145
224
  """Mock implementation."""
146
- pass
225
+ return LLMBaseConfig()
147
226
 
148
- async def add_embedder(self, name: str, config):
227
+ async def add_embedder(self, name: str, config) -> None:
149
228
  """Mock implementation."""
150
229
  pass
151
230
 
@@ -159,15 +238,14 @@ class MockBuilder(Builder):
159
238
  return mock_embedder
160
239
  raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
161
240
 
162
- def get_embedder_config(self, embedder_name: str):
241
+ def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
163
242
  """Mock implementation."""
164
- pass
243
+ return EmbedderBaseConfig()
165
244
 
166
- async def add_memory_client(self, name: str, config):
167
- """Mock implementation."""
168
- pass
245
+ async def add_memory_client(self, name: str, config) -> MemoryEditor:
246
+ return MagicMock(spec=MemoryEditor)
169
247
 
170
- def get_memory_client(self, memory_name: str):
248
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
171
249
  """Return a mock memory client if one is configured."""
172
250
  key = f"memory_{memory_name}"
173
251
  if key in self._mocks:
@@ -177,11 +255,11 @@ class MockBuilder(Builder):
177
255
  return mock_memory
178
256
  raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
179
257
 
180
- def get_memory_client_config(self, memory_name: str):
258
+ def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
181
259
  """Mock implementation."""
182
- pass
260
+ return MemoryBaseConfig()
183
261
 
184
- async def add_retriever(self, name: str, config):
262
+ async def add_retriever(self, name: str, config) -> None:
185
263
  """Mock implementation."""
186
264
  pass
187
265
 
@@ -194,13 +272,13 @@ class MockBuilder(Builder):
194
272
  return mock_retriever
195
273
  raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
196
274
 
197
- async def get_retriever_config(self, retriever_name: str):
275
+ async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
198
276
  """Mock implementation."""
199
- pass
277
+ return RetrieverBaseConfig()
200
278
 
201
- async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
279
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
202
280
  """Mock implementation for object store."""
203
- pass
281
+ return MagicMock(spec=ObjectStore)
204
282
 
205
283
  async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
206
284
  """Return a mock object store client if one is configured."""
@@ -216,7 +294,7 @@ class MockBuilder(Builder):
216
294
 
217
295
  def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
218
296
  """Mock implementation for object store config."""
219
- pass
297
+ return ObjectStoreBaseConfig()
220
298
 
221
299
  def get_user_manager(self):
222
300
  """Mock implementation."""
@@ -224,9 +302,81 @@ class MockBuilder(Builder):
224
302
  mock_user.get_id = MagicMock(return_value="test_user")
225
303
  return mock_user
226
304
 
227
- def get_function_dependencies(self, fn_name: str):
305
+ def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
228
306
  """Mock implementation."""
229
- pass
307
+ return FunctionDependencies()
308
+
309
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
310
+ """Mock implementation."""
311
+ return FunctionDependencies()
312
+
313
+ async def get_middleware(self, middleware_name: str | MiddlewareRef) -> FunctionMiddleware:
314
+ """Mock implementation."""
315
+ return FunctionMiddleware()
316
+
317
+ def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> FunctionMiddlewareBaseConfig:
318
+ """Mock implementation."""
319
+ return FunctionMiddlewareBaseConfig()
320
+
321
+ async def add_middleware(self, name: str | MiddlewareRef,
322
+ config: FunctionMiddlewareBaseConfig) -> FunctionMiddleware:
323
+ """Mock implementation."""
324
+ return FunctionMiddleware()
325
+
326
+ async def add_trainer(self, name: str, config: TrainerConfig) -> Trainer:
327
+ """Mock implementation."""
328
+ return MagicMock(spec=Trainer)
329
+
330
+ async def get_trainer(self,
331
+ trainer_name: str,
332
+ trajectory_builder: TrajectoryBuilder,
333
+ trainer_adapter: TrainerAdapter) -> Trainer:
334
+ """Return a mock trainer if one is configured."""
335
+ key = f"trainer_{trainer_name}"
336
+ if key in self._mocks:
337
+ mock_trainer = MagicMock()
338
+ mock_trainer.train = AsyncMock(return_value=self._mocks[key])
339
+ return mock_trainer
340
+ raise ValueError(f"Trainer '{trainer_name}' not mocked. Use mock_trainer() to add it.")
341
+
342
+ async def get_trainer_config(self, trainer_name: str) -> TrainerConfig:
343
+ """Mock implementation."""
344
+ return TrainerConfig()
345
+
346
+ async def add_trainer_adapter(self, name: str, config: TrainerAdapterConfig) -> TrainerAdapter:
347
+ """Mock implementation."""
348
+ return MagicMock(spec=TrainerAdapter)
349
+
350
+ async def get_trainer_adapter(self, trainer_adapter_name: str) -> TrainerAdapter:
351
+ """Return a mock trainer adapter if one is configured."""
352
+ key = f"trainer_adapter_{trainer_adapter_name}"
353
+ if key in self._mocks:
354
+ mock_adapter = MagicMock()
355
+ mock_adapter.adapt = AsyncMock(return_value=self._mocks[key])
356
+ return mock_adapter
357
+ raise ValueError(f"Trainer adapter '{trainer_adapter_name}' not mocked. Use mock_trainer_adapter() to add it.")
358
+
359
+ async def get_trainer_adapter_config(self, trainer_adapter_name: str) -> TrainerAdapterConfig:
360
+ """Mock implementation."""
361
+ return TrainerAdapterConfig()
362
+
363
+ async def add_trajectory_builder(self, name: str, config: TrajectoryBuilderConfig) -> TrajectoryBuilder:
364
+ """Mock implementation."""
365
+ return MagicMock(spec=TrajectoryBuilder)
366
+
367
+ async def get_trajectory_builder(self, trajectory_builder_name: str) -> TrajectoryBuilder:
368
+ """Return a mock trajectory builder if one is configured."""
369
+ key = f"trajectory_builder_{trajectory_builder_name}"
370
+ if key in self._mocks:
371
+ mock_builder = MagicMock()
372
+ mock_builder.build = AsyncMock(return_value=self._mocks[key])
373
+ return mock_builder
374
+ raise ValueError(
375
+ f"Trajectory builder '{trajectory_builder_name}' not mocked. Use mock_trajectory_builder() to add it.")
376
+
377
+ async def get_trajectory_builder_config(self, trajectory_builder_name: str) -> TrajectoryBuilderConfig:
378
+ """Mock implementation."""
379
+ return TrajectoryBuilderConfig()
230
380
 
231
381
 
232
382
  class ToolTestRunner:
@@ -322,15 +472,19 @@ class ToolTestRunner:
322
472
 
323
473
  # Execute the tool
324
474
  if input_data is not None:
325
- if asyncio.iscoroutinefunction(tool_function):
475
+ if isinstance(tool_function, Function):
476
+ result = await tool_function.ainvoke(input_data)
477
+ elif asyncio.iscoroutinefunction(tool_function):
326
478
  result = await tool_function(input_data)
327
479
  else:
328
480
  result = tool_function(input_data)
481
+ elif isinstance(tool_function, Function):
482
+ # Function objects require input, so pass None if no input_data
483
+ result = await tool_function.ainvoke(None)
484
+ elif asyncio.iscoroutinefunction(tool_function):
485
+ result = await tool_function()
329
486
  else:
330
- if asyncio.iscoroutinefunction(tool_function):
331
- result = await tool_function()
332
- else:
333
- result = tool_function()
487
+ result = tool_function()
334
488
 
335
489
  # Assert expected output if provided
336
490
  if expected_output is not None:
@@ -402,8 +556,8 @@ class ToolTestRunner:
402
556
  elif isinstance(tool_result, FunctionInfo):
403
557
  if tool_result.single_fn:
404
558
  tool_function = tool_result.single_fn
405
- elif tool_result.streaming_fn:
406
- tool_function = tool_result.streaming_fn
559
+ elif tool_result.stream_fn:
560
+ tool_function = tool_result.stream_fn
407
561
  else:
408
562
  raise ValueError("Tool function not found in FunctionInfo")
409
563
  elif callable(tool_result):
@@ -413,15 +567,17 @@ class ToolTestRunner:
413
567
 
414
568
  # Execute the tool
415
569
  if input_data is not None:
416
- if asyncio.iscoroutinefunction(tool_function):
417
- result = await tool_function(input_data)
570
+ if isinstance(tool_function, Function):
571
+ result = await tool_function.ainvoke(input_data)
418
572
  else:
419
- result = tool_function(input_data)
573
+ maybe_result = tool_function(input_data)
574
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
575
+ elif isinstance(tool_function, Function):
576
+ # Function objects require input, so pass None if no input_data
577
+ result = await tool_function.ainvoke(None)
420
578
  else:
421
- if asyncio.iscoroutinefunction(tool_function):
422
- result = await tool_function()
423
- else:
424
- result = tool_function()
579
+ maybe_result = tool_function()
580
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
425
581
 
426
582
  # Assert expected output if provided
427
583
  if expected_output is not None:
nat/test/utils.py ADDED
@@ -0,0 +1,148 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib.resources
17
+ import inspect
18
+ import json
19
+ import subprocess
20
+ import typing
21
+ from contextlib import asynccontextmanager
22
+ from pathlib import Path
23
+
24
+ if typing.TYPE_CHECKING:
25
+ from collections.abc import AsyncIterator
26
+
27
+ from httpx import AsyncClient
28
+
29
+ from nat.data_models.config import Config
30
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
31
+ from nat.utils.type_utils import StrPath
32
+
33
+
34
+ def locate_repo_root() -> Path:
35
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
36
+ assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
37
+ return Path(result.stdout.strip())
38
+
39
+
40
+ def locate_example_src_dir(example_config_class: type) -> Path:
41
+ """
42
+ Locate the example src directory for an example's config class.
43
+ """
44
+ package_name = inspect.getmodule(example_config_class).__package__
45
+ return importlib.resources.files(package_name)
46
+
47
+
48
+ def locate_example_dir(example_config_class: type) -> Path:
49
+ """
50
+ Locate the example directory for an example's config class.
51
+ """
52
+ src_dir = locate_example_src_dir(example_config_class)
53
+ example_dir = src_dir.parent.parent
54
+ return example_dir
55
+
56
+
57
+ def locate_example_config(example_config_class: type,
58
+ config_file: str = "config.yml",
59
+ assert_exists: bool = True) -> Path:
60
+ """
61
+ Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
62
+ """
63
+ example_dir = locate_example_src_dir(example_config_class)
64
+ config_path = example_dir.joinpath("configs", config_file).absolute()
65
+ if assert_exists:
66
+ assert config_path.exists(), f"Config file {config_path} does not exist"
67
+
68
+ return config_path
69
+
70
+
71
+ async def run_workflow(*,
72
+ config: "Config | None" = None,
73
+ config_file: "StrPath | None" = None,
74
+ question: str,
75
+ expected_answer: str,
76
+ assert_expected_answer: bool = True,
77
+ **kwargs) -> str:
78
+ """
79
+ Test specific wrapper for `nat.utils.run_workflow` to run a workflow with a question and validate the expected
80
+ answer. This variant always sets the result type to `str`.
81
+ """
82
+ from nat.utils import run_workflow as nat_run_workflow
83
+
84
+ result = await nat_run_workflow(config=config, config_file=config_file, prompt=question, to_type=str, **kwargs)
85
+
86
+ if assert_expected_answer:
87
+ assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
88
+
89
+ return result
90
+
91
+
92
+ @asynccontextmanager
93
+ async def build_nat_client(
94
+ config: "Config",
95
+ worker_class: "type[FastApiFrontEndPluginWorker] | None" = None) -> "AsyncIterator[AsyncClient]":
96
+ """
97
+ Build a NAT client for testing purposes.
98
+
99
+ Creates a test client with an ASGI transport for the specified configuration.
100
+ The client is backed by a FastAPI application built from the provided worker class.
101
+
102
+ Args:
103
+ config: The NAT configuration to use for building the client.
104
+ worker_class: Optional worker class to use. Defaults to FastApiFrontEndPluginWorker.
105
+
106
+ Yields:
107
+ An AsyncClient instance configured for testing.
108
+ """
109
+ from asgi_lifespan import LifespanManager
110
+ from httpx import ASGITransport
111
+ from httpx import AsyncClient
112
+
113
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
114
+
115
+ if worker_class is None:
116
+ worker_class = FastApiFrontEndPluginWorker
117
+
118
+ worker = worker_class(config)
119
+ app = worker.build_app()
120
+
121
+ async with LifespanManager(app):
122
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
123
+ yield client
124
+
125
+
126
+ def validate_workflow_output(workflow_output_file: Path) -> None:
127
+ """
128
+ Validate the contents of the workflow output file.
129
+ WIP: output format should be published as a schema and this validation should be done against that schema.
130
+ """
131
+ # Ensure the workflow_output.json file was created
132
+ assert workflow_output_file.exists(), "The workflow_output.json file was not created"
133
+
134
+ # Read and validate the workflow_output.json file
135
+ try:
136
+ with open(workflow_output_file, encoding="utf-8") as f:
137
+ result_json = json.load(f)
138
+ except json.JSONDecodeError as err:
139
+ raise RuntimeError("Failed to parse workflow_output.json as valid JSON") from err
140
+
141
+ assert isinstance(result_json, list), "The workflow_output.json file is not a list"
142
+ assert len(result_json) > 0, "The workflow_output.json file is empty"
143
+ assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries"
144
+
145
+ # Ensure required keys exist
146
+ required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"]
147
+ for key in required_keys:
148
+ assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"
@@ -1,14 +1,25 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-test
3
- Version: 1.2.0
3
+ Version: 1.4.0a20251212
4
4
  Summary: Testing utilities for NeMo Agent toolkit
5
+ Author: NVIDIA Corporation
6
+ Maintainer: NVIDIA Corporation
7
+ License: Apache-2.0
8
+ Project-URL: documentation, https://docs.nvidia.com/nemo/agent-toolkit/latest/
9
+ Project-URL: source, https://github.com/NVIDIA/NeMo-Agent-Toolkit
5
10
  Keywords: ai,rag,agents
6
11
  Classifier: Programming Language :: Python
7
- Requires-Python: <3.13,>=3.11
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Python: <3.14,>=3.11
8
16
  Description-Content-Type: text/markdown
9
- Requires-Dist: nvidia-nat==v1.2.0
17
+ License-File: LICENSE-3rd-party.txt
18
+ License-File: LICENSE.md
19
+ Requires-Dist: nvidia-nat==v1.4.0a20251212
10
20
  Requires-Dist: langchain-community~=0.3
11
21
  Requires-Dist: pytest~=8.3
22
+ Dynamic: license-file
12
23
 
13
24
  <!--
14
25
  SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
@@ -0,0 +1,18 @@
1
+ nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
2
+ nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
3
+ nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
4
+ nat/test/functions.py,sha256=ZxXVzfaLBGOpR5qtmMrKU7q-M9-vVGGj3Xi5mrw4vHY,3557
5
+ nat/test/llm.py,sha256=dbFoWFrSAlUoKm6QGfS4VJdrhgxwkXzm1oaFd6K7jnM,9926
6
+ nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
7
+ nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
8
+ nat/test/plugin.py,sha256=HF25W2YPTiXaoIJggnZTstiTMaspQckvL_thQSseDEc,32434
9
+ nat/test/register.py,sha256=o1BEA5fyxyFyCxXhQ6ArmtuNpgRyTEfvw6HdBgECPLI,897
10
+ nat/test/tool_test_runner.py,sha256=O4FyZMlf-oc1XYbuAFxLzn9_zWi8TzWigFyg-Px7xBc,26126
11
+ nat/test/utils.py,sha256=GyhxIZ1CcUPcc8RMRyCzpHBEwVifeqiGxT3c9Pp0KAU,5774
12
+ nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
13
+ nvidia_nat_test-1.4.0a20251212.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
14
+ nvidia_nat_test-1.4.0a20251212.dist-info/METADATA,sha256=lqR6z09voBpFqNfGMydq_0tDeACzfQVCkSf5muBIwO4,1925
15
+ nvidia_nat_test-1.4.0a20251212.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ nvidia_nat_test-1.4.0a20251212.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
17
+ nvidia_nat_test-1.4.0a20251212.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
+ nvidia_nat_test-1.4.0a20251212.dist-info/RECORD,,