nvidia-haystack 0.1.8__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/CHANGELOG.md +13 -6
  2. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/PKG-INFO +12 -6
  3. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/README.md +10 -4
  4. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/pyproject.toml +33 -42
  5. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +6 -4
  6. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +6 -4
  7. nvidia_haystack-0.3.0/src/haystack_integrations/components/embedders/py.typed +0 -0
  8. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/generators/nvidia/__init__.py +2 -1
  9. nvidia_haystack-0.3.0/src/haystack_integrations/components/generators/nvidia/chat/chat_generator.py +133 -0
  10. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/generators/nvidia/generator.py +7 -5
  11. nvidia_haystack-0.3.0/src/haystack_integrations/components/generators/py.typed +0 -0
  12. nvidia_haystack-0.3.0/src/haystack_integrations/components/rankers/nvidia/py.typed +0 -0
  13. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/rankers/nvidia/ranker.py +4 -4
  14. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/utils/nvidia/__init__.py +2 -1
  15. nvidia_haystack-0.3.0/src/haystack_integrations/utils/nvidia/client.py +26 -0
  16. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/utils/nvidia/models.py +74 -71
  17. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/utils/nvidia/nim_backend.py +17 -11
  18. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/utils/nvidia/utils.py +15 -15
  19. nvidia_haystack-0.3.0/src/haystack_integrations/utils/py.typed +0 -0
  20. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/conftest.py +1 -1
  21. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_nim_backend.py +63 -5
  22. nvidia_haystack-0.3.0/tests/test_nvidia_chat_generator.py +379 -0
  23. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_utils.py +7 -7
  24. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/.gitignore +0 -0
  25. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/LICENSE.txt +0 -0
  26. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/pydoc/config.yml +0 -0
  27. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/embedders/nvidia/__init__.py +0 -0
  28. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/embedders/nvidia/truncate.py +0 -0
  29. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/generators/nvidia/chat/__init__.py +0 -0
  30. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/rankers/nvidia/__init__.py +0 -0
  31. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/src/haystack_integrations/components/rankers/nvidia/truncate.py +0 -0
  32. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/__init__.py +0 -0
  33. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_base_url.py +0 -0
  34. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_document_embedder.py +0 -0
  35. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_embedding_truncate_mode.py +0 -0
  36. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_generator.py +0 -0
  37. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_ranker.py +0 -0
  38. {nvidia_haystack-0.1.8 → nvidia_haystack-0.3.0}/tests/test_text_embedder.py +0 -0
@@ -1,10 +1,20 @@
1
1
  # Changelog
2
2
 
3
- ## [integrations/nvidia-v0.1.7] - 2025-04-03
3
+ ## [integrations/nvidia-v0.2.0] - 2025-06-05
4
4
 
5
- ### 📚 Documentation
5
+ ### 🚀 Features
6
+
7
+ - Add NvidiaChatGenerator based on OpenAIChatGenerator (#1776)
8
+
9
+
10
+ ## [integrations/nvidia-v0.1.8] - 2025-05-28
11
+
12
+ ### 🌀 Miscellaneous
13
+
14
+ - Add pins for Nvidia (#1846)
15
+
16
+ ## [integrations/nvidia-v0.1.7] - 2025-04-03
6
17
 
7
- - Update changelog for integrations/nvidia (#1365)
8
18
 
9
19
  ### 🧪 Testing
10
20
 
@@ -30,9 +40,6 @@
30
40
 
31
41
  - Add nvidia latest embedding models (#1364)
32
42
 
33
- ### 📚 Documentation
34
-
35
- - Update changelog for integrations/nvidia (#1353)
36
43
 
37
44
  ## [integrations/nvidia-v0.1.5] - 2025-02-04
38
45
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-haystack
3
- Version: 0.1.8
3
+ Version: 0.3.0
4
4
  Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme
5
5
  Project-URL: Issues, https://github.com/deepset-ai/haystack-core-integrations/issues
6
6
  Project-URL: Source, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia
@@ -18,7 +18,7 @@ Classifier: Programming Language :: Python :: 3.13
18
18
  Classifier: Programming Language :: Python :: Implementation :: CPython
19
19
  Classifier: Programming Language :: Python :: Implementation :: PyPy
20
20
  Requires-Python: >=3.9
21
- Requires-Dist: haystack-ai
21
+ Requires-Dist: haystack-ai>=2.13.0
22
22
  Requires-Dist: requests>=2.25.0
23
23
  Requires-Dist: tqdm>=4.21.0
24
24
  Description-Content-Type: text/markdown
@@ -54,7 +54,7 @@ pip install hatch
54
54
  With `hatch` installed, to run all the tests:
55
55
 
56
56
  ```
57
- hatch run test
57
+ hatch run test:all
58
58
  ```
59
59
 
60
60
  > Note: integration tests will be skipped unless the env var NVIDIA_API_KEY is set. The api key needs to be valid
@@ -63,13 +63,19 @@ hatch run test
63
63
  To only run unit tests:
64
64
 
65
65
  ```
66
- hatch run test -m "not integration"
66
+ hatch run test:unit
67
67
  ```
68
68
 
69
- To run the linters `ruff` and `mypy`:
69
+ To format your code and perform linting using Ruff (with automatic fixes), run:
70
70
 
71
71
  ```
72
- hatch run lint:all
72
+ hatch run fmt
73
+ ```
74
+
75
+ To check for static type errors, run:
76
+
77
+ ```console
78
+ $ hatch run test:types
73
79
  ```
74
80
 
75
81
  ## License
@@ -29,7 +29,7 @@ pip install hatch
29
29
  With `hatch` installed, to run all the tests:
30
30
 
31
31
  ```
32
- hatch run test
32
+ hatch run test:all
33
33
  ```
34
34
 
35
35
  > Note: integration tests will be skipped unless the env var NVIDIA_API_KEY is set. The api key needs to be valid
@@ -38,13 +38,19 @@ hatch run test
38
38
  To only run unit tests:
39
39
 
40
40
  ```
41
- hatch run test -m "not integration"
41
+ hatch run test:unit
42
42
  ```
43
43
 
44
- To run the linters `ruff` and `mypy`:
44
+ To format your code and perform linting using Ruff (with automatic fixes), run:
45
45
 
46
46
  ```
47
- hatch run lint:all
47
+ hatch run fmt
48
+ ```
49
+
50
+ To check for static type errors, run:
51
+
52
+ ```console
53
+ $ hatch run test:types
48
54
  ```
49
55
 
50
56
  ## License
@@ -23,7 +23,7 @@ classifiers = [
23
23
  "Programming Language :: Python :: Implementation :: CPython",
24
24
  "Programming Language :: Python :: Implementation :: PyPy",
25
25
  ]
26
- dependencies = ["haystack-ai", "requests>=2.25.0", "tqdm>=4.21.0"]
26
+ dependencies = ["haystack-ai>=2.13.0", "requests>=2.25.0", "tqdm>=4.21.0"]
27
27
 
28
28
  [project.urls]
29
29
  Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
@@ -43,34 +43,42 @@ git_describe_command = 'git describe --tags --match="integrations/nvidia-v[0-9]*
43
43
 
44
44
  [tool.hatch.envs.default]
45
45
  installer = "uv"
46
- dependencies = [
47
- "coverage[toml]>=6.5",
48
- "pytest",
49
- "pytest-rerunfailures",
50
- "haystack-pydoc-tools",
51
- "requests_mock",
52
- ]
46
+ dependencies = ["haystack-pydoc-tools", "ruff"]
47
+
53
48
  [tool.hatch.envs.default.scripts]
54
- test = "pytest {args:tests}"
55
- test-cov = "coverage run -m pytest {args:tests}"
56
- test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x"
57
- cov-report = ["- coverage combine", "coverage report"]
58
- cov = ["test-cov", "cov-report"]
59
- cov-retry = ["test-cov-retry", "cov-report"]
60
49
  docs = ["pydoc-markdown pydoc/config.yml"]
50
+ fmt = "ruff check --fix {args} && ruff format {args}"
51
+ fmt-check = "ruff check {args} && ruff format --check {args}"
61
52
 
62
- [tool.hatch.envs.lint]
63
- installer = "uv"
64
- detached = true
65
- dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
66
- [tool.hatch.envs.lint.scripts]
67
- typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
68
- style = [
69
- "ruff check {args:}",
70
- "black --check --diff {args:.}",
53
+ [tool.hatch.envs.test]
54
+ dependencies = [
55
+ "pytest",
56
+ "pytest-asyncio",
57
+ "pytest-cov",
58
+ "pytest-rerunfailures",
59
+ "mypy",
60
+ "pip",
61
+ "requests_mock",
62
+ "pytz"
71
63
  ]
72
- fmt = ["black {args:.}", "ruff check --fix {args:}", "style"]
73
- all = ["style", "typing"]
64
+
65
+ [tool.hatch.envs.test.scripts]
66
+ unit = 'pytest -m "not integration" {args:tests}'
67
+ integration = 'pytest -m "integration" {args:tests}'
68
+ all = 'pytest {args:tests}'
69
+ cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
70
+
71
+ types = """mypy -p haystack_integrations.components.embedders.nvidia \
72
+ -p haystack_integrations.components.generators.nvidia \
73
+ -p haystack_integrations.components.rankers.nvidia \
74
+ -p haystack_integrations.utils.nvidia {args}"""
75
+
76
+ [tool.mypy]
77
+ install_types = true
78
+ non_interactive = true
79
+ check_untyped_defs = true
80
+ disallow_incomplete_defs = true
81
+
74
82
 
75
83
  [tool.black]
76
84
  target-version = ["py38"]
@@ -151,26 +159,9 @@ omit = ["*/tests/*", "*/__init__.py"]
151
159
  show_missing = true
152
160
  exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
153
161
 
154
-
155
- [[tool.mypy.overrides]]
156
- module = [
157
- "nvidia.*",
158
- "haystack.*",
159
- "haystack_integrations.*",
160
- "pytest.*",
161
- "numpy.*",
162
- "requests_mock.*",
163
- "pydantic.*",
164
- ]
165
- ignore_missing_imports = true
166
-
167
162
  [tool.pytest.ini_options]
168
163
  addopts = "--strict-markers"
169
164
  markers = [
170
165
  "integration: integration tests",
171
- "unit: unit tests",
172
- "embedders: embedders tests",
173
- "generators: generators tests",
174
- "chat_generators: chat_generators tests",
175
166
  ]
176
167
  log_cli = true
@@ -11,7 +11,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace
11
11
  from tqdm import tqdm
12
12
 
13
13
  from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
14
- from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, url_validation
14
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, url_validation
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
@@ -122,7 +122,9 @@ class NvidiaDocumentEmbedder:
122
122
  UserWarning,
123
123
  stacklevel=2,
124
124
  )
125
- self.model = self.backend.model = name
125
+ self.model = name
126
+ if self.backend:
127
+ self.backend.model = name
126
128
  else:
127
129
  error_message = "No locally hosted model was found."
128
130
  raise ValueError(error_message)
@@ -143,7 +145,7 @@ class NvidiaDocumentEmbedder:
143
145
  api_url=self.api_url,
144
146
  api_key=self.api_key,
145
147
  model_kwargs=model_kwargs,
146
- client=self.__class__.__name__,
148
+ client=Client.NVIDIA_DOCUMENT_EMBEDDER,
147
149
  timeout=self.timeout,
148
150
  )
149
151
  if not self.model and self.backend.model:
@@ -232,7 +234,7 @@ class NvidiaDocumentEmbedder:
232
234
  return all_embeddings, {"usage": {"prompt_tokens": usage_prompt_tokens, "total_tokens": usage_total_tokens}}
233
235
 
234
236
  @component.output_types(documents=List[Document], meta=Dict[str, Any])
235
- def run(self, documents: List[Document]):
237
+ def run(self, documents: List[Document]) -> Dict[str, Union[List[Document], Dict[str, Any]]]:
236
238
  """
237
239
  Embed a list of Documents.
238
240
 
@@ -10,7 +10,7 @@ from haystack import component, default_from_dict, default_to_dict, logging
10
10
  from haystack.utils import Secret, deserialize_secrets_inplace
11
11
 
12
12
  from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
13
- from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, url_validation
13
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, url_validation
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -112,7 +112,9 @@ class NvidiaTextEmbedder:
112
112
  UserWarning,
113
113
  stacklevel=2,
114
114
  )
115
- self.model = self.backend.model = name
115
+ self.model = name
116
+ if self.backend:
117
+ self.backend.model = name
116
118
  else:
117
119
  error_message = "No locally hosted model was found."
118
120
  raise ValueError(error_message)
@@ -134,7 +136,7 @@ class NvidiaTextEmbedder:
134
136
  api_key=self.api_key,
135
137
  model_kwargs=model_kwargs,
136
138
  timeout=self.timeout,
137
- client=self.__class__.__name__,
139
+ client=Client.NVIDIA_TEXT_EMBEDDER,
138
140
  )
139
141
  self._initialized = True
140
142
 
@@ -185,7 +187,7 @@ class NvidiaTextEmbedder:
185
187
  return default_from_dict(cls, data)
186
188
 
187
189
  @component.output_types(embedding=List[float], meta=Dict[str, Any])
188
- def run(self, text: str):
190
+ def run(self, text: str) -> Dict[str, Union[List[float], Dict[str, Any]]]:
189
191
  """
190
192
  Embed a string.
191
193
 
@@ -2,6 +2,7 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from .chat.chat_generator import NvidiaChatGenerator
5
6
  from .generator import NvidiaGenerator
6
7
 
7
- __all__ = ["NvidiaGenerator"]
8
+ __all__ = ["NvidiaChatGenerator", "NvidiaGenerator"]
@@ -0,0 +1,133 @@
1
+ # SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import os
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ from haystack import component, default_to_dict, logging
9
+ from haystack.components.generators.chat import OpenAIChatGenerator
10
+ from haystack.dataclasses import StreamingCallbackT
11
+ from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
12
+ from haystack.utils import serialize_callable
13
+ from haystack.utils.auth import Secret
14
+
15
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @component
21
+ class NvidiaChatGenerator(OpenAIChatGenerator):
22
+ """
23
+ Enables text generation using NVIDIA generative models.
24
+ For supported models, see [NVIDIA Docs](https://build.nvidia.com/models).
25
+
26
+ Users can pass any text generation parameters valid for the NVIDIA Chat Completion API
27
+ directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
28
+ parameter in `run` method.
29
+
30
+ This component uses the ChatMessage format for structuring both input and output,
31
+ ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
32
+ Details on the ChatMessage format can be found in the
33
+ [Haystack docs](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
34
+
35
+ For more details on the parameters supported by the NVIDIA API, refer to the
36
+ [NVIDIA Docs](https://build.nvidia.com/models).
37
+
38
+ Usage example:
39
+ ```python
40
+ from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator
41
+ from haystack.dataclasses import ChatMessage
42
+
43
+ messages = [ChatMessage.from_user("What's Natural Language Processing?")]
44
+
45
+ client = NvidiaChatGenerator()
46
+ response = client.run(messages)
47
+ print(response)
48
+ ```
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ *,
54
+ api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"),
55
+ model: str = "meta/llama-3.1-8b-instruct",
56
+ streaming_callback: Optional[StreamingCallbackT] = None,
57
+ api_base_url: Optional[str] = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL),
58
+ generation_kwargs: Optional[Dict[str, Any]] = None,
59
+ tools: Optional[Union[List[Tool], Toolset]] = None,
60
+ timeout: Optional[float] = None,
61
+ max_retries: Optional[int] = None,
62
+ http_client_kwargs: Optional[Dict[str, Any]] = None,
63
+ ):
64
+ """
65
+ Creates an instance of NvidiaChatGenerator.
66
+
67
+ :param api_key:
68
+ The NVIDIA API key.
69
+ :param model:
70
+ The name of the NVIDIA chat completion model to use.
71
+ :param streaming_callback:
72
+ A callback function that is called when a new token is received from the stream.
73
+ The callback function accepts StreamingChunk as an argument.
74
+ :param api_base_url:
75
+ The NVIDIA API Base url.
76
+ :param generation_kwargs:
77
+ Other parameters to use for the model. These parameters are all sent directly to
78
+ the NVIDIA API endpoint. See [NVIDIA API docs](https://docs.nvcf.nvidia.com/ai/generative-models/)
79
+ for more details.
80
+ Some of the supported parameters:
81
+ - `max_tokens`: The maximum number of tokens the output text can have.
82
+ - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
83
+ Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
84
+ - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
85
+ considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
86
+ comprising the top 10% probability mass are considered.
87
+ - `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
88
+ events as they become available, with the stream terminated by a data: [DONE] message.
89
+ :param tools:
90
+ A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
91
+ list of `Tool` objects or a `Toolset` instance.
92
+ :param timeout:
93
+ The timeout for the NVIDIA API call.
94
+ :param max_retries:
95
+ Maximum number of retries to contact NVIDIA after an internal error.
96
+ If not set, it defaults to either the `NVIDIA_MAX_RETRIES` environment variable, or set to 5.
97
+ :param http_client_kwargs:
98
+ A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
99
+ For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
100
+ """
101
+ super(NvidiaChatGenerator, self).__init__( # noqa: UP008
102
+ api_key=api_key,
103
+ model=model,
104
+ streaming_callback=streaming_callback,
105
+ api_base_url=api_base_url,
106
+ generation_kwargs=generation_kwargs,
107
+ tools=tools,
108
+ timeout=timeout,
109
+ max_retries=max_retries,
110
+ http_client_kwargs=http_client_kwargs,
111
+ )
112
+
113
+ def to_dict(self) -> Dict[str, Any]:
114
+ """
115
+ Serialize this component to a dictionary.
116
+
117
+ :returns:
118
+ The serialized component as a dictionary.
119
+ """
120
+ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
121
+
122
+ return default_to_dict(
123
+ self,
124
+ model=self.model,
125
+ streaming_callback=callback_name,
126
+ api_base_url=self.api_base_url,
127
+ generation_kwargs=self.generation_kwargs,
128
+ api_key=self.api_key.to_dict(),
129
+ tools=serialize_tools_or_toolset(self.tools),
130
+ timeout=self.timeout,
131
+ max_retries=self.max_retries,
132
+ http_client_kwargs=self.http_client_kwargs,
133
+ )
@@ -4,12 +4,12 @@
4
4
 
5
5
  import os
6
6
  import warnings
7
- from typing import Any, Dict, List, Optional
7
+ from typing import Any, Dict, List, Optional, Union
8
8
 
9
9
  from haystack import component, default_from_dict, default_to_dict
10
10
  from haystack.utils.auth import Secret, deserialize_secrets_inplace
11
11
 
12
- from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Model, NimBackend, is_hosted, url_validation
12
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, Model, NimBackend, is_hosted, url_validation
13
13
 
14
14
 
15
15
  @component
@@ -104,7 +104,9 @@ class NvidiaGenerator:
104
104
  UserWarning,
105
105
  stacklevel=2,
106
106
  )
107
- self._model = self.backend.model = name
107
+ self._model = name
108
+ if self.backend:
109
+ self.backend.model = name
108
110
  else:
109
111
  error_message = "No locally hosted model was found."
110
112
  raise ValueError(error_message)
@@ -123,7 +125,7 @@ class NvidiaGenerator:
123
125
  api_key=self._api_key,
124
126
  model_kwargs=self._model_arguments,
125
127
  timeout=self.timeout,
126
- client=self.__class__.__name__,
128
+ client=Client.NVIDIA_GENERATOR,
127
129
  )
128
130
 
129
131
  if not self.is_hosted and not self._model:
@@ -169,7 +171,7 @@ class NvidiaGenerator:
169
171
  return default_from_dict(cls, data)
170
172
 
171
173
  @component.output_types(replies=List[str], meta=List[Dict[str, Any]])
172
- def run(self, prompt: str):
174
+ def run(self, prompt: str) -> Dict[str, Union[List[str], List[Dict[str, Any]]]]:
173
175
  """
174
176
  Queries the model with the provided prompt.
175
177
 
@@ -10,7 +10,7 @@ from haystack import Document, component, default_from_dict, default_to_dict, lo
10
10
  from haystack.utils import Secret, deserialize_secrets_inplace
11
11
 
12
12
  from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode
13
- from haystack_integrations.utils.nvidia import DEFAULT_API_URL, NimBackend, is_hosted, url_validation
13
+ from haystack_integrations.utils.nvidia import DEFAULT_API_URL, Client, NimBackend, is_hosted, url_validation
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -162,7 +162,7 @@ class NvidiaRanker:
162
162
  :raises ValueError: If the API key is required for hosted NVIDIA NIMs.
163
163
  """
164
164
  if not self._initialized:
165
- model_kwargs = {}
165
+ model_kwargs: Dict[str, Any] = {}
166
166
  if self.truncate is not None:
167
167
  model_kwargs.update(truncate=str(self.truncate))
168
168
  self.backend = NimBackend(
@@ -172,9 +172,9 @@ class NvidiaRanker:
172
172
  api_key=self.api_key,
173
173
  model_kwargs=model_kwargs,
174
174
  timeout=self.timeout,
175
- client=self.__class__.__name__,
175
+ client=Client.NVIDIA_RANKER,
176
176
  )
177
- if not self.is_hosted and not self._model:
177
+ if not self.is_hosted and not self.model:
178
178
  if self.backend.model:
179
179
  self.model = self.backend.model
180
180
  self._initialized = True
@@ -2,8 +2,9 @@
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
 
5
+ from .client import Client
5
6
  from .models import DEFAULT_API_URL, Model
6
7
  from .nim_backend import NimBackend
7
8
  from .utils import is_hosted, url_validation
8
9
 
9
- __all__ = ["DEFAULT_API_URL", "Model", "NimBackend", "is_hosted", "url_validation", "validate_hosted_model"]
10
+ __all__ = ["DEFAULT_API_URL", "Client", "Model", "NimBackend", "is_hosted", "url_validation", "validate_hosted_model"]
@@ -0,0 +1,26 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Client(Enum):
5
+ """
6
+ Client to use for NVIDIA NIMs.
7
+ """
8
+
9
+ NVIDIA_GENERATOR = "NvidiaGenerator"
10
+ NVIDIA_TEXT_EMBEDDER = "NvidiaTextEmbedder"
11
+ NVIDIA_DOCUMENT_EMBEDDER = "NvidiaDocumentEmbedder"
12
+ NVIDIA_RANKER = "NvidiaRanker"
13
+
14
+ def __str__(self) -> str:
15
+ """Convert a Client enum to a string."""
16
+ return self.value
17
+
18
+ @staticmethod
19
+ def from_str(string: str) -> "Client":
20
+ """Convert a string to a Client enum."""
21
+ enum_map = {e.value: e for e in Client}
22
+ mode = enum_map.get(string)
23
+ if mode is None:
24
+ msg = f"Unknown client '{string}' to use for NVIDIA NIMs. Supported modes are: {list(enum_map.keys())}"
25
+ raise ValueError(msg)
26
+ return mode