nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,200 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import logging
7
+ import math
8
+ import os
9
+ import sys
10
+ import multiprocessing as mp
11
+ from threading import Lock
12
+ from typing import Any, Callable, Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SimpleFuture:
18
+ """
19
+ A simplified future object that uses a multiprocessing Pipe to receive its result.
20
+
21
+ When the result() method is called, it blocks until the worker sends a tuple
22
+ (result, error) over the pipe.
23
+ """
24
+
25
+ def __init__(self, parent_conn: mp.connection.Connection) -> None:
26
+ """
27
+ Parameters
28
+ ----------
29
+ parent_conn : mp.connection.Connection
30
+ The parent end of the multiprocessing Pipe used to receive the result.
31
+ """
32
+ self._parent_conn: mp.connection.Connection = parent_conn
33
+
34
+ def result(self) -> Any:
35
+ """
36
+ Retrieve the result from the future, blocking until it is available.
37
+
38
+ Returns
39
+ -------
40
+ Any
41
+ The result returned by the worker function.
42
+
43
+ Raises
44
+ ------
45
+ Exception
46
+ If the worker function raised an exception, it is re-raised here.
47
+ """
48
+ result, error = self._parent_conn.recv()
49
+ if error is not None:
50
+ raise error
51
+ return result
52
+
53
+
54
+ class ProcessWorkerPoolSingleton:
55
+ """
56
+ A singleton process worker pool using a dual-queue implementation.
57
+
58
+ Instead of a global result queue, each submitted task gets its own Pipe.
59
+ The submit_task() method returns a SimpleFuture, whose result() call blocks
60
+ until the task completes.
61
+ """
62
+
63
+ _instance: Optional["ProcessWorkerPoolSingleton"] = None
64
+ _lock: Lock = Lock()
65
+ _total_workers: int = 0
66
+
67
+ def __new__(cls) -> "ProcessWorkerPoolSingleton":
68
+ """
69
+ Create or return the singleton instance of ProcessWorkerPoolSingleton.
70
+
71
+ Returns
72
+ -------
73
+ ProcessWorkerPoolSingleton
74
+ The singleton instance.
75
+ """
76
+ logger.debug("Creating ProcessWorkerPoolSingleton instance...")
77
+ with cls._lock:
78
+ if cls._instance is None:
79
+ max_worker_limit: int = int(os.environ.get("MAX_INGEST_PROCESS_WORKERS", -1))
80
+ instance = super().__new__(cls)
81
+ # Determine available CPU count using affinity if possible
82
+ available: Optional[int] = (
83
+ len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count()
84
+ )
85
+ # Use 40% of available CPUs, ensuring at least one worker
86
+ max_workers: int = math.floor(max(1, available * 0.4))
87
+ if (max_worker_limit > 0) and (max_workers > max_worker_limit):
88
+ max_workers = max_worker_limit
89
+ logger.debug("Creating ProcessWorkerPoolSingleton instance with max workers: %d", max_workers)
90
+ instance._initialize(max_workers)
91
+ logger.debug("ProcessWorkerPoolSingleton instance created: %s", instance)
92
+ cls._instance = instance
93
+ else:
94
+ logger.debug("ProcessWorkerPoolSingleton instance already exists: %s", cls._instance)
95
+ return cls._instance
96
+
97
+ def _initialize(self, total_max_workers: int) -> None:
98
+ """
99
+ Initialize the worker pool with the specified number of worker processes.
100
+
101
+ Parameters
102
+ ----------
103
+ total_max_workers : int
104
+ The total number of worker processes to start.
105
+ """
106
+ self._total_workers = total_max_workers
107
+
108
+ start_method = "fork"
109
+ if sys.platform.lower() == "darwin":
110
+ start_method = "spawn"
111
+ self._context: mp.context.ForkContext = mp.get_context(start_method)
112
+
113
+ # Bounded task queue: maximum tasks queued = 2 * total_max_workers.
114
+ self._task_queue: mp.Queue = self._context.Queue(maxsize=2 * total_max_workers)
115
+ self._next_task_id: int = 0
116
+ self._processes: list[mp.Process] = []
117
+ logger.debug(
118
+ "Initializing ProcessWorkerPoolSingleton with %d workers and queue size %d.",
119
+ total_max_workers,
120
+ 2 * total_max_workers,
121
+ )
122
+ for i in range(total_max_workers):
123
+ p: mp.Process = self._context.Process(target=self._worker, args=(self._task_queue,))
124
+ p.start()
125
+ self._processes.append(p)
126
+ logger.debug("Started worker process %d/%d: PID %d", i + 1, total_max_workers, p.pid)
127
+ logger.debug("Initialized with max workers: %d", total_max_workers)
128
+
129
+ @staticmethod
130
+ def _worker(task_queue: mp.Queue) -> None:
131
+ """
132
+ Worker process that continuously processes tasks from the task queue.
133
+
134
+ Parameters
135
+ ----------
136
+ task_queue : mp.Queue
137
+ The queue from which tasks are retrieved.
138
+ """
139
+ logger.debug("Worker process started: PID %d", os.getpid())
140
+ while True:
141
+ task = task_queue.get()
142
+ if task is None:
143
+ # Stop signal received; exit the loop.
144
+ logger.debug("Worker process %d received stop signal.", os.getpid())
145
+ break
146
+ # Unpack task: (task_id, process_fn, args, child_conn)
147
+ task_id, process_fn, args, child_conn = task
148
+ try:
149
+ result = process_fn(*args)
150
+ child_conn.send((result, None))
151
+ except Exception as e:
152
+ logger.error("Task %d error in worker %d: %s", task_id, os.getpid(), e)
153
+ child_conn.send((None, e))
154
+ finally:
155
+ child_conn.close()
156
+
157
+ def submit_task(self, process_fn: Callable, *args: Any) -> SimpleFuture:
158
+ """
159
+ Submits a task to the worker pool for asynchronous execution.
160
+
161
+ If a single tuple is passed as the only argument, it is unpacked.
162
+
163
+ Parameters
164
+ ----------
165
+ process_fn : Callable
166
+ The function to be executed asynchronously.
167
+ *args : Any
168
+ The arguments to pass to the process function. If a single argument is a tuple,
169
+ it will be unpacked as the function arguments.
170
+
171
+ Returns
172
+ -------
173
+ SimpleFuture
174
+ A future object that can be used to retrieve the result of the task.
175
+ """
176
+ # Unpack tuple if a single tuple argument is provided.
177
+ if len(args) == 1 and isinstance(args[0], tuple):
178
+ args = args[0]
179
+ parent_conn, child_conn = mp.Pipe(duplex=False)
180
+ task_id: int = self._next_task_id
181
+ self._next_task_id += 1
182
+ self._task_queue.put((task_id, process_fn, args, child_conn))
183
+ return SimpleFuture(parent_conn)
184
+
185
+ def close(self) -> None:
186
+ """
187
+ Closes the worker pool and terminates all worker processes.
188
+
189
+ Sends a stop signal to each worker and waits for them to terminate.
190
+ """
191
+ logger.debug("Closing ProcessWorkerPoolSingleton...")
192
+ # Send a stop signal (None) for each worker.
193
+ for _ in range(self._total_workers):
194
+ self._task_queue.put(None)
195
+ logger.debug("Sent stop signal to worker.")
196
+ # Wait for all processes to finish.
197
+ for i, p in enumerate(self._processes):
198
+ p.join()
199
+ logger.debug("Worker process %d/%d joined: PID %d", i + 1, self._total_workers, p.pid)
200
+ logger.debug("ProcessWorkerPoolSingleton closed.")
@@ -0,0 +1,161 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Tuple, Optional
6
+ import re
7
+ import logging
8
+
9
+ from nv_ingest_api.internal.primitives.nim import NimClient
10
+ from nv_ingest_api.internal.primitives.nim.model_interface.text_embedding import EmbeddingModelInterface
11
+ from nv_ingest_api.internal.primitives.nim.nim_client import NimClientManager
12
+ from nv_ingest_api.internal.primitives.nim.nim_client import get_nim_client_manager
13
+ from nv_ingest_api.internal.primitives.nim.nim_model_interface import ModelInterface
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ __all__ = ["create_inference_client", "infer_microservice"]
19
+
20
+
21
+ def create_inference_client(
22
+ endpoints: Tuple[str, str],
23
+ model_interface: ModelInterface,
24
+ auth_token: Optional[str] = None,
25
+ infer_protocol: Optional[str] = None,
26
+ timeout: float = 120.0,
27
+ max_retries: int = 10,
28
+ **kwargs,
29
+ ) -> NimClientManager:
30
+ """
31
+ Create a NimClientManager for interfacing with a model inference server.
32
+
33
+ Parameters
34
+ ----------
35
+ endpoints : tuple
36
+ A tuple containing the gRPC and HTTP endpoints.
37
+ model_interface : ModelInterface
38
+ The model interface implementation to use.
39
+ auth_token : str, optional
40
+ Authorization token for HTTP requests (default: None).
41
+ infer_protocol : str, optional
42
+ The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints.
43
+ timeout : float, optional
44
+ The timeout for the request in seconds (default: 120.0).
45
+ max_retries : int, optional
46
+ The maximum number of retries for the request (default: 10).
47
+ **kwargs : dict, optional
48
+ Additional keyword arguments to pass to the NimClientManager.
49
+
50
+ Returns
51
+ -------
52
+ NimClientManager
53
+ The initialized NimClientManager.
54
+
55
+ Raises
56
+ ------
57
+ ValueError
58
+ If an invalid infer_protocol is specified.
59
+ """
60
+
61
+ grpc_endpoint, http_endpoint = endpoints
62
+
63
+ if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
64
+ infer_protocol = "grpc"
65
+ elif infer_protocol is None and http_endpoint:
66
+ infer_protocol = "http"
67
+
68
+ if infer_protocol not in ["grpc", "http"]:
69
+ raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.")
70
+
71
+ manager = get_nim_client_manager()
72
+ client = manager.get_client(
73
+ model_interface=model_interface,
74
+ protocol=infer_protocol,
75
+ endpoints=endpoints,
76
+ auth_token=auth_token,
77
+ timeout=timeout,
78
+ max_retries=max_retries,
79
+ **kwargs,
80
+ )
81
+
82
+ return client
83
+
84
+
85
+ def infer_microservice(
86
+ data,
87
+ model_name: str = None,
88
+ embedding_endpoint: str = None,
89
+ nvidia_api_key: str = None,
90
+ input_type: str = "passage",
91
+ truncate: str = "END",
92
+ batch_size: int = 8191,
93
+ grpc: bool = False,
94
+ input_names: list = ["text"],
95
+ output_names: list = ["embeddings"],
96
+ dtypes: list = ["BYTES"],
97
+ ):
98
+ """
99
+ This function takes the input data and creates a list of embeddings
100
+ using the NVIDIA embedding microservice.
101
+
102
+ Parameters
103
+ ----------
104
+ data : list
105
+ The input data to be embedded.
106
+ model_name : str
107
+ The name of the model to use.
108
+ embedding_endpoint : str
109
+ The endpoint of the embedding microservice.
110
+ nvidia_api_key : str
111
+ The API key for the NVIDIA embedding microservice.
112
+ input_type : str
113
+ The type of input to be embedded.
114
+ truncate : str
115
+ The truncation of the input data.
116
+ batch_size : int
117
+ The batch size of the input data.
118
+ grpc : bool
119
+ Whether to use gRPC or HTTP.
120
+ input_names : list
121
+ The names of the input data.
122
+ output_names : list
123
+ The names of the output data.
124
+ dtypes : list
125
+ The data types of the input data.
126
+
127
+ Returns
128
+ -------
129
+ list
130
+ The list of embeddings.
131
+ """
132
+ if isinstance(data[0], str):
133
+ data = {"prompts": data}
134
+ else:
135
+ data = {"prompts": [res["metadata"]["content"] for res in data]}
136
+ if grpc:
137
+ model_name = re.sub(r"[^a-zA-Z0-9]", "_", model_name)
138
+ client = NimClient(
139
+ model_interface=EmbeddingModelInterface(),
140
+ protocol="grpc",
141
+ endpoints=(embedding_endpoint, None),
142
+ auth_token=nvidia_api_key,
143
+ )
144
+ return client.infer(
145
+ data,
146
+ model_name,
147
+ parameters={"input_type": input_type, "truncate": truncate},
148
+ dtypes=dtypes,
149
+ input_names=input_names,
150
+ batch_size=batch_size,
151
+ output_names=output_names,
152
+ )
153
+ else:
154
+ embedding_endpoint = f"{embedding_endpoint}/embeddings"
155
+ client = NimClient(
156
+ model_interface=EmbeddingModelInterface(),
157
+ protocol="http",
158
+ endpoints=(None, embedding_endpoint),
159
+ auth_token=nvidia_api_key,
160
+ )
161
+ return client.infer(data, model_name, input_type=input_type, truncate=truncate, batch_size=batch_size)
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0