nvidia-nat-test 1.3.0a20251108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-test might be problematic. Click here for more details.

@@ -0,0 +1,516 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import asyncio
17
+ import inspect
18
+ import logging
19
+ import typing
20
+ from collections.abc import Sequence
21
+ from contextlib import asynccontextmanager
22
+ from unittest.mock import AsyncMock
23
+ from unittest.mock import MagicMock
24
+
25
+ from nat.authentication.interfaces import AuthProviderBase
26
+ from nat.builder.builder import Builder
27
+ from nat.builder.function import Function
28
+ from nat.builder.function import FunctionGroup
29
+ from nat.builder.function_info import FunctionInfo
30
+ from nat.cli.type_registry import GlobalTypeRegistry
31
+ from nat.data_models.authentication import AuthProviderBaseConfig
32
+ from nat.data_models.embedder import EmbedderBaseConfig
33
+ from nat.data_models.function import FunctionBaseConfig
34
+ from nat.data_models.function import FunctionGroupBaseConfig
35
+ from nat.data_models.function_dependencies import FunctionDependencies
36
+ from nat.data_models.llm import LLMBaseConfig
37
+ from nat.data_models.memory import MemoryBaseConfig
38
+ from nat.data_models.object_store import ObjectStoreBaseConfig
39
+ from nat.data_models.retriever import RetrieverBaseConfig
40
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
41
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
42
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
43
+ from nat.memory.interfaces import MemoryEditor
44
+ from nat.object_store.interfaces import ObjectStore
45
+ from nat.runtime.loader import PluginTypes
46
+ from nat.runtime.loader import discover_and_register_plugins
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class MockBuilder(Builder):
52
+ """
53
+ A lightweight mock builder for tool testing that provides minimal dependencies.
54
+ """
55
+
56
+ def __init__(self):
57
+ self._functions = {}
58
+ self._mocks = {}
59
+
60
+ def mock_function(self, name: str, mock_response: typing.Any):
61
+ """Add a mock function that returns a fixed response."""
62
+ self._mocks[name] = mock_response
63
+
64
+ def mock_function_group(self, name: str, mock_response: typing.Any):
65
+ """Add a mock function group that returns a fixed response."""
66
+ self._mocks[name] = mock_response
67
+
68
+ def mock_llm(self, name: str, mock_response: typing.Any):
69
+ """Add a mock LLM that returns a fixed response."""
70
+ self._mocks[f"llm_{name}"] = mock_response
71
+
72
+ def mock_embedder(self, name: str, mock_response: typing.Any):
73
+ """Add a mock embedder that returns a fixed response."""
74
+ self._mocks[f"embedder_{name}"] = mock_response
75
+
76
+ def mock_memory_client(self, name: str, mock_response: typing.Any):
77
+ """Add a mock memory client that returns a fixed response."""
78
+ self._mocks[f"memory_{name}"] = mock_response
79
+
80
+ def mock_retriever(self, name: str, mock_response: typing.Any):
81
+ """Add a mock retriever that returns a fixed response."""
82
+ self._mocks[f"retriever_{name}"] = mock_response
83
+
84
+ def mock_object_store(self, name: str, mock_response: typing.Any):
85
+ """Add a mock object store that returns a fixed response."""
86
+ self._mocks[f"object_store_{name}"] = mock_response
87
+
88
+ def mock_ttc_strategy(self, name: str, mock_response: typing.Any):
89
+ """Add a mock TTC strategy that returns a fixed response."""
90
+ self._mocks[f"ttc_strategy_{name}"] = mock_response
91
+
92
+ def mock_auth_provider(self, name: str, mock_response: typing.Any):
93
+ """Add a mock auth provider that returns a fixed response."""
94
+ self._mocks[f"auth_provider_{name}"] = mock_response
95
+
96
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
97
+ """Mock implementation (no‑op)."""
98
+ pass
99
+
100
+ async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
101
+ stage_type: StageTypeEnum) -> typing.Any:
102
+ """Return a mock TTC strategy if one is configured."""
103
+ key = f"ttc_strategy_{strategy_name}"
104
+ if key in self._mocks:
105
+ mock_strategy = MagicMock()
106
+ # Provide common callable patterns used in tests
107
+ mock_strategy.invoke = MagicMock(return_value=self._mocks[key])
108
+ mock_strategy.ainvoke = AsyncMock(return_value=self._mocks[key])
109
+ return mock_strategy
110
+ raise ValueError(f"TTC strategy '{strategy_name}' not mocked. Use mock_ttc_strategy() to add it.")
111
+
112
+ async def get_ttc_strategy_config(self,
113
+ strategy_name: str,
114
+ pipeline_type: PipelineTypeEnum,
115
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
116
+ """Mock implementation."""
117
+ return TTCStrategyBaseConfig()
118
+
119
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
120
+ """Mock implementation (no‑op)."""
121
+ pass
122
+
123
+ async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
124
+ """Return a mock auth provider if one is configured."""
125
+ key = f"auth_provider_{auth_provider_name}"
126
+ if key in self._mocks:
127
+ mock_auth = MagicMock()
128
+ mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
129
+ return mock_auth
130
+ raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
131
+
132
+ async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
133
+ """Mock implementation - not used in tool testing."""
134
+ raise NotImplementedError("Mock implementation does not support add_function")
135
+
136
+ async def get_function(self, name: str) -> Function:
137
+ """Return a mock function if one is configured."""
138
+ if name in self._mocks:
139
+ mock_fn = AsyncMock()
140
+ mock_fn.ainvoke = AsyncMock(return_value=self._mocks[name])
141
+ return mock_fn
142
+ raise ValueError(f"Function '{name}' not mocked. Use mock_function() to add it.")
143
+
144
+ def get_function_config(self, name: str) -> FunctionBaseConfig:
145
+ """Mock implementation."""
146
+ return FunctionBaseConfig()
147
+
148
+ async def add_function_group(self, name: str, config: FunctionGroupBaseConfig) -> FunctionGroup:
149
+ """Mock implementation - not used in tool testing."""
150
+ raise NotImplementedError("Mock implementation does not support add_function_group")
151
+
152
+ async def get_function_group(self, name: str) -> FunctionGroup:
153
+ """Return a mock function group if one is configured."""
154
+ if name in self._mocks:
155
+ mock_fn_group = MagicMock(spec=FunctionGroup)
156
+ mock_fn_group.ainvoke = AsyncMock(return_value=self._mocks[name])
157
+ return mock_fn_group
158
+ raise ValueError(f"Function group '{name}' not mocked. Use mock_function_group() to add it.")
159
+
160
+ def get_function_group_config(self, name: str) -> FunctionGroupBaseConfig:
161
+ """Mock implementation."""
162
+ return FunctionGroupBaseConfig()
163
+
164
+ async def set_workflow(self, config: FunctionBaseConfig) -> Function:
165
+ """Mock implementation."""
166
+ mock_fn = AsyncMock()
167
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
168
+ return mock_fn
169
+
170
+ def get_workflow(self) -> Function:
171
+ """Mock implementation."""
172
+ mock_fn = AsyncMock()
173
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
174
+ return mock_fn
175
+
176
+ def get_workflow_config(self) -> FunctionBaseConfig:
177
+ """Mock implementation."""
178
+ return FunctionBaseConfig()
179
+
180
+ async def get_tools(self, tool_names: Sequence[str], wrapper_type) -> list[typing.Any]:
181
+ """Mock implementation."""
182
+ return []
183
+
184
+ async def get_tool(self, fn_name: str, wrapper_type) -> typing.Any:
185
+ """Mock implementation."""
186
+ pass
187
+
188
+ async def add_llm(self, name: str, config) -> None:
189
+ """Mock implementation."""
190
+ pass
191
+
192
+ async def get_llm(self, llm_name: str, wrapper_type):
193
+ """Return a mock LLM if one is configured."""
194
+ key = f"llm_{llm_name}"
195
+ if key in self._mocks:
196
+ mock_llm = MagicMock()
197
+ mock_llm.invoke = MagicMock(return_value=self._mocks[key])
198
+ mock_llm.ainvoke = AsyncMock(return_value=self._mocks[key])
199
+ return mock_llm
200
+ raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
201
+
202
+ def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
203
+ """Mock implementation."""
204
+ return LLMBaseConfig()
205
+
206
+ async def add_embedder(self, name: str, config) -> None:
207
+ """Mock implementation."""
208
+ pass
209
+
210
+ async def get_embedder(self, embedder_name: str, wrapper_type):
211
+ """Return a mock embedder if one is configured."""
212
+ key = f"embedder_{embedder_name}"
213
+ if key in self._mocks:
214
+ mock_embedder = MagicMock()
215
+ mock_embedder.embed_query = MagicMock(return_value=self._mocks[key])
216
+ mock_embedder.embed_documents = MagicMock(return_value=self._mocks[key])
217
+ return mock_embedder
218
+ raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
219
+
220
+ def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
221
+ """Mock implementation."""
222
+ return EmbedderBaseConfig()
223
+
224
+ async def add_memory_client(self, name: str, config) -> MemoryEditor:
225
+ return MagicMock(spec=MemoryEditor)
226
+
227
+ async def get_memory_client(self, memory_name: str) -> MemoryEditor:
228
+ """Return a mock memory client if one is configured."""
229
+ key = f"memory_{memory_name}"
230
+ if key in self._mocks:
231
+ mock_memory = MagicMock()
232
+ mock_memory.add = AsyncMock(return_value=self._mocks[key])
233
+ mock_memory.search = AsyncMock(return_value=self._mocks[key])
234
+ return mock_memory
235
+ raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
236
+
237
+ def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
238
+ """Mock implementation."""
239
+ return MemoryBaseConfig()
240
+
241
+ async def add_retriever(self, name: str, config) -> None:
242
+ """Mock implementation."""
243
+ pass
244
+
245
+ async def get_retriever(self, retriever_name: str, wrapper_type=None):
246
+ """Return a mock retriever if one is configured."""
247
+ key = f"retriever_{retriever_name}"
248
+ if key in self._mocks:
249
+ mock_retriever = MagicMock()
250
+ mock_retriever.retrieve = AsyncMock(return_value=self._mocks[key])
251
+ return mock_retriever
252
+ raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
253
+
254
+ async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
255
+ """Mock implementation."""
256
+ return RetrieverBaseConfig()
257
+
258
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> ObjectStore:
259
+ """Mock implementation for object store."""
260
+ return MagicMock(spec=ObjectStore)
261
+
262
+ async def get_object_store_client(self, object_store_name: str) -> ObjectStore:
263
+ """Return a mock object store client if one is configured."""
264
+ key = f"object_store_{object_store_name}"
265
+ if key in self._mocks:
266
+ mock_object_store = MagicMock()
267
+ mock_object_store.put_object = AsyncMock(return_value=self._mocks[key])
268
+ mock_object_store.get_object = AsyncMock(return_value=self._mocks[key])
269
+ mock_object_store.delete_object = AsyncMock(return_value=self._mocks[key])
270
+ mock_object_store.list_objects = AsyncMock(return_value=self._mocks[key])
271
+ return mock_object_store
272
+ raise ValueError(f"Object store '{object_store_name}' not mocked. Use mock_object_store() to add it.")
273
+
274
+ def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
275
+ """Mock implementation for object store config."""
276
+ return ObjectStoreBaseConfig()
277
+
278
+ def get_user_manager(self):
279
+ """Mock implementation."""
280
+ mock_user = MagicMock()
281
+ mock_user.get_id = MagicMock(return_value="test_user")
282
+ return mock_user
283
+
284
+ def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
285
+ """Mock implementation."""
286
+ return FunctionDependencies()
287
+
288
+ def get_function_group_dependencies(self, fn_name: str) -> FunctionDependencies:
289
+ """Mock implementation."""
290
+ return FunctionDependencies()
291
+
292
+
293
+ class ToolTestRunner:
294
+ """
295
+ A test runner that enables isolated testing of NAT tools without requiring
296
+ full workflow setup, LLMs, or complex dependencies.
297
+
298
+ Usage:
299
+ runner = ToolTestRunner()
300
+
301
+ # Test a tool with minimal setup
302
+ result = await runner.test_tool(
303
+ config_type=MyToolConfig,
304
+ config_params={"param1": "value1"},
305
+ input_data="test input"
306
+ )
307
+
308
+ # Test a tool with mocked dependencies
309
+ async with runner.with_mocks() as mock_builder:
310
+ mock_builder.mock_llm("my_llm", "mocked response")
311
+ result = await runner.test_tool(
312
+ config_type=MyToolConfig,
313
+ config_params={"llm_name": "my_llm"},
314
+ input_data="test input"
315
+ )
316
+ """
317
+
318
+ def __init__(self):
319
+ self._ensure_plugins_loaded()
320
+
321
+ def _ensure_plugins_loaded(self):
322
+ """Ensure all plugins are loaded for tool registration."""
323
+ discover_and_register_plugins(PluginTypes.CONFIG_OBJECT)
324
+
325
+ async def test_tool(self,
326
+ config_type: type[FunctionBaseConfig],
327
+ config_params: dict[str, typing.Any] | None = None,
328
+ input_data: typing.Any = None,
329
+ expected_output: typing.Any = None,
330
+ **kwargs) -> typing.Any:
331
+ """
332
+ Test a tool in isolation with minimal setup.
333
+
334
+ Args:
335
+ config_type: The tool configuration class
336
+ config_params: Parameters to pass to the config constructor
337
+ input_data: Input data to pass to the tool
338
+ expected_output: Expected output for assertion (optional)
339
+ **kwargs: Additional parameters
340
+
341
+ Returns:
342
+ The tool's output
343
+
344
+ Raises:
345
+ AssertionError: If expected_output is provided and doesn't match
346
+ ValueError: If tool registration or execution fails
347
+ """
348
+ config_params = config_params or {}
349
+
350
+ # Create tool configuration
351
+ config = config_type(**config_params)
352
+
353
+ # Get the registered tool function
354
+ registry = GlobalTypeRegistry.get()
355
+ try:
356
+ tool_registration = registry.get_function(config_type)
357
+ except KeyError:
358
+ raise ValueError(
359
+ f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
360
+ )
361
+
362
+ # Create a mock builder for dependencies
363
+ mock_builder = MockBuilder()
364
+
365
+ # Build the tool function
366
+ async with tool_registration.build_fn(config, mock_builder) as tool_result:
367
+
368
+ # Handle different tool result types
369
+ if isinstance(tool_result, Function):
370
+ tool_function = tool_result
371
+ elif isinstance(tool_result, FunctionInfo):
372
+ # Extract the actual function from FunctionInfo
373
+ if tool_result.single_fn:
374
+ tool_function = tool_result.single_fn
375
+ elif tool_result.stream_fn:
376
+ tool_function = tool_result.stream_fn
377
+ else:
378
+ raise ValueError("Tool function not found in FunctionInfo")
379
+ elif callable(tool_result):
380
+ tool_function = tool_result
381
+ else:
382
+ raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
383
+
384
+ # Execute the tool
385
+ if input_data is not None:
386
+ if isinstance(tool_function, Function):
387
+ result = await tool_function.ainvoke(input_data)
388
+ elif asyncio.iscoroutinefunction(tool_function):
389
+ result = await tool_function(input_data)
390
+ else:
391
+ result = tool_function(input_data)
392
+ elif isinstance(tool_function, Function):
393
+ # Function objects require input, so pass None if no input_data
394
+ result = await tool_function.ainvoke(None)
395
+ elif asyncio.iscoroutinefunction(tool_function):
396
+ result = await tool_function()
397
+ else:
398
+ result = tool_function()
399
+
400
+ # Assert expected output if provided
401
+ if expected_output is not None:
402
+ assert result == expected_output, f"Expected {expected_output}, got {result}"
403
+
404
+ return result
405
+
406
+ @asynccontextmanager
407
+ async def with_mocks(self):
408
+ """
409
+ Context manager that provides a mock builder for setting up dependencies.
410
+
411
+ Usage:
412
+ async with runner.with_mocks() as mock_builder:
413
+ mock_builder.mock_llm("my_llm", "mocked response")
414
+ result = await runner.test_tool_with_builder(
415
+ config_type=MyToolConfig,
416
+ builder=mock_builder,
417
+ input_data="test input"
418
+ )
419
+ """
420
+ mock_builder = MockBuilder()
421
+ try:
422
+ yield mock_builder
423
+ finally:
424
+ pass
425
+
426
+ async def test_tool_with_builder(
427
+ self,
428
+ config_type: type[FunctionBaseConfig],
429
+ builder: MockBuilder,
430
+ config_params: dict[str, typing.Any] | None = None,
431
+ input_data: typing.Any = None,
432
+ expected_output: typing.Any = None,
433
+ ) -> typing.Any:
434
+ """
435
+ Test a tool with a pre-configured mock builder.
436
+
437
+ Args:
438
+ config_type: The tool configuration class
439
+ builder: Pre-configured MockBuilder with mocked dependencies
440
+ config_params: Parameters to pass to the config constructor
441
+ input_data: Input data to pass to the tool
442
+ expected_output: Expected output for assertion (optional)
443
+
444
+ Returns:
445
+ The tool's output
446
+ """
447
+ config_params = config_params or {}
448
+
449
+ # Create tool configuration
450
+ config = config_type(**config_params)
451
+
452
+ # Get the registered tool function
453
+ registry = GlobalTypeRegistry.get()
454
+ try:
455
+ tool_registration = registry.get_function(config_type)
456
+ except KeyError:
457
+ raise ValueError(
458
+ f"Tool {config_type} is not registered. Make sure it's imported and registered with @register_function."
459
+ )
460
+
461
+ # Build the tool function with the provided builder
462
+ async with tool_registration.build_fn(config, builder) as tool_result:
463
+
464
+ # Handle different tool result types (same as above)
465
+ if isinstance(tool_result, Function):
466
+ tool_function = tool_result
467
+ elif isinstance(tool_result, FunctionInfo):
468
+ if tool_result.single_fn:
469
+ tool_function = tool_result.single_fn
470
+ elif tool_result.stream_fn:
471
+ tool_function = tool_result.stream_fn
472
+ else:
473
+ raise ValueError("Tool function not found in FunctionInfo")
474
+ elif callable(tool_result):
475
+ tool_function = tool_result
476
+ else:
477
+ raise ValueError(f"Unexpected tool result type: {type(tool_result)}")
478
+
479
+ # Execute the tool
480
+ if input_data is not None:
481
+ if isinstance(tool_function, Function):
482
+ result = await tool_function.ainvoke(input_data)
483
+ else:
484
+ maybe_result = tool_function(input_data)
485
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
486
+ elif isinstance(tool_function, Function):
487
+ # Function objects require input, so pass None if no input_data
488
+ result = await tool_function.ainvoke(None)
489
+ else:
490
+ maybe_result = tool_function()
491
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
492
+
493
+ # Assert expected output if provided
494
+ if expected_output is not None:
495
+ assert result == expected_output, f"Expected {expected_output}, got {result}"
496
+
497
+ return result
498
+
499
+
500
+ @asynccontextmanager
501
+ async def with_mocked_dependencies():
502
+ """
503
+ Convenience context manager for testing tools with mocked dependencies.
504
+
505
+ Usage:
506
+ async with with_mocked_dependencies() as (runner, mock_builder):
507
+ mock_builder.mock_llm("my_llm", "mocked response")
508
+ result = await runner.test_tool_with_builder(
509
+ config_type=MyToolConfig,
510
+ builder=mock_builder,
511
+ input_data="test input"
512
+ )
513
+ """
514
+ runner = ToolTestRunner()
515
+ async with runner.with_mocks() as mock_builder:
516
+ yield runner, mock_builder
nat/test/utils.py ADDED
@@ -0,0 +1,155 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib.resources
17
+ import inspect
18
+ import json
19
+ import subprocess
20
+ import typing
21
+ from contextlib import asynccontextmanager
22
+ from pathlib import Path
23
+
24
+ if typing.TYPE_CHECKING:
25
+ from collections.abc import AsyncIterator
26
+
27
+ from httpx import AsyncClient
28
+
29
+ from nat.data_models.config import Config
30
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
31
+ from nat.utils.type_utils import StrPath
32
+
33
+
34
+ def locate_repo_root() -> Path:
35
+ result = subprocess.run(["git", "rev-parse", "--show-toplevel"], check=False, capture_output=True, text=True)
36
+ assert result.returncode == 0, f"Failed to get git root: {result.stderr}"
37
+ return Path(result.stdout.strip())
38
+
39
+
40
+ def locate_example_src_dir(example_config_class: type) -> Path:
41
+ """
42
+ Locate the example src directory for an example's config class.
43
+ """
44
+ package_name = inspect.getmodule(example_config_class).__package__
45
+ return importlib.resources.files(package_name)
46
+
47
+
48
+ def locate_example_dir(example_config_class: type) -> Path:
49
+ """
50
+ Locate the example directory for an example's config class.
51
+ """
52
+ src_dir = locate_example_src_dir(example_config_class)
53
+ example_dir = src_dir.parent.parent
54
+ return example_dir
55
+
56
+
57
+ def locate_example_config(example_config_class: type,
58
+ config_file: str = "config.yml",
59
+ assert_exists: bool = True) -> Path:
60
+ """
61
+ Locate the example config file for an example's config class, assumes the example contains a 'configs' directory
62
+ """
63
+ example_dir = locate_example_src_dir(example_config_class)
64
+ config_path = example_dir.joinpath("configs", config_file).absolute()
65
+ if assert_exists:
66
+ assert config_path.exists(), f"Config file {config_path} does not exist"
67
+
68
+ return config_path
69
+
70
+
71
+ async def run_workflow(
72
+ *,
73
+ config: "Config | None" = None,
74
+ config_file: "StrPath | None" = None,
75
+ question: str,
76
+ expected_answer: str,
77
+ assert_expected_answer: bool = True,
78
+ ) -> str:
79
+ from nat.builder.workflow_builder import WorkflowBuilder
80
+ from nat.runtime.loader import load_config
81
+ from nat.runtime.session import SessionManager
82
+
83
+ if config is None:
84
+ assert config_file is not None, "Either config_file or config must be provided"
85
+ assert Path(config_file).exists(), f"Config file {config_file} does not exist"
86
+ config = load_config(config_file)
87
+
88
+ async with WorkflowBuilder.from_config(config=config) as workflow_builder:
89
+ workflow = SessionManager(await workflow_builder.build())
90
+ async with workflow.run(question) as runner:
91
+ result = await runner.result(to_type=str)
92
+
93
+ if assert_expected_answer:
94
+ assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
95
+
96
+ return result
97
+
98
+
99
+ @asynccontextmanager
100
+ async def build_nat_client(
101
+ config: "Config",
102
+ worker_class: "type[FastApiFrontEndPluginWorker] | None" = None) -> "AsyncIterator[AsyncClient]":
103
+ """
104
+ Build a NAT client for testing purposes.
105
+
106
+ Creates a test client with an ASGI transport for the specified configuration.
107
+ The client is backed by a FastAPI application built from the provided worker class.
108
+
109
+ Args:
110
+ config: The NAT configuration to use for building the client.
111
+ worker_class: Optional worker class to use. Defaults to FastApiFrontEndPluginWorker.
112
+
113
+ Yields:
114
+ An AsyncClient instance configured for testing.
115
+ """
116
+ from asgi_lifespan import LifespanManager
117
+ from httpx import ASGITransport
118
+ from httpx import AsyncClient
119
+
120
+ from nat.front_ends.fastapi.fastapi_front_end_plugin_worker import FastApiFrontEndPluginWorker
121
+
122
+ if worker_class is None:
123
+ worker_class = FastApiFrontEndPluginWorker
124
+
125
+ worker = worker_class(config)
126
+ app = worker.build_app()
127
+
128
+ async with LifespanManager(app):
129
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
130
+ yield client
131
+
132
+
133
+ def validate_workflow_output(workflow_output_file: Path) -> None:
134
+ """
135
+ Validate the contents of the workflow output file.
136
+ WIP: output format should be published as a schema and this validation should be done against that schema.
137
+ """
138
+ # Ensure the workflow_output.json file was created
139
+ assert workflow_output_file.exists(), "The workflow_output.json file was not created"
140
+
141
+ # Read and validate the workflow_output.json file
142
+ try:
143
+ with open(workflow_output_file, encoding="utf-8") as f:
144
+ result_json = json.load(f)
145
+ except json.JSONDecodeError as err:
146
+ raise RuntimeError("Failed to parse workflow_output.json as valid JSON") from err
147
+
148
+ assert isinstance(result_json, list), "The workflow_output.json file is not a list"
149
+ assert len(result_json) > 0, "The workflow_output.json file is empty"
150
+ assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries"
151
+
152
+ # Ensure required keys exist
153
+ required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"]
154
+ for key in required_keys:
155
+ assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"