nvidia-haystack 0.1.7__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/CHANGELOG.md +28 -0
  2. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/PKG-INFO +4 -4
  3. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/pyproject.toml +4 -1
  4. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/generators/nvidia/__init__.py +2 -1
  5. nvidia_haystack-0.2.0/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py +133 -0
  6. nvidia_haystack-0.2.0/tests/test_nvidia_chat_generator.py +379 -0
  7. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/.gitignore +0 -0
  8. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/LICENSE.txt +0 -0
  9. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/README.md +0 -0
  10. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/pydoc/config.yml +0 -0
  11. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/embedders/nvidia/__init__.py +0 -0
  12. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +0 -0
  13. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +0 -0
  14. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/embedders/nvidia/truncate.py +0 -0
  15. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/generators/nvidia/chat/__init__.py +0 -0
  16. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/generators/nvidia/generator.py +0 -0
  17. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/rankers/nvidia/__init__.py +0 -0
  18. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/rankers/nvidia/ranker.py +0 -0
  19. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/components/rankers/nvidia/truncate.py +0 -0
  20. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/utils/nvidia/__init__.py +0 -0
  21. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/utils/nvidia/models.py +0 -0
  22. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/utils/nvidia/nim_backend.py +0 -0
  23. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/src/haystack_integrations/utils/nvidia/utils.py +0 -0
  24. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/__init__.py +0 -0
  25. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/conftest.py +0 -0
  26. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_base_url.py +0 -0
  27. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_document_embedder.py +0 -0
  28. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_embedding_truncate_mode.py +0 -0
  29. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_generator.py +0 -0
  30. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_nim_backend.py +0 -0
  31. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_ranker.py +0 -0
  32. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_text_embedder.py +0 -0
  33. {nvidia_haystack-0.1.7 → nvidia_haystack-0.2.0}/tests/test_utils.py +0 -0
@@ -1,11 +1,39 @@
1
1
  # Changelog
2
2
 
3
+ ## [integrations/nvidia-v0.1.8] - 2025-05-28
4
+
5
+ ### 🌀 Miscellaneous
6
+
7
+ - Add pins for Nvidia (#1846)
8
+
9
+ ## [integrations/nvidia-v0.1.7] - 2025-04-03
10
+
11
+
12
+ ### 🧪 Testing
13
+
14
+ - Reduce Nvidia API calls in integration tests (#1432)
15
+ - Add test cases for all utils methods for Nvidia integration (#1458)
16
+ - Add unit tests for Nvidia NimBackend (#1546)
17
+
18
+ ### ⚙️ CI
19
+
20
+ - Review testing workflows (#1541)
21
+
22
+ ### 🧹 Chores
23
+
24
+ - Remove Python 3.8 support (#1421)
25
+
26
+ ### 🌀 Miscellaneous
27
+
28
+ - Fix: nvidia-haystack remove init files to make them namespace packages (#1594)
29
+
3
30
  ## [integrations/nvidia-v0.1.6] - 2025-02-11
4
31
 
5
32
  ### 🚀 Features
6
33
 
7
34
  - Add nvidia latest embedding models (#1364)
8
35
 
36
+
9
37
  ## [integrations/nvidia-v0.1.5] - 2025-02-04
10
38
 
11
39
  ### 🌀 Miscellaneous
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-haystack
3
- Version: 0.1.7
3
+ Version: 0.2.0
4
4
  Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme
5
5
  Project-URL: Issues, https://github.com/deepset-ai/haystack-core-integrations/issues
6
6
  Project-URL: Source, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia
@@ -18,9 +18,9 @@ Classifier: Programming Language :: Python :: 3.13
18
18
  Classifier: Programming Language :: Python :: Implementation :: CPython
19
19
  Classifier: Programming Language :: Python :: Implementation :: PyPy
20
20
  Requires-Python: >=3.9
21
- Requires-Dist: haystack-ai
22
- Requires-Dist: requests
23
- Requires-Dist: tqdm
21
+ Requires-Dist: haystack-ai>=2.13.0
22
+ Requires-Dist: requests>=2.25.0
23
+ Requires-Dist: tqdm>=4.21.0
24
24
  Description-Content-Type: text/markdown
25
25
 
26
26
  # nvidia-haystack
@@ -23,7 +23,7 @@ classifiers = [
23
23
  "Programming Language :: Python :: Implementation :: CPython",
24
24
  "Programming Language :: Python :: Implementation :: PyPy",
25
25
  ]
26
- dependencies = ["haystack-ai", "requests", "tqdm"]
26
+ dependencies = ["haystack-ai>=2.13.0", "requests>=2.25.0", "tqdm>=4.21.0"]
27
27
 
28
28
  [project.urls]
29
29
  Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
@@ -46,6 +46,8 @@ installer = "uv"
46
46
  dependencies = [
47
47
  "coverage[toml]>=6.5",
48
48
  "pytest",
49
+ "pytest-asyncio",
50
+ "pytz",
49
51
  "pytest-rerunfailures",
50
52
  "haystack-pydoc-tools",
51
53
  "requests_mock",
@@ -160,6 +162,7 @@ module = [
160
162
  "pytest.*",
161
163
  "numpy.*",
162
164
  "requests_mock.*",
165
+ "openai.*",
163
166
  "pydantic.*",
164
167
  ]
165
168
  ignore_missing_imports = true
@@ -2,6 +2,7 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from .chat.chat_generator import NvidiaChatGenerator
5
6
  from .generator import NvidiaGenerator
6
7
 
7
- __all__ = ["NvidiaGenerator"]
8
+ __all__ = ["NvidiaChatGenerator", "NvidiaGenerator"]
@@ -0,0 +1,133 @@
1
+ # SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ from haystack import component, default_to_dict, logging
9
+ from haystack.components.generators.chat import OpenAIChatGenerator
10
+ from haystack.dataclasses import StreamingCallbackT
11
+ from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
12
+ from haystack.utils import serialize_callable
13
+ from haystack.utils.auth import Secret
14
+
15
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @component
21
+ class NvidiaChatGenerator(OpenAIChatGenerator):
22
+ """
23
+ Enables text generation using NVIDIA generative models.
24
+ For supported models, see [NVIDIA Docs](https://build.nvidia.com/models).
25
+
26
+ Users can pass any text generation parameters valid for the NVIDIA Chat Completion API
27
+ directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
28
+ parameter in `run` method.
29
+
30
+ This component uses the ChatMessage format for structuring both input and output,
31
+ ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
32
+ Details on the ChatMessage format can be found in the
33
+ [Haystack docs](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
34
+
35
+ For more details on the parameters supported by the NVIDIA API, refer to the
36
+ [NVIDIA Docs](https://build.nvidia.com/models).
37
+
38
+ Usage example:
39
+ ```python
40
+ from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator
41
+ from haystack.dataclasses import ChatMessage
42
+
43
+ messages = [ChatMessage.from_user("What's Natural Language Processing?")]
44
+
45
+ client = NvidiaChatGenerator()
46
+ response = client.run(messages)
47
+ print(response)
48
+ ```
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"),
55
+ model: str = "meta/llama-3.1-8b-instruct",
56
+ streaming_callback: Optional[StreamingCallbackT] = None,
57
+ api_base_url: Optional[str] = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL),
58
+ generation_kwargs: Optional[Dict[str, Any]] = None,
59
+ tools: Optional[Union[List[Tool], Toolset]] = None,
60
+ timeout: Optional[float] = None,
61
+ max_retries: Optional[int] = None,
62
+ http_client_kwargs: Optional[Dict[str, Any]] = None,
63
+ ):
64
+ """
65
+ Creates an instance of NvidiaChatGenerator.
66
+
67
+ :param api_key:
68
+ The NVIDIA API key.
69
+ :param model:
70
+ The name of the NVIDIA chat completion model to use.
71
+ :param streaming_callback:
72
+ A callback function that is called when a new token is received from the stream.
73
+ The callback function accepts StreamingChunk as an argument.
74
+ :param api_base_url:
75
+ The NVIDIA API Base url.
76
+ :param generation_kwargs:
77
+ Other parameters to use for the model. These parameters are all sent directly to
78
+ the NVIDIA API endpoint. See [NVIDIA API docs](https://docs.nvcf.nvidia.com/ai/generative-models/)
79
+ for more details.
80
+ Some of the supported parameters:
81
+ - `max_tokens`: The maximum number of tokens the output text can have.
82
+ - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
83
+ Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
84
+ - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
85
+ considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
86
+ comprising the top 10% probability mass are considered.
87
+ - `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
88
+ events as they become available, with the stream terminated by a data: [DONE] message.
89
+ :param tools:
90
+ A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
91
+ list of `Tool` objects or a `Toolset` instance.
92
+ :param timeout:
93
+ The timeout for the NVIDIA API call.
94
+ :param max_retries:
95
+ Maximum number of retries to contact NVIDIA after an internal error.
96
+ If not set, it defaults to either the `NVIDIA_MAX_RETRIES` environment variable, or set to 5.
97
+ :param http_client_kwargs:
98
+ A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
99
+ For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
100
+ """
101
+ super(NvidiaChatGenerator, self).__init__( # noqa: UP008
102
+ api_key=api_key,
103
+ model=model,
104
+ streaming_callback=streaming_callback,
105
+ api_base_url=api_base_url,
106
+ generation_kwargs=generation_kwargs,
107
+ tools=tools,
108
+ timeout=timeout,
109
+ max_retries=max_retries,
110
+ http_client_kwargs=http_client_kwargs,
111
+ )
112
+
113
+ def to_dict(self) -> Dict[str, Any]:
114
+ """
115
+ Serialize this component to a dictionary.
116
+
117
+ :returns:
118
+ The serialized component as a dictionary.
119
+ """
120
+ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
121
+
122
+ return default_to_dict(
123
+ self,
124
+ model=self.model,
125
+ streaming_callback=callback_name,
126
+ api_base_url=self.api_base_url,
127
+ generation_kwargs=self.generation_kwargs,
128
+ api_key=self.api_key.to_dict(),
129
+ tools=serialize_tools_or_toolset(self.tools),
130
+ timeout=self.timeout,
131
+ max_retries=self.max_retries,
132
+ http_client_kwargs=self.http_client_kwargs,
133
+ )
@@ -0,0 +1,379 @@
1
+ # SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ from datetime import datetime
7
+ from unittest.mock import AsyncMock, patch
8
+
9
+ import pytest
10
+ import pytz
11
+ from haystack.components.generators.utils import print_streaming_chunk
12
+ from haystack.dataclasses import ChatMessage, StreamingChunk
13
+ from haystack.tools import Tool
14
+ from haystack.utils.auth import Secret
15
+ from openai import AsyncOpenAI, OpenAIError
16
+ from openai.types.chat import ChatCompletion, ChatCompletionMessage
17
+ from openai.types.chat.chat_completion import Choice
18
+
19
+ from haystack_integrations.components.generators.nvidia.chat.chat_generator import NvidiaChatGenerator
20
+ from haystack_integrations.utils.nvidia.models import DEFAULT_API_URL
21
+
22
+
23
+ @pytest.fixture
24
+ def chat_messages():
25
+ return [
26
+ ChatMessage.from_system("You are a helpful assistant"),
27
+ ChatMessage.from_user("What's the capital of France"),
28
+ ]
29
+
30
+
31
+ def weather(city: str):
32
+ """Get weather for a given city."""
33
+ return f"The weather in {city} is sunny and 32°C"
34
+
35
+
36
+ @pytest.fixture
37
+ def tools():
38
+ tool_parameters = {
39
+ "type": "object",
40
+ "properties": {"city": {"type": "string"}},
41
+ "required": ["city"],
42
+ }
43
+ tool = Tool(
44
+ name="weather",
45
+ description="useful to determine the weather in a given location",
46
+ parameters=tool_parameters,
47
+ function=weather,
48
+ )
49
+
50
+ return [tool]
51
+
52
+
53
+ @pytest.fixture
54
+ def mock_chat_completion():
55
+ """
56
+ Mock the OpenAI API completion response and reuse it for tests
57
+ """
58
+ with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
59
+ completion = ChatCompletion(
60
+ id="foo",
61
+ model="meta/llama-3.1-8b-instruct",
62
+ object="chat.completion",
63
+ choices=[
64
+ Choice(
65
+ finish_reason="stop",
66
+ logprobs=None,
67
+ index=0,
68
+ message=ChatCompletionMessage(content="Hello world!", role="assistant"),
69
+ )
70
+ ],
71
+ created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()),
72
+ usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
73
+ )
74
+
75
+ mock_chat_completion_create.return_value = completion
76
+ yield mock_chat_completion_create
77
+
78
+
79
+ @pytest.fixture
80
+ def mock_async_chat_completion():
81
+ """
82
+ Mock the Async OpenAI API completion response and reuse it for async tests
83
+ """
84
+ with patch(
85
+ "openai.resources.chat.completions.AsyncCompletions.create",
86
+ new_callable=AsyncMock,
87
+ ) as mock_chat_completion_create:
88
+ completion = ChatCompletion(
89
+ id="foo",
90
+ model="meta/llama-3.1-8b-instruct",
91
+ object="chat.completion",
92
+ choices=[
93
+ Choice(
94
+ finish_reason="stop",
95
+ logprobs=None,
96
+ index=0,
97
+ message=ChatCompletionMessage(content="Hello world!", role="assistant"),
98
+ )
99
+ ],
100
+ created=int(datetime.now(tz=pytz.timezone("UTC")).timestamp()),
101
+ usage={
102
+ "prompt_tokens": 57,
103
+ "completion_tokens": 40,
104
+ "total_tokens": 97,
105
+ },
106
+ )
107
+ # For async mocks, the return value should be awaitable
108
+ mock_chat_completion_create.return_value = completion
109
+ yield mock_chat_completion_create
110
+
111
+
112
+ class TestNvidiaChatGenerator:
113
+ def test_init_default(self, monkeypatch):
114
+ monkeypatch.setenv("NVIDIA_API_KEY", "test-api-key")
115
+ component = NvidiaChatGenerator()
116
+ assert component.client.api_key == "test-api-key"
117
+ assert component.model == "meta/llama-3.1-8b-instruct"
118
+ assert component.streaming_callback is None
119
+ assert not component.generation_kwargs
120
+
121
+ def test_init_fail_wo_api_key(self, monkeypatch):
122
+ monkeypatch.delenv("NVIDIA_API_KEY", raising=False)
123
+ with pytest.raises(ValueError, match="None of the .* environment variables are set"):
124
+ NvidiaChatGenerator()
125
+
126
+ def test_init_with_parameters(self):
127
+ component = NvidiaChatGenerator(
128
+ api_key=Secret.from_token("test-api-key"),
129
+ model="meta/llama-3.1-8b-instruct",
130
+ streaming_callback=print_streaming_chunk,
131
+ api_base_url="test-base-url",
132
+ generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
133
+ )
134
+ assert component.client.api_key == "test-api-key"
135
+ assert component.model == "meta/llama-3.1-8b-instruct"
136
+ assert component.streaming_callback is print_streaming_chunk
137
+ assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
138
+
139
+ def test_to_dict_default(self, monkeypatch):
140
+ monkeypatch.setenv("NVIDIA_API_KEY", "test-api-key")
141
+ component = NvidiaChatGenerator()
142
+ data = component.to_dict()
143
+
144
+ assert (
145
+ data["type"] == "haystack_integrations.components.generators.nvidia.chat.chat_generator.NvidiaChatGenerator"
146
+ )
147
+
148
+ expected_params = {
149
+ "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"},
150
+ "model": "meta/llama-3.1-8b-instruct",
151
+ "streaming_callback": None,
152
+ "api_base_url": DEFAULT_API_URL,
153
+ "generation_kwargs": {},
154
+ "tools": None,
155
+ "timeout": None,
156
+ "max_retries": None,
157
+ "http_client_kwargs": None,
158
+ }
159
+
160
+ for key, value in expected_params.items():
161
+ assert data["init_parameters"][key] == value
162
+
163
+ def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002
164
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
165
+ component = NvidiaChatGenerator()
166
+ response = component.run(chat_messages)
167
+
168
+ # check that the component returns the correct ChatMessage response
169
+ assert isinstance(response, dict)
170
+ assert "replies" in response
171
+ assert isinstance(response["replies"], list)
172
+ assert len(response["replies"]) == 1
173
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
174
+
175
+ def test_run_with_params(self, chat_messages, mock_chat_completion, monkeypatch):
176
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
177
+ component = NvidiaChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
178
+ response = component.run(chat_messages)
179
+
180
+ # check that the component calls the OpenAI API with the correct parameters
181
+ _, kwargs = mock_chat_completion.call_args
182
+ assert kwargs["max_tokens"] == 10
183
+ assert kwargs["temperature"] == 0.5
184
+
185
+ # check that the component returns the correct response
186
+ assert isinstance(response, dict)
187
+ assert "replies" in response
188
+ assert isinstance(response["replies"], list)
189
+ assert len(response["replies"]) == 1
190
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
191
+
192
+ def test_run_with_extra_body(self, chat_messages, mock_chat_completion, monkeypatch):
193
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
194
+ extra_body = {
195
+ "guardrails": {"config_id": "demo-self-check-input-output"},
196
+ }
197
+ component = NvidiaChatGenerator(generation_kwargs={"extra_body": extra_body})
198
+ response = component.run(chat_messages)
199
+
200
+ # check that the component calls the OpenAI API with the correct parameters
201
+ _, kwargs = mock_chat_completion.call_args
202
+ assert kwargs["extra_body"] == extra_body
203
+ assert kwargs["model"] == "meta/llama-3.1-8b-instruct"
204
+ assert kwargs["messages"] == [
205
+ {"role": "system", "content": "You are a helpful assistant"},
206
+ {"role": "user", "content": "What's the capital of France"},
207
+ ]
208
+
209
+ # check that the component returns the correct response
210
+ assert isinstance(response, dict)
211
+ assert "replies" in response
212
+ assert isinstance(response["replies"], list)
213
+ assert len(response["replies"]) == 1
214
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
215
+
216
+ @pytest.mark.skipif(
217
+ not os.environ.get("NVIDIA_API_KEY", None),
218
+ reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
219
+ )
220
+ @pytest.mark.integration
221
+ def test_live_run(self):
222
+ chat_messages = [ChatMessage.from_user("What's the capital of France")]
223
+ component = NvidiaChatGenerator()
224
+ results = component.run(chat_messages)
225
+ assert len(results["replies"]) == 1
226
+ message: ChatMessage = results["replies"][0]
227
+ assert "Paris" in message.text
228
+ assert "meta/llama-3.1-8b-instruct" in message.meta["model"]
229
+ assert message.meta["finish_reason"] == "stop"
230
+
231
+ @pytest.mark.skipif(
232
+ not os.environ.get("NVIDIA_API_KEY", None),
233
+ reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
234
+ )
235
+ @pytest.mark.integration
236
+ def test_live_run_wrong_model(self, chat_messages):
237
+ component = NvidiaChatGenerator(model="something-obviously-wrong")
238
+ with pytest.raises(OpenAIError):
239
+ component.run(chat_messages)
240
+
241
+ @pytest.mark.skipif(
242
+ not os.environ.get("NVIDIA_API_KEY", None),
243
+ reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
244
+ )
245
+ @pytest.mark.integration
246
+ def test_live_run_streaming(self):
247
+ class Callback:
248
+ def __init__(self):
249
+ self.responses = ""
250
+ self.counter = 0
251
+
252
+ def __call__(self, chunk: StreamingChunk) -> None:
253
+ self.counter += 1
254
+ self.responses += chunk.content if chunk.content else ""
255
+
256
+ callback = Callback()
257
+ component = NvidiaChatGenerator(streaming_callback=callback)
258
+ results = component.run([ChatMessage.from_user("What's the capital of France?")])
259
+
260
+ assert len(results["replies"]) == 1
261
+ message: ChatMessage = results["replies"][0]
262
+ assert "Paris" in message.text
263
+
264
+ assert "meta/llama-3.1-8b-instruct" in message.meta["model"]
265
+ assert message.meta["finish_reason"] == "stop"
266
+
267
+ assert callback.counter > 1
268
+ assert "Paris" in callback.responses
269
+
270
+
271
+ class TestNvidiaChatGeneratorAsync:
272
+ def test_init_default_async(self, monkeypatch):
273
+ monkeypatch.setenv("NVIDIA_API_KEY", "test-api-key")
274
+ component = NvidiaChatGenerator()
275
+
276
+ assert isinstance(component.async_client, AsyncOpenAI)
277
+ assert component.async_client.api_key == "test-api-key"
278
+ assert not component.generation_kwargs
279
+
280
+ @pytest.mark.asyncio
281
+ async def test_run_async(self, chat_messages, mock_async_chat_completion, monkeypatch): # noqa: ARG002
282
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
283
+ component = NvidiaChatGenerator()
284
+ response = await component.run_async(chat_messages)
285
+
286
+ # check that the component returns the correct ChatMessage response
287
+ assert isinstance(response, dict)
288
+ assert "replies" in response
289
+ assert isinstance(response["replies"], list)
290
+ assert len(response["replies"]) == 1
291
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
292
+
293
+ @pytest.mark.asyncio
294
+ async def test_run_async_with_params(self, chat_messages, mock_async_chat_completion, monkeypatch):
295
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
296
+ component = NvidiaChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
297
+ response = await component.run_async(chat_messages)
298
+
299
+ # check that the component calls the OpenAI API with the correct parameters
300
+ _, kwargs = mock_async_chat_completion.call_args
301
+ assert kwargs["max_tokens"] == 10
302
+ assert kwargs["temperature"] == 0.5
303
+
304
+ # check that the component returns the correct response
305
+ assert isinstance(response, dict)
306
+ assert "replies" in response
307
+ assert isinstance(response["replies"], list)
308
+ assert len(response["replies"]) == 1
309
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
310
+
311
+ @pytest.mark.asyncio
312
+ async def test_run_async_with_extra_body(self, chat_messages, mock_async_chat_completion, monkeypatch):
313
+ monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
314
+ extra_body = {
315
+ "guardrails": {"config_id": "demo-self-check-input-output"},
316
+ }
317
+ component = NvidiaChatGenerator(generation_kwargs={"extra_body": extra_body})
318
+ response = await component.run_async(chat_messages)
319
+
320
+ # check that the component calls the OpenAI API with the correct parameters
321
+ _, kwargs = mock_async_chat_completion.call_args
322
+ assert kwargs["extra_body"] == extra_body
323
+ assert kwargs["model"] == "meta/llama-3.1-8b-instruct"
324
+ assert kwargs["messages"] == [
325
+ {"role": "system", "content": "You are a helpful assistant"},
326
+ {"role": "user", "content": "What's the capital of France"},
327
+ ]
328
+
329
+ # check that the component returns the correct response
330
+ assert isinstance(response, dict)
331
+ assert "replies" in response
332
+ assert isinstance(response["replies"], list)
333
+ assert len(response["replies"]) == 1
334
+ assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
335
+
336
+ @pytest.mark.skipif(
337
+ not os.environ.get("NVIDIA_API_KEY", None),
338
+ reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
339
+ )
340
+ @pytest.mark.integration
341
+ @pytest.mark.asyncio
342
+ async def test_live_run_async(self):
343
+ chat_messages = [ChatMessage.from_user("What's the capital of France")]
344
+ component = NvidiaChatGenerator()
345
+ results = await component.run_async(chat_messages)
346
+ assert len(results["replies"]) == 1
347
+ message: ChatMessage = results["replies"][0]
348
+ assert "Paris" in message.text
349
+ assert "meta/llama-3.1-8b-instruct" in message.meta["model"]
350
+ assert message.meta["finish_reason"] == "stop"
351
+
352
+ @pytest.mark.skipif(
353
+ not os.environ.get("NVIDIA_API_KEY", None),
354
+ reason="Export an env var called NVIDIA_API_KEY containing the NVIDIA API key to run this test.",
355
+ )
356
+ @pytest.mark.integration
357
+ @pytest.mark.asyncio
358
+ async def test_live_run_streaming_async(self):
359
+ counter = 0
360
+ responses = ""
361
+
362
+ async def callback(chunk: StreamingChunk):
363
+ nonlocal counter
364
+ nonlocal responses
365
+ counter += 1
366
+ responses += chunk.content if chunk.content else ""
367
+
368
+ component = NvidiaChatGenerator(streaming_callback=callback)
369
+ results = await component.run_async([ChatMessage.from_user("What's the capital of France?")])
370
+
371
+ assert len(results["replies"]) == 1
372
+ message: ChatMessage = results["replies"][0]
373
+ assert "Paris" in message.text
374
+
375
+ assert "meta/llama-3.1-8b-instruct" in message.meta["model"]
376
+ assert message.meta["finish_reason"] == "stop"
377
+
378
+ assert counter > 1
379
+ assert "Paris" in responses