nvidia-nat-test 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,18 +14,29 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import asyncio
17
+ import inspect
17
18
  import logging
18
19
  import typing
19
20
  from contextlib import asynccontextmanager
20
21
  from unittest.mock import AsyncMock
21
22
  from unittest.mock import MagicMock
22
23
 
24
+ from nat.authentication.interfaces import AuthProviderBase
23
25
  from nat.builder.builder import Builder
24
26
  from nat.builder.function import Function
25
27
  from nat.builder.function_info import FunctionInfo
26
28
  from nat.cli.type_registry import GlobalTypeRegistry
29
+ from nat.data_models.authentication import AuthProviderBaseConfig
30
+ from nat.data_models.embedder import EmbedderBaseConfig
27
31
  from nat.data_models.function import FunctionBaseConfig
32
+ from nat.data_models.function_dependencies import FunctionDependencies
33
+ from nat.data_models.llm import LLMBaseConfig
34
+ from nat.data_models.memory import MemoryBaseConfig
28
35
  from nat.data_models.object_store import ObjectStoreBaseConfig
36
+ from nat.data_models.retriever import RetrieverBaseConfig
37
+ from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
38
+ from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
39
+ from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
29
40
  from nat.object_store.interfaces import ObjectStore
30
41
  from nat.runtime.loader import PluginTypes
31
42
  from nat.runtime.loader import discover_and_register_plugins
@@ -70,14 +81,16 @@ class MockBuilder(Builder):
70
81
  """Add a mock TTC strategy that returns a fixed response."""
71
82
  self._mocks[f"ttc_strategy_{name}"] = mock_response
72
83
 
73
- async def add_ttc_strategy(self, name: str, config):
84
+ def mock_auth_provider(self, name: str, mock_response: typing.Any):
85
+ """Add a mock auth provider that returns a fixed response."""
86
+ self._mocks[f"auth_provider_{name}"] = mock_response
87
+
88
+ async def add_ttc_strategy(self, name: str, config: TTCStrategyBaseConfig) -> None:
74
89
  """Mock implementation (no‑op)."""
75
90
  pass
76
91
 
77
- async def get_ttc_strategy(self,
78
- strategy_name: str,
79
- pipeline_type: typing.Any = None,
80
- stage_type: typing.Any = None):
92
+ async def get_ttc_strategy(self, strategy_name: str, pipeline_type: PipelineTypeEnum,
93
+ stage_type: StageTypeEnum) -> typing.Any:
81
94
  """Return a mock TTC strategy if one is configured."""
82
95
  key = f"ttc_strategy_{strategy_name}"
83
96
  if key in self._mocks:
@@ -90,11 +103,24 @@ class MockBuilder(Builder):
90
103
 
91
104
  async def get_ttc_strategy_config(self,
92
105
  strategy_name: str,
93
- pipeline_type: typing.Any = None,
94
- stage_type: typing.Any = None):
106
+ pipeline_type: PipelineTypeEnum,
107
+ stage_type: StageTypeEnum) -> TTCStrategyBaseConfig:
95
108
  """Mock implementation."""
109
+ return TTCStrategyBaseConfig()
110
+
111
+ async def add_auth_provider(self, name: str, config: AuthProviderBaseConfig) -> None:
112
+ """Mock implementation (no‑op)."""
96
113
  pass
97
114
 
115
+ async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
116
+ """Return a mock auth provider if one is configured."""
117
+ key = f"auth_provider_{auth_provider_name}"
118
+ if key in self._mocks:
119
+ mock_auth = MagicMock()
120
+ mock_auth.authenticate = AsyncMock(return_value=self._mocks[key])
121
+ return mock_auth
122
+ raise ValueError(f"Auth provider '{auth_provider_name}' not mocked. Use mock_auth_provider() to add it.")
123
+
98
124
  async def add_function(self, name: str, config: FunctionBaseConfig) -> Function:
99
125
  """Mock implementation - not used in tool testing."""
100
126
  raise NotImplementedError("Mock implementation does not support add_function")
@@ -109,25 +135,29 @@ class MockBuilder(Builder):
109
135
 
110
136
  def get_function_config(self, name: str) -> FunctionBaseConfig:
111
137
  """Mock implementation."""
112
- pass
138
+ return FunctionBaseConfig()
113
139
 
114
140
  async def set_workflow(self, config: FunctionBaseConfig) -> Function:
115
141
  """Mock implementation."""
116
- pass
142
+ mock_fn = AsyncMock()
143
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
144
+ return mock_fn
117
145
 
118
146
  def get_workflow(self) -> Function:
119
147
  """Mock implementation."""
120
- pass
148
+ mock_fn = AsyncMock()
149
+ mock_fn.ainvoke = AsyncMock(return_value="mock_workflow_result")
150
+ return mock_fn
121
151
 
122
152
  def get_workflow_config(self) -> FunctionBaseConfig:
123
153
  """Mock implementation."""
124
- pass
154
+ return FunctionBaseConfig()
125
155
 
126
156
  def get_tool(self, fn_name: str, wrapper_type):
127
157
  """Mock implementation."""
128
158
  pass
129
159
 
130
- async def add_llm(self, name: str, config):
160
+ async def add_llm(self, name: str, config) -> None:
131
161
  """Mock implementation."""
132
162
  pass
133
163
 
@@ -141,11 +171,11 @@ class MockBuilder(Builder):
141
171
  return mock_llm
142
172
  raise ValueError(f"LLM '{llm_name}' not mocked. Use mock_llm() to add it.")
143
173
 
144
- def get_llm_config(self, llm_name: str):
174
+ def get_llm_config(self, llm_name: str) -> LLMBaseConfig:
145
175
  """Mock implementation."""
146
- pass
176
+ return LLMBaseConfig()
147
177
 
148
- async def add_embedder(self, name: str, config):
178
+ async def add_embedder(self, name: str, config) -> None:
149
179
  """Mock implementation."""
150
180
  pass
151
181
 
@@ -159,11 +189,11 @@ class MockBuilder(Builder):
159
189
  return mock_embedder
160
190
  raise ValueError(f"Embedder '{embedder_name}' not mocked. Use mock_embedder() to add it.")
161
191
 
162
- def get_embedder_config(self, embedder_name: str):
192
+ def get_embedder_config(self, embedder_name: str) -> EmbedderBaseConfig:
163
193
  """Mock implementation."""
164
- pass
194
+ return EmbedderBaseConfig()
165
195
 
166
- async def add_memory_client(self, name: str, config):
196
+ async def add_memory_client(self, name: str, config) -> None:
167
197
  """Mock implementation."""
168
198
  pass
169
199
 
@@ -177,11 +207,11 @@ class MockBuilder(Builder):
177
207
  return mock_memory
178
208
  raise ValueError(f"Memory client '{memory_name}' not mocked. Use mock_memory_client() to add it.")
179
209
 
180
- def get_memory_client_config(self, memory_name: str):
210
+ def get_memory_client_config(self, memory_name: str) -> MemoryBaseConfig:
181
211
  """Mock implementation."""
182
- pass
212
+ return MemoryBaseConfig()
183
213
 
184
- async def add_retriever(self, name: str, config):
214
+ async def add_retriever(self, name: str, config) -> None:
185
215
  """Mock implementation."""
186
216
  pass
187
217
 
@@ -194,11 +224,11 @@ class MockBuilder(Builder):
194
224
  return mock_retriever
195
225
  raise ValueError(f"Retriever '{retriever_name}' not mocked. Use mock_retriever() to add it.")
196
226
 
197
- async def get_retriever_config(self, retriever_name: str):
227
+ async def get_retriever_config(self, retriever_name: str) -> RetrieverBaseConfig:
198
228
  """Mock implementation."""
199
- pass
229
+ return RetrieverBaseConfig()
200
230
 
201
- async def add_object_store(self, name: str, config: ObjectStoreBaseConfig):
231
+ async def add_object_store(self, name: str, config: ObjectStoreBaseConfig) -> None:
202
232
  """Mock implementation for object store."""
203
233
  pass
204
234
 
@@ -216,7 +246,7 @@ class MockBuilder(Builder):
216
246
 
217
247
  def get_object_store_config(self, object_store_name: str) -> ObjectStoreBaseConfig:
218
248
  """Mock implementation for object store config."""
219
- pass
249
+ return ObjectStoreBaseConfig()
220
250
 
221
251
  def get_user_manager(self):
222
252
  """Mock implementation."""
@@ -224,9 +254,9 @@ class MockBuilder(Builder):
224
254
  mock_user.get_id = MagicMock(return_value="test_user")
225
255
  return mock_user
226
256
 
227
- def get_function_dependencies(self, fn_name: str):
257
+ def get_function_dependencies(self, fn_name: str) -> FunctionDependencies:
228
258
  """Mock implementation."""
229
- pass
259
+ return FunctionDependencies()
230
260
 
231
261
 
232
262
  class ToolTestRunner:
@@ -322,10 +352,15 @@ class ToolTestRunner:
322
352
 
323
353
  # Execute the tool
324
354
  if input_data is not None:
325
- if asyncio.iscoroutinefunction(tool_function):
355
+ if isinstance(tool_function, Function):
356
+ result = await tool_function.ainvoke(input_data)
357
+ elif asyncio.iscoroutinefunction(tool_function):
326
358
  result = await tool_function(input_data)
327
359
  else:
328
360
  result = tool_function(input_data)
361
+ elif isinstance(tool_function, Function):
362
+ # Function objects require input, so pass None if no input_data
363
+ result = await tool_function.ainvoke(None)
329
364
  elif asyncio.iscoroutinefunction(tool_function):
330
365
  result = await tool_function()
331
366
  else:
@@ -401,8 +436,8 @@ class ToolTestRunner:
401
436
  elif isinstance(tool_result, FunctionInfo):
402
437
  if tool_result.single_fn:
403
438
  tool_function = tool_result.single_fn
404
- elif tool_result.streaming_fn:
405
- tool_function = tool_result.streaming_fn
439
+ elif tool_result.stream_fn:
440
+ tool_function = tool_result.stream_fn
406
441
  else:
407
442
  raise ValueError("Tool function not found in FunctionInfo")
408
443
  elif callable(tool_result):
@@ -412,14 +447,17 @@ class ToolTestRunner:
412
447
 
413
448
  # Execute the tool
414
449
  if input_data is not None:
415
- if asyncio.iscoroutinefunction(tool_function):
416
- result = await tool_function(input_data)
450
+ if isinstance(tool_function, Function):
451
+ result = await tool_function.ainvoke(input_data)
417
452
  else:
418
- result = tool_function(input_data)
419
- elif asyncio.iscoroutinefunction(tool_function):
420
- result = await tool_function()
453
+ maybe_result = tool_function(input_data)
454
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
455
+ elif isinstance(tool_function, Function):
456
+ # Function objects require input, so pass None if no input_data
457
+ result = await tool_function.ainvoke(None)
421
458
  else:
422
- result = tool_function()
459
+ maybe_result = tool_function()
460
+ result = await maybe_result if inspect.isawaitable(maybe_result) else maybe_result
423
461
 
424
462
  # Assert expected output if provided
425
463
  if expected_output is not None:
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-test
3
- Version: 1.3.0a20250827
3
+ Version: 1.3.0a20250829
4
4
  Summary: Testing utilities for NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents
6
6
  Classifier: Programming Language :: Python
7
7
  Requires-Python: <3.13,>=3.11
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: nvidia-nat==v1.3.0a20250827
9
+ Requires-Dist: nvidia-nat==v1.3.0a20250829
10
10
  Requires-Dist: langchain-community~=0.3
11
11
  Requires-Dist: pytest~=8.3
12
12
 
@@ -7,9 +7,9 @@ nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2z
7
7
  nat/test/plugin.py,sha256=sMZ7xupCgEpQCuwUksUDYMjbBj0VNlhR6SK5UcOrBzg,6953
8
8
  nat/test/register.py,sha256=fbCLr3E4u8PYMFUlkRNlg53Td2YJ80iQCyxpRIbGId4,859
9
9
  nat/test/test_env_fixtures.py,sha256=zGhFBiZmdDYuj8kOU__RL9LOrood3L58KG8OWXnyOjQ,2375
10
- nat/test/tool_test_runner.py,sha256=EP5Zo_YmfCOeeifzSv1ztx6xF-E-GVGYktN5hkZIBSc,17405
11
- nvidia_nat_test-1.3.0a20250827.dist-info/METADATA,sha256=dhwJgvCZzsH9kqhCygerDpW-9i1cxO3GfbZzR_8uxnA,1466
12
- nvidia_nat_test-1.3.0a20250827.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
- nvidia_nat_test-1.3.0a20250827.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
14
- nvidia_nat_test-1.3.0a20250827.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
15
- nvidia_nat_test-1.3.0a20250827.dist-info/RECORD,,
10
+ nat/test/tool_test_runner.py,sha256=2kCydvJ6LBZ3Lh04e5_Qg-8_kOpcIkNdwGeQOPtadek,20089
11
+ nvidia_nat_test-1.3.0a20250829.dist-info/METADATA,sha256=LgoT5OxCrGslq9U0rp6lPbeVovcKrqGb6nt9-rhNUYA,1466
12
+ nvidia_nat_test-1.3.0a20250829.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ nvidia_nat_test-1.3.0a20250829.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
14
+ nvidia_nat_test-1.3.0a20250829.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
15
+ nvidia_nat_test-1.3.0a20250829.dist-info/RECORD,,