nvidia-nat-test 1.4.0a20251013__py3-none-any.whl → 1.4.0a20251125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nvidia-nat-test might be problematic. Click here for more details.

nat/test/llm.py CHANGED
@@ -96,6 +96,14 @@ async def test_llm_langchain(config: TestLLMConfig, builder: Builder):
96
96
  await chooser.async_sleep()
97
97
  yield chooser.next_response()
98
98
 
99
+ def bind_tools(self, tools: Any, **_kwargs: Any) -> "LangChainTestLLM":
100
+ """Bind tools to the LLM. Returns self to maintain fluent interface."""
101
+ return self
102
+
103
+ def bind(self, **_kwargs: Any) -> "LangChainTestLLM":
104
+ """Bind additional parameters to the LLM. Returns self to maintain fluent interface."""
105
+ return self
106
+
99
107
  yield LangChainTestLLM()
100
108
 
101
109
 
nat/test/plugin.py CHANGED
@@ -14,13 +14,23 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ import random
17
18
  import subprocess
19
+ import time
20
+ import types
18
21
  import typing
22
+ from collections.abc import AsyncGenerator
23
+ from collections.abc import Generator
19
24
  from pathlib import Path
20
25
 
21
26
  import pytest
27
+ import pytest_asyncio
22
28
 
23
29
  if typing.TYPE_CHECKING:
30
+ import galileo.log_streams
31
+ import galileo.projects
32
+ import langsmith.client
33
+
24
34
  from docker.client import DockerClient
25
35
 
26
36
 
@@ -216,10 +226,158 @@ def azure_openai_keys_fixture(fail_missing: bool):
216
226
  yield require_env_variables(
217
227
  varnames=["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
218
228
  reason="Azure integration tests require the `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_ENDPOINT` environment "
219
- "variable to be defined.",
229
+ "variables to be defined.",
230
+ fail_missing=fail_missing)
231
+
232
+
233
+ @pytest.fixture(name="langfuse_keys", scope='session')
234
+ def langfuse_keys_fixture(fail_missing: bool):
235
+ """
236
+ Use for integration tests that require Langfuse credentials.
237
+ """
238
+ yield require_env_variables(
239
+ varnames=["LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"],
240
+ reason="Langfuse integration tests require the `LANGFUSE_PUBLIC_KEY` and `LANGFUSE_SECRET_KEY` environment "
241
+ "variables to be defined.",
242
+ fail_missing=fail_missing)
243
+
244
+
245
+ @pytest.fixture(name="wandb_api_key", scope='session')
246
+ def wandb_api_key_fixture(fail_missing: bool):
247
+ """
248
+ Use for integration tests that require a Weights & Biases API key.
249
+ """
250
+ yield require_env_variables(
251
+ varnames=["WANDB_API_KEY"],
252
+ reason="Weights & Biases integration tests require the `WANDB_API_KEY` environment variable to be defined.",
253
+ fail_missing=fail_missing)
254
+
255
+
256
+ @pytest.fixture(name="weave", scope='session')
257
+ def require_weave_fixture(fail_missing: bool) -> types.ModuleType:
258
+ """
259
+ Use for integration tests that require Weave to be running.
260
+ """
261
+ try:
262
+ import weave
263
+ return weave
264
+ except Exception as e:
265
+ reason = "Weave must be installed to run weave based tests"
266
+ if fail_missing:
267
+ raise RuntimeError(reason) from e
268
+ pytest.skip(reason=reason)
269
+
270
+
271
+ @pytest.fixture(name="langsmith_api_key", scope='session')
272
+ def langsmith_api_key_fixture(fail_missing: bool):
273
+ """
274
+ Use for integration tests that require a LangSmith API key.
275
+ """
276
+ yield require_env_variables(
277
+ varnames=["LANGSMITH_API_KEY"],
278
+ reason="LangSmith integration tests require the `LANGSMITH_API_KEY` environment variable to be defined.",
279
+ fail_missing=fail_missing)
280
+
281
+
282
+ @pytest.fixture(name="langsmith_client")
283
+ def langsmith_client_fixture(langsmith_api_key: str, fail_missing: bool) -> "langsmith.client.Client":
284
+ try:
285
+ import langsmith.client
286
+ client = langsmith.client.Client()
287
+ return client
288
+ except ImportError:
289
+ reason = "LangSmith integration tests require the `langsmith` package to be installed."
290
+ if fail_missing:
291
+ raise RuntimeError(reason)
292
+ pytest.skip(reason=reason)
293
+
294
+
295
+ @pytest.fixture(name="project_name")
296
+ def project_name_fixture() -> str:
297
+ # Create a unique project name for each test run
298
+ return f"nat-e2e-test-{time.time()}-{random.random()}"
299
+
300
+
301
+ @pytest.fixture(name="langsmith_project_name")
302
+ def langsmith_project_name_fixture(langsmith_client: "langsmith.client.Client", project_name: str) -> Generator[str]:
303
+ langsmith_client.create_project(project_name)
304
+ yield project_name
305
+
306
+ langsmith_client.delete_project(project_name=project_name)
307
+
308
+
309
+ @pytest.fixture(name="galileo_api_key", scope='session')
310
+ def galileo_api_key_fixture(fail_missing: bool):
311
+ """
312
+ Use for integration tests that require a Galileo API key.
313
+ """
314
+ yield require_env_variables(
315
+ varnames=["GALILEO_API_KEY"],
316
+ reason="Galileo integration tests require the `GALILEO_API_KEY` environment variable to be defined.",
317
+ fail_missing=fail_missing)
318
+
319
+
320
+ @pytest.fixture(name="galileo_project")
321
+ def galileo_project_fixture(galileo_api_key: str, fail_missing: bool,
322
+ project_name: str) -> Generator["galileo.projects.Project"]:
323
+ """
324
+ Creates a unique Galileo project and deletes it after the test run.
325
+ """
326
+ try:
327
+ import galileo.projects
328
+ project = galileo.projects.create_project(name=project_name)
329
+ yield project
330
+
331
+ galileo.projects.delete_project(id=project.id)
332
+ except ImportError as e:
333
+ reason = "Galileo integration tests require the `galileo` package to be installed."
334
+ if fail_missing:
335
+ raise RuntimeError(reason) from e
336
+ pytest.skip(reason=reason)
337
+
338
+
339
+ @pytest.fixture(name="galileo_log_stream")
340
+ def galileo_log_stream_fixture(galileo_project: "galileo.projects.Project") -> "galileo.log_streams.LogStream":
341
+ """
342
+ Creates a Galileo log stream for integration tests.
343
+
344
+ The log stream is automatically deleted when the associated project is deleted.
345
+ """
346
+ import galileo.log_streams
347
+ return galileo.log_streams.create_log_stream(project_id=galileo_project.id, name="test")
348
+
349
+
350
+ @pytest.fixture(name="catalyst_keys", scope='session')
351
+ def catalyst_keys_fixture(fail_missing: bool):
352
+ """
353
+ Use for integration tests that require RagaAI Catalyst credentials.
354
+ """
355
+ yield require_env_variables(
356
+ varnames=["CATALYST_ACCESS_KEY", "CATALYST_SECRET_KEY"],
357
+ reason="Catalyst integration tests require the `CATALYST_ACCESS_KEY` and `CATALYST_SECRET_KEY` environment "
358
+ "variables to be defined.",
220
359
  fail_missing=fail_missing)
221
360
 
222
361
 
362
+ @pytest.fixture(name="catalyst_project_name")
363
+ def catalyst_project_name_fixture(catalyst_keys) -> str:
364
+ return os.environ.get("NAT_CI_CATALYST_PROJECT_NAME", "nat-e2e")
365
+
366
+
367
+ @pytest.fixture(name="catalyst_dataset_name")
368
+ def catalyst_dataset_name_fixture(catalyst_project_name: str, project_name: str) -> str:
369
+ """
370
+ We can't create and delete projects, but we can create and delete datasets, so use a unique dataset name
371
+ """
372
+ dataset_name = project_name.replace('.', '-')
373
+ yield dataset_name
374
+
375
+ from ragaai_catalyst import Dataset
376
+ ds = Dataset(catalyst_project_name)
377
+ if dataset_name in ds.list_datasets():
378
+ ds.delete_dataset(dataset_name)
379
+
380
+
223
381
  @pytest.fixture(name="require_docker", scope='session')
224
382
  def require_docker_fixture(fail_missing: bool) -> "DockerClient":
225
383
  """
@@ -256,8 +414,20 @@ def root_repo_dir_fixture() -> Path:
256
414
  return locate_repo_root()
257
415
 
258
416
 
259
- @pytest.fixture(name="require_etcd", scope="session")
260
- def require_etcd_fixture(fail_missing: bool = False) -> bool:
417
+ @pytest.fixture(name="examples_dir", scope='session')
418
+ def examples_dir_fixture(root_repo_dir: Path) -> Path:
419
+ return root_repo_dir / "examples"
420
+
421
+
422
+ @pytest.fixture(name="env_without_nat_log_level", scope='function')
423
+ def env_without_nat_log_level_fixture() -> dict[str, str]:
424
+ env = os.environ.copy()
425
+ env.pop("NAT_LOG_LEVEL", None)
426
+ return env
427
+
428
+
429
+ @pytest.fixture(name="etcd_url", scope="session")
430
+ def etcd_url_fixture(fail_missing: bool = False) -> str:
261
431
  """
262
432
  To run these tests, an etcd server must be running
263
433
  """
@@ -265,21 +435,22 @@ def require_etcd_fixture(fail_missing: bool = False) -> bool:
265
435
 
266
436
  host = os.getenv("NAT_CI_ETCD_HOST", "localhost")
267
437
  port = os.getenv("NAT_CI_ETCD_PORT", "2379")
268
- health_url = f"http://{host}:{port}/health"
438
+ url = f"http://{host}:{port}"
439
+ health_url = f"{url}/health"
269
440
 
270
441
  try:
271
442
  response = requests.get(health_url, timeout=5)
272
443
  response.raise_for_status()
273
- return True
444
+ return url
274
445
  except: # noqa: E722
275
- failure_reason = f"Unable to connect to etcd server at {health_url}"
446
+ failure_reason = f"Unable to connect to etcd server at {url}"
276
447
  if fail_missing:
277
448
  raise RuntimeError(failure_reason)
278
449
  pytest.skip(reason=failure_reason)
279
450
 
280
451
 
281
452
  @pytest.fixture(name="milvus_uri", scope="session")
282
- def milvus_uri_fixture(require_etcd: bool, fail_missing: bool = False) -> str:
453
+ def milvus_uri_fixture(etcd_url: str, fail_missing: bool = False) -> str:
283
454
  """
284
455
  To run these tests, a Milvus server must be running
285
456
  """
@@ -343,3 +514,361 @@ def populate_milvus_fixture(milvus_uri: str, root_repo_dir: Path):
343
514
  "wikipedia_docs"
344
515
  ],
345
516
  check=True)
517
+
518
+
519
+ @pytest.fixture(name="require_nest_asyncio", scope="session", autouse=True)
520
+ def require_nest_asyncio_fixture():
521
+ """
522
+ Some tests require the nest_asyncio2 patch to be applied to allow nested event loops, calling
523
+ `nest_asyncio2.apply()` more than once is a no-op. However we need to ensure that the nest_asyncio2 patch is
524
+ applied prior to the older nest_asyncio patch is applied. Requiring us to ensure that any library which will apply
525
+ the patch on import is lazily imported.
526
+ """
527
+ import nest_asyncio2
528
+ try:
529
+ nest_asyncio2.apply(error_on_mispatched=True)
530
+ except RuntimeError as e:
531
+ raise RuntimeError(
532
+ "nest_asyncio2 fixture called but asyncio is already patched, most likely this is due to the nest_asyncio "
533
+ "being applied first, which is not compatible with Python 3.12+. Please ensure that any libraries which "
534
+ "apply nest_asyncio on import are lazily imported.") from e
535
+
536
+
537
+ @pytest.fixture(name="phoenix_url", scope="session")
538
+ def phoenix_url_fixture(fail_missing: bool) -> str:
539
+ """
540
+ To run these tests, a phoenix server must be running.
541
+ The phoenix server can be started by running the following command:
542
+ docker run -p 6006:6006 -p 4317:4317 arizephoenix/phoenix:latest
543
+ """
544
+ import requests
545
+
546
+ url = os.getenv("NAT_CI_PHOENIX_URL", "http://localhost:6006")
547
+ try:
548
+ response = requests.get(url, timeout=5)
549
+ response.raise_for_status()
550
+
551
+ return url
552
+ except Exception as e:
553
+ reason = f"Unable to connect to Phoenix server at {url}: {e}"
554
+ if fail_missing:
555
+ raise RuntimeError(reason)
556
+ pytest.skip(reason=reason)
557
+
558
+
559
+ @pytest.fixture(name="phoenix_trace_url", scope="session")
560
+ def phoenix_trace_url_fixture(phoenix_url: str) -> str:
561
+ """
562
+ Some of our tools expect the base url provided by the phoenix_url fixture, however the
563
+ general.telemetry.tracing["phoenix"].endpoint expects the trace url which is what this fixture provides.
564
+ """
565
+ return f"{phoenix_url}/v1/traces"
566
+
567
+
568
+ @pytest.fixture(name="redis_server", scope="session")
569
+ def fixture_redis_server(fail_missing: bool) -> Generator[dict[str, str | int]]:
570
+ """Fixture to safely skip redis based tests if redis is not running"""
571
+ host = os.environ.get("NAT_CI_REDIS_HOST", "localhost")
572
+ port = int(os.environ.get("NAT_CI_REDIS_PORT", "6379"))
573
+ db = int(os.environ.get("NAT_CI_REDIS_DB", "0"))
574
+ password = os.environ.get("REDIS_PASSWORD", "redis")
575
+ bucket_name = os.environ.get("NAT_CI_REDIS_BUCKET_NAME", "test")
576
+
577
+ try:
578
+ import redis
579
+ client = redis.Redis(host=host, port=port, db=db, password=password)
580
+ if not client.ping():
581
+ raise RuntimeError("Failed to connect to Redis")
582
+ yield {"host": host, "port": port, "db": db, "bucket_name": bucket_name, "password": password}
583
+ except ImportError:
584
+ if fail_missing:
585
+ raise
586
+ pytest.skip("redis not installed, skipping redis tests")
587
+ except Exception as e:
588
+ if fail_missing:
589
+ raise
590
+ pytest.skip(f"Error connecting to Redis server: {e}, skipping redis tests")
591
+
592
+
593
+ @pytest_asyncio.fixture(name="mysql_server", scope="session")
594
+ async def fixture_mysql_server(fail_missing: bool) -> AsyncGenerator[dict[str, str | int]]:
595
+ """Fixture to safely skip MySQL based tests if MySQL is not running"""
596
+ host = os.environ.get('NAT_CI_MYSQL_HOST', '127.0.0.1')
597
+ port = int(os.environ.get('NAT_CI_MYSQL_PORT', '3306'))
598
+ user = os.environ.get('NAT_CI_MYSQL_USER', 'root')
599
+ password = os.environ.get('MYSQL_ROOT_PASSWORD', 'my_password')
600
+ bucket_name = os.environ.get('NAT_CI_MYSQL_BUCKET_NAME', 'test')
601
+ try:
602
+ import aiomysql
603
+ conn = await aiomysql.connect(host=host, port=port, user=user, password=password)
604
+ yield {"host": host, "port": port, "username": user, "password": password, "bucket_name": bucket_name}
605
+ conn.close()
606
+ except ImportError:
607
+ if fail_missing:
608
+ raise
609
+ pytest.skip("aiomysql not installed, skipping MySQL tests")
610
+ except Exception as e:
611
+ if fail_missing:
612
+ raise
613
+ pytest.skip(f"Error connecting to MySQL server: {e}, skipping MySQL tests")
614
+
615
+
616
+ @pytest.fixture(name="minio_server", scope="session")
617
+ def minio_server_fixture(fail_missing: bool) -> Generator[dict[str, str | int]]:
618
+ """Fixture to safely skip MinIO based tests if MinIO is not running"""
619
+ host = os.getenv("NAT_CI_MINIO_HOST", "localhost")
620
+ port = int(os.getenv("NAT_CI_MINIO_PORT", "9000"))
621
+ bucket_name = os.getenv("NAT_CI_MINIO_BUCKET_NAME", "test")
622
+ aws_access_key_id = os.getenv("NAT_CI_MINIO_ACCESS_KEY_ID", "minioadmin")
623
+ aws_secret_access_key = os.getenv("NAT_CI_MINIO_SECRET_ACCESS_KEY", "minioadmin")
624
+ endpoint_url = f"http://{host}:{port}"
625
+
626
+ minio_info = {
627
+ "host": host,
628
+ "port": port,
629
+ "bucket_name": bucket_name,
630
+ "endpoint_url": endpoint_url,
631
+ "aws_access_key_id": aws_access_key_id,
632
+ "aws_secret_access_key": aws_secret_access_key,
633
+ }
634
+
635
+ try:
636
+ import botocore.session
637
+ session = botocore.session.get_session()
638
+
639
+ client = session.create_client("s3",
640
+ aws_access_key_id=aws_access_key_id,
641
+ aws_secret_access_key=aws_secret_access_key,
642
+ endpoint_url=endpoint_url)
643
+ client.list_buckets()
644
+ yield minio_info
645
+ except ImportError:
646
+ if fail_missing:
647
+ raise
648
+ pytest.skip("aioboto3 not installed, skipping MinIO tests")
649
+ except Exception as e:
650
+ if fail_missing:
651
+ raise
652
+ else:
653
+ pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
654
+
655
+
656
+ @pytest.fixture(name="langfuse_bucket", scope="session")
657
+ def langfuse_bucket_fixture(fail_missing: bool, minio_server: dict[str, str | int]) -> Generator[str]:
658
+
659
+ bucket_name = os.getenv("NAT_CI_LANGFUSE_BUCKET", "langfuse")
660
+ try:
661
+ import botocore.session
662
+ session = botocore.session.get_session()
663
+
664
+ client = session.create_client("s3",
665
+ aws_access_key_id=minio_server["aws_access_key_id"],
666
+ aws_secret_access_key=minio_server["aws_secret_access_key"],
667
+ endpoint_url=minio_server["endpoint_url"])
668
+
669
+ buckets = client.list_buckets()
670
+ bucket_names = [b['Name'] for b in buckets['Buckets']]
671
+ if bucket_name not in bucket_names:
672
+ client.create_bucket(Bucket=bucket_name)
673
+
674
+ yield bucket_name
675
+ except ImportError:
676
+ if fail_missing:
677
+ raise
678
+ pytest.skip("aioboto3 not installed, skipping MinIO tests")
679
+ except Exception as e:
680
+ if fail_missing:
681
+ raise
682
+ else:
683
+ pytest.skip(f"Error connecting to MinIO server: {e}, skipping MinIO tests")
684
+
685
+
686
+ @pytest.fixture(name="langfuse_url", scope="session")
687
+ def langfuse_url_fixture(fail_missing: bool, langfuse_bucket: str) -> str:
688
+ """
689
+ To run these tests, a langfuse server must be running.
690
+ """
691
+ import requests
692
+
693
+ host = os.getenv("NAT_CI_LANGFUSE_HOST", "localhost")
694
+ port = int(os.getenv("NAT_CI_LANGFUSE_PORT", "3000"))
695
+ url = f"http://{host}:{port}"
696
+ health_endpoint = f"{url}/api/public/health"
697
+ try:
698
+ response = requests.get(health_endpoint, timeout=5)
699
+ response.raise_for_status()
700
+
701
+ return url
702
+ except Exception as e:
703
+ reason = f"Unable to connect to Langfuse server at {url}: {e}"
704
+ if fail_missing:
705
+ raise RuntimeError(reason)
706
+ pytest.skip(reason=reason)
707
+
708
+
709
+ @pytest.fixture(name="langfuse_trace_url", scope="session")
710
+ def langfuse_trace_url_fixture(langfuse_url: str) -> str:
711
+ """
712
+ The langfuse_url fixture provides the base url, however the general.telemetry.tracing["langfuse"].endpoint expects
713
+ the trace url which is what this fixture provides.
714
+ """
715
+ return f"{langfuse_url}/api/public/otel/v1/traces"
716
+
717
+
718
+ @pytest.fixture(name="oauth2_server_url", scope="session")
719
+ def oauth2_server_url_fixture(fail_missing: bool) -> str:
720
+ """
721
+ To run these tests, an oauth2 server must be running.
722
+ """
723
+ import requests
724
+
725
+ host = os.getenv("NAT_CI_OAUTH2_HOST", "localhost")
726
+ port = int(os.getenv("NAT_CI_OAUTH2_PORT", "5001"))
727
+ url = f"http://{host}:{port}"
728
+ try:
729
+ response = requests.get(url, timeout=5)
730
+ response.raise_for_status()
731
+
732
+ return url
733
+ except Exception as e:
734
+ reason = f"Unable to connect to OAuth2 server at {url}: {e}"
735
+ if fail_missing:
736
+ raise RuntimeError(reason)
737
+ pytest.skip(reason=reason)
738
+
739
+
740
+ @pytest.fixture(name="oauth2_client_credentials", scope="session")
741
+ def oauth2_client_credentials_fixture(oauth2_server_url: str, fail_missing: bool) -> dict[str, typing.Any]:
742
+ """
743
+ Fixture to provide OAuth2 client credentials for testing
744
+
745
+ Simulates the steps a user would take in a web browser to create a new OAuth2 client as documented in:
746
+ examples/front_ends/simple_auth/README.md
747
+ """
748
+
749
+ try:
750
+ import requests
751
+ from bs4 import BeautifulSoup
752
+ username = os.getenv("NAT_CI_OAUTH2_CLIENT_USERNAME", "Testy Testerson")
753
+
754
+ # This post request responds with a cookie that we need for future requests and a 302 redirect, the response
755
+ # for the redirected url doesn't contain the cookie, so we disable the redirect here to capture the cookie
756
+ user_create_response = requests.post(oauth2_server_url,
757
+ data=[("username", username)],
758
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
759
+ allow_redirects=False,
760
+ timeout=5)
761
+ user_create_response.raise_for_status()
762
+ cookies = user_create_response.cookies
763
+
764
+ client_create_response = requests.post(f"{oauth2_server_url}/create_client",
765
+ cookies=cookies,
766
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
767
+ data=[
768
+ ("client_name", "test"),
769
+ ("client_uri", "https://test.com"),
770
+ ("scope", "openid profile email"),
771
+ ("redirect_uri", "http://localhost:8000/auth/redirect"),
772
+ ("grant_type", "authorization_code\nrefresh_token"),
773
+ ("response_type", "code"),
774
+ ("token_endpoint_auth_method", "client_secret_post"),
775
+ ],
776
+ timeout=5)
777
+ client_create_response.raise_for_status()
778
+
779
+ # Unfortunately the response is HTML so we need to parse it to get the client ID and secret, which are not
780
+ # locatable via ID tags
781
+ soup = BeautifulSoup(client_create_response.text, 'html.parser')
782
+ strong_tags = soup.find_all('strong')
783
+ i = 0
784
+ client_id = None
785
+ client_secret = None
786
+ while i < len(strong_tags) and None in (client_id, client_secret):
787
+ tag = strong_tags[i]
788
+ contents = "".join(tag.contents)
789
+ if client_id is None and "client_id:" in contents:
790
+ client_id = tag.next_sibling.strip()
791
+ elif client_secret is None and "client_secret:" in contents:
792
+ client_secret = tag.next_sibling.strip()
793
+
794
+ i += 1
795
+
796
+ assert client_id is not None and client_secret is not None, "Failed to parse client credentials from response"
797
+
798
+ return {
799
+ "id": client_id,
800
+ "secret": client_secret,
801
+ "username": username,
802
+ "url": oauth2_server_url,
803
+ "cookies": cookies
804
+ }
805
+
806
+ except Exception as e:
807
+ reason = f"Unable to create OAuth2 client: {e}"
808
+ if fail_missing:
809
+ raise RuntimeError(reason)
810
+ pytest.skip(reason=reason)
811
+
812
+
813
+ @pytest.fixture(name="local_sandbox_url", scope="session")
814
+ def local_sandbox_url_fixture(fail_missing: bool) -> str:
815
+ """Check if sandbox server is running before running tests."""
816
+ import requests
817
+ url = os.environ.get("NAT_CI_SANDBOX_URL", "http://127.0.0.1:6000")
818
+ try:
819
+ response = requests.get(url, timeout=5)
820
+ response.raise_for_status()
821
+ return url
822
+ except Exception:
823
+ reason = (f"Sandbox server is not running at {url}. "
824
+ "Please start it with: cd src/nat/tool/code_execution/local_sandbox && ./start_local_sandbox.sh")
825
+ if fail_missing:
826
+ raise RuntimeError(reason)
827
+ pytest.skip(reason)
828
+
829
+
830
+ @pytest.fixture(name="sandbox_config", scope="session")
831
+ def sandbox_config_fixture(local_sandbox_url: str) -> dict[str, typing.Any]:
832
+ """Configuration for sandbox testing."""
833
+ return {
834
+ "base_url": local_sandbox_url,
835
+ "execute_url": f"{local_sandbox_url.rstrip('/')}/execute",
836
+ "timeout": int(os.environ.get("SANDBOX_TIMEOUT", "30")),
837
+ "connection_timeout": 5
838
+ }
839
+
840
+
841
+ @pytest.fixture(name="piston_url", scope="session")
842
+ def piston_url_fixture(fail_missing: bool) -> str:
843
+ """
844
+ Configuration for piston testing.
845
+
846
+ The public piston server limits usage to five requests per minute.
847
+ """
848
+ import requests
849
+ url = os.environ.get("NAT_CI_PISTON_URL", "https://emkc.org/api/v2/piston")
850
+ try:
851
+ response = requests.get(f"{url.rstrip('/')}/runtimes", timeout=30)
852
+ response.raise_for_status()
853
+ return url
854
+ except Exception:
855
+ reason = (f"Piston server is not running at {url}. "
856
+ "Please start it with: cd src/nat/tool/code_execution/local_sandbox && ./start_local_sandbox.sh")
857
+ if fail_missing:
858
+ raise RuntimeError(reason)
859
+ pytest.skip(reason)
860
+
861
+
862
+ @pytest.fixture(autouse=True, scope="session")
863
+ def import_adk_early():
864
+ """
865
+ Import ADK early to work-around slow import issue (https://github.com/google/adk-python/issues/2433),
866
+ when ADK is imported early it takes about 8 seconds, however if we wait until the `packages/nvidia_nat_adk/tests`
867
+ run the same import will take about 70 seconds.
868
+
869
+ Since ADK is an optional dependency, we will ignore any import errors.
870
+ """
871
+ try:
872
+ import google.adk # noqa: F401
873
+ except ImportError:
874
+ pass
@@ -29,18 +29,21 @@ from nat.builder.function import FunctionGroup
29
29
  from nat.builder.function_info import FunctionInfo
30
30
  from nat.cli.type_registry import GlobalTypeRegistry
31
31
  from nat.data_models.authentication import AuthProviderBaseConfig
32
+ from nat.data_models.component_ref import MiddlewareRef
32
33
  from nat.data_models.embedder import EmbedderBaseConfig
33
34
  from nat.data_models.function import FunctionBaseConfig
34
35
  from nat.data_models.function import FunctionGroupBaseConfig
35
36
  from nat.data_models.function_dependencies import FunctionDependencies
36
37
  from nat.data_models.llm import LLMBaseConfig
37
38
  from nat.data_models.memory import MemoryBaseConfig
39
+ from nat.data_models.middleware import FunctionMiddlewareBaseConfig
38
40
  from nat.data_models.object_store import ObjectStoreBaseConfig
39
41
  from nat.data_models.retriever import RetrieverBaseConfig
40
42
  from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
41
43
  from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
42
44
  from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
43
45
  from nat.memory.interfaces import MemoryEditor
46
+ from nat.middleware import FunctionMiddleware
44
47
  from nat.object_store.interfaces import ObjectStore
45
48
  from nat.runtime.loader import PluginTypes
46
49
  from nat.runtime.loader import discover_and_register_plugins
@@ -289,6 +292,19 @@ class MockBuilder(Builder):
289
292
  """Mock implementation."""
290
293
  return FunctionDependencies()
291
294
 
295
+ async def get_middleware(self, middleware_name: str | MiddlewareRef) -> FunctionMiddleware:
296
+ """Mock implementation."""
297
+ return FunctionMiddleware()
298
+
299
+ def get_middleware_config(self, middleware_name: str | MiddlewareRef) -> FunctionMiddlewareBaseConfig:
300
+ """Mock implementation."""
301
+ return FunctionMiddlewareBaseConfig()
302
+
303
+ async def add_middleware(self, name: str | MiddlewareRef,
304
+ config: FunctionMiddlewareBaseConfig) -> FunctionMiddleware:
305
+ """Mock implementation."""
306
+ return FunctionMiddleware()
307
+
292
308
 
293
309
  class ToolTestRunner:
294
310
  """
nat/test/utils.py CHANGED
@@ -15,6 +15,7 @@
15
15
 
16
16
  import importlib.resources
17
17
  import inspect
18
+ import json
18
19
  import subprocess
19
20
  import typing
20
21
  from contextlib import asynccontextmanager
@@ -67,25 +68,20 @@ def locate_example_config(example_config_class: type,
67
68
  return config_path
68
69
 
69
70
 
70
- async def run_workflow(
71
- config_file: "StrPath | None",
72
- question: str,
73
- expected_answer: str,
74
- assert_expected_answer: bool = True,
75
- config: "Config | None" = None,
76
- ) -> str:
77
- from nat.builder.workflow_builder import WorkflowBuilder
78
- from nat.runtime.loader import load_config
79
- from nat.runtime.session import SessionManager
80
-
81
- if config is None:
82
- assert config_file is not None, "Either config_file or config must be provided"
83
- config = load_config(config_file)
71
+ async def run_workflow(*,
72
+ config: "Config | None" = None,
73
+ config_file: "StrPath | None" = None,
74
+ question: str,
75
+ expected_answer: str,
76
+ assert_expected_answer: bool = True,
77
+ **kwargs) -> str:
78
+ """
79
+ Test specific wrapper for `nat.utils.run_workflow` to run a workflow with a question and validate the expected
80
+ answer. This variant always sets the result type to `str`.
81
+ """
82
+ from nat.utils import run_workflow as nat_run_workflow
84
83
 
85
- async with WorkflowBuilder.from_config(config=config) as workflow_builder:
86
- workflow = SessionManager(await workflow_builder.build())
87
- async with workflow.run(question) as runner:
88
- result = await runner.result(to_type=str)
84
+ result = await nat_run_workflow(config=config, config_file=config_file, prompt=question, to_type=str, **kwargs)
89
85
 
90
86
  if assert_expected_answer:
91
87
  assert expected_answer.lower() in result.lower(), f"Expected '{expected_answer}' in '{result}'"
@@ -125,3 +121,28 @@ async def build_nat_client(
125
121
  async with LifespanManager(app):
126
122
  async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
127
123
  yield client
124
+
125
+
126
+ def validate_workflow_output(workflow_output_file: Path) -> None:
127
+ """
128
+ Validate the contents of the workflow output file.
129
+ WIP: output format should be published as a schema and this validation should be done against that schema.
130
+ """
131
+ # Ensure the workflow_output.json file was created
132
+ assert workflow_output_file.exists(), "The workflow_output.json file was not created"
133
+
134
+ # Read and validate the workflow_output.json file
135
+ try:
136
+ with open(workflow_output_file, encoding="utf-8") as f:
137
+ result_json = json.load(f)
138
+ except json.JSONDecodeError as err:
139
+ raise RuntimeError("Failed to parse workflow_output.json as valid JSON") from err
140
+
141
+ assert isinstance(result_json, list), "The workflow_output.json file is not a list"
142
+ assert len(result_json) > 0, "The workflow_output.json file is empty"
143
+ assert isinstance(result_json[0], dict), "The workflow_output.json file is not a list of dictionaries"
144
+
145
+ # Ensure required keys exist
146
+ required_keys = ["id", "question", "answer", "generated_answer", "intermediate_steps"]
147
+ for key in required_keys:
148
+ assert all(item.get(key) for item in result_json), f"The '{key}' key is missing in workflow_output.json"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-test
3
- Version: 1.4.0a20251013
3
+ Version: 1.4.0a20251125
4
4
  Summary: Testing utilities for NeMo Agent toolkit
5
5
  Author: NVIDIA Corporation
6
6
  Maintainer: NVIDIA Corporation
@@ -16,7 +16,7 @@ Requires-Python: <3.14,>=3.11
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE-3rd-party.txt
18
18
  License-File: LICENSE.md
19
- Requires-Dist: nvidia-nat==v1.4.0a20251013
19
+ Requires-Dist: nvidia-nat==v1.4.0a20251125
20
20
  Requires-Dist: langchain-community~=0.3
21
21
  Requires-Dist: pytest~=8.3
22
22
  Dynamic: license-file
@@ -2,17 +2,17 @@ nat/meta/pypi.md,sha256=LLKJHg5oN1-M9Pqfk3Bmphkk4O2TFsyiixuK5T0Y-gw,1100
2
2
  nat/test/__init__.py,sha256=_RnTJnsUucHvla_nYKqD4O4g8Bz0tcuDRzWk1bEhcy0,875
3
3
  nat/test/embedder.py,sha256=ClDyK1kna4hCBSlz71gK1B-ZjlwcBHTDQRekoNM81Bs,1809
4
4
  nat/test/functions.py,sha256=ZxXVzfaLBGOpR5qtmMrKU7q-M9-vVGGj3Xi5mrw4vHY,3557
5
- nat/test/llm.py,sha256=f6bz6arAQjhjuOKFrLfu_U1LbiyFzQmpM-q8b-WKSrU,9550
5
+ nat/test/llm.py,sha256=dbFoWFrSAlUoKm6QGfS4VJdrhgxwkXzm1oaFd6K7jnM,9926
6
6
  nat/test/memory.py,sha256=xki_A2yiMhEZuQk60K7t04QRqf32nQqnfzD5Iv7fkvw,1456
7
7
  nat/test/object_store_tests.py,sha256=PyJioOtoSzILPq6LuD-sOZ_89PIcgXWZweoHBQpK2zQ,4281
8
- nat/test/plugin.py,sha256=b9DsqeRDYrBA00egilznvNpr_lQmdnkUQilsWX07mTA,11688
8
+ nat/test/plugin.py,sha256=HF25W2YPTiXaoIJggnZTstiTMaspQckvL_thQSseDEc,32434
9
9
  nat/test/register.py,sha256=o1BEA5fyxyFyCxXhQ6ArmtuNpgRyTEfvw6HdBgECPLI,897
10
- nat/test/tool_test_runner.py,sha256=SxavwXHkvCQDl_PUiiiqgvGfexKJJTeBdI5i1qk6AzI,21712
11
- nat/test/utils.py,sha256=wXa9uH7-_HH7eg0bKpBrlVhffYrc2-F2MYc5ZBwSbAQ,4593
12
- nvidia_nat_test-1.4.0a20251013.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
13
- nvidia_nat_test-1.4.0a20251013.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
14
- nvidia_nat_test-1.4.0a20251013.dist-info/METADATA,sha256=L4psga_VNLK3KqM1vdVY39t9NDQOgiVxqBJLFpMGZOI,1925
15
- nvidia_nat_test-1.4.0a20251013.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- nvidia_nat_test-1.4.0a20251013.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
17
- nvidia_nat_test-1.4.0a20251013.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
- nvidia_nat_test-1.4.0a20251013.dist-info/RECORD,,
10
+ nat/test/tool_test_runner.py,sha256=WDwIRo3160raBoEkj1-MgnLSCaaF2Ud_cARRIM3Qdag,22463
11
+ nat/test/utils.py,sha256=GyhxIZ1CcUPcc8RMRyCzpHBEwVifeqiGxT3c9Pp0KAU,5774
12
+ nvidia_nat_test-1.4.0a20251125.dist-info/licenses/LICENSE-3rd-party.txt,sha256=fOk5jMmCX9YoKWyYzTtfgl-SUy477audFC5hNY4oP7Q,284609
13
+ nvidia_nat_test-1.4.0a20251125.dist-info/licenses/LICENSE.md,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
14
+ nvidia_nat_test-1.4.0a20251125.dist-info/METADATA,sha256=diTRLTBEHiMoMiOAkuUZQcB2wXXZiTDChOp6ImRnqXQ,1925
15
+ nvidia_nat_test-1.4.0a20251125.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ nvidia_nat_test-1.4.0a20251125.dist-info/entry_points.txt,sha256=7dOP9XB6iMDqvav3gYx9VWUwA8RrFzhbAa8nGeC8e4Y,99
17
+ nvidia_nat_test-1.4.0a20251125.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
18
+ nvidia_nat_test-1.4.0a20251125.dist-info/RECORD,,