nv-ingest-api 26.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +218 -0
- nv_ingest_api/interface/extract.py +977 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +200 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +186 -0
- nv_ingest_api/internal/__init__.py +0 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +550 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
- nv_ingest_api/internal/extract/html/__init__.py +3 -0
- nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
- nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
- nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
- nv_ingest_api/internal/meta/__init__.py +3 -0
- nv_ingest_api/internal/meta/udf.py +232 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/internal/primitives/control_message_task.py +16 -0
- nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
- nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
- nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
- nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
- nv_ingest_api/internal/schemas/meta/udf.py +23 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +251 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +219 -0
- nv_ingest_api/internal/transform/embed_text.py +702 -0
- nv_ingest_api/internal/transform/split_text.py +182 -0
- nv_ingest_api/util/__init__.py +3 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/dataloader/__init__.py +9 -0
- nv_ingest_api/util/dataloader/dataloader.py +409 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +429 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +177 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
- nv_ingest_api/util/image_processing/transforms.py +850 -0
- nv_ingest_api/util/imports/__init__.py +3 -0
- nv_ingest_api/util/imports/callable_signatures.py +108 -0
- nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
- nv_ingest_api/util/introspection/__init__.py +3 -0
- nv_ingest_api/util/introspection/class_inspect.py +145 -0
- nv_ingest_api/util/introspection/function_inspect.py +65 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +102 -0
- nv_ingest_api/util/logging/sanitize.py +84 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +516 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
- nv_ingest_api/util/nim/__init__.py +161 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +428 -0
- nv_ingest_api/util/schema/__init__.py +3 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +86 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- nv_ingest_api/util/string_processing/configuration.py +682 -0
- nv_ingest_api/util/string_processing/yaml.py +109 -0
- nv_ingest_api/util/system/__init__.py +0 -0
- nv_ingest_api/util/system/hardware_info.py +594 -0
- nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
- nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
- nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
- nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
- nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
- udfs/__init__.py +5 -0
- udfs/llm_summarizer_udf.py +259 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import math
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
from threading import Lock
|
|
12
|
+
from typing import Any, Callable, Optional
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SimpleFuture:
|
|
18
|
+
"""
|
|
19
|
+
A simplified future object that uses a multiprocessing Pipe to receive its result.
|
|
20
|
+
|
|
21
|
+
When the result() method is called, it blocks until the worker sends a tuple
|
|
22
|
+
(result, error) over the pipe.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, parent_conn: mp.connection.Connection) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
parent_conn : mp.connection.Connection
|
|
30
|
+
The parent end of the multiprocessing Pipe used to receive the result.
|
|
31
|
+
"""
|
|
32
|
+
self._parent_conn: mp.connection.Connection = parent_conn
|
|
33
|
+
|
|
34
|
+
def result(self) -> Any:
|
|
35
|
+
"""
|
|
36
|
+
Retrieve the result from the future, blocking until it is available.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
Any
|
|
41
|
+
The result returned by the worker function.
|
|
42
|
+
|
|
43
|
+
Raises
|
|
44
|
+
------
|
|
45
|
+
Exception
|
|
46
|
+
If the worker function raised an exception, it is re-raised here.
|
|
47
|
+
"""
|
|
48
|
+
result, error = self._parent_conn.recv()
|
|
49
|
+
if error is not None:
|
|
50
|
+
raise error
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ProcessWorkerPoolSingleton:
|
|
55
|
+
"""
|
|
56
|
+
A singleton process worker pool using a dual-queue implementation.
|
|
57
|
+
|
|
58
|
+
Instead of a global result queue, each submitted task gets its own Pipe.
|
|
59
|
+
The submit_task() method returns a SimpleFuture, whose result() call blocks
|
|
60
|
+
until the task completes.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
_instance: Optional["ProcessWorkerPoolSingleton"] = None
|
|
64
|
+
_lock: Lock = Lock()
|
|
65
|
+
_total_workers: int = 0
|
|
66
|
+
|
|
67
|
+
def __new__(cls) -> "ProcessWorkerPoolSingleton":
|
|
68
|
+
"""
|
|
69
|
+
Create or return the singleton instance of ProcessWorkerPoolSingleton.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
ProcessWorkerPoolSingleton
|
|
74
|
+
The singleton instance.
|
|
75
|
+
"""
|
|
76
|
+
logger.debug("Creating ProcessWorkerPoolSingleton instance...")
|
|
77
|
+
with cls._lock:
|
|
78
|
+
if cls._instance is None:
|
|
79
|
+
max_worker_limit: int = int(os.environ.get("MAX_INGEST_PROCESS_WORKERS", -1))
|
|
80
|
+
instance = super().__new__(cls)
|
|
81
|
+
# Determine available CPU count using affinity if possible
|
|
82
|
+
available: Optional[int] = (
|
|
83
|
+
len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count()
|
|
84
|
+
)
|
|
85
|
+
# Use 40% of available CPUs, ensuring at least one worker
|
|
86
|
+
max_workers: int = math.floor(max(1, available * 0.4))
|
|
87
|
+
if (max_worker_limit > 0) and (max_workers > max_worker_limit):
|
|
88
|
+
max_workers = max_worker_limit
|
|
89
|
+
logger.debug("Creating ProcessWorkerPoolSingleton instance with max workers: %d", max_workers)
|
|
90
|
+
instance._initialize(max_workers)
|
|
91
|
+
logger.debug("ProcessWorkerPoolSingleton instance created: %s", instance)
|
|
92
|
+
cls._instance = instance
|
|
93
|
+
else:
|
|
94
|
+
logger.debug("ProcessWorkerPoolSingleton instance already exists: %s", cls._instance)
|
|
95
|
+
return cls._instance
|
|
96
|
+
|
|
97
|
+
def _initialize(self, total_max_workers: int) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Initialize the worker pool with the specified number of worker processes.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
total_max_workers : int
|
|
104
|
+
The total number of worker processes to start.
|
|
105
|
+
"""
|
|
106
|
+
self._total_workers = total_max_workers
|
|
107
|
+
|
|
108
|
+
start_method = "fork"
|
|
109
|
+
if sys.platform.lower() == "darwin":
|
|
110
|
+
start_method = "spawn"
|
|
111
|
+
self._context: mp.context.ForkContext = mp.get_context(start_method)
|
|
112
|
+
|
|
113
|
+
# Bounded task queue: maximum tasks queued = 2 * total_max_workers.
|
|
114
|
+
self._task_queue: mp.Queue = self._context.Queue(maxsize=2 * total_max_workers)
|
|
115
|
+
self._next_task_id: int = 0
|
|
116
|
+
self._processes: list[mp.Process] = []
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Initializing ProcessWorkerPoolSingleton with %d workers and queue size %d.",
|
|
119
|
+
total_max_workers,
|
|
120
|
+
2 * total_max_workers,
|
|
121
|
+
)
|
|
122
|
+
for i in range(total_max_workers):
|
|
123
|
+
p: mp.Process = self._context.Process(target=self._worker, args=(self._task_queue,))
|
|
124
|
+
p.start()
|
|
125
|
+
self._processes.append(p)
|
|
126
|
+
logger.debug("Started worker process %d/%d: PID %d", i + 1, total_max_workers, p.pid)
|
|
127
|
+
logger.debug("Initialized with max workers: %d", total_max_workers)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _worker(task_queue: mp.Queue) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Worker process that continuously processes tasks from the task queue.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
task_queue : mp.Queue
|
|
137
|
+
The queue from which tasks are retrieved.
|
|
138
|
+
"""
|
|
139
|
+
logger.debug("Worker process started: PID %d", os.getpid())
|
|
140
|
+
while True:
|
|
141
|
+
task = task_queue.get()
|
|
142
|
+
if task is None:
|
|
143
|
+
# Stop signal received; exit the loop.
|
|
144
|
+
logger.debug("Worker process %d received stop signal.", os.getpid())
|
|
145
|
+
break
|
|
146
|
+
# Unpack task: (task_id, process_fn, args, child_conn)
|
|
147
|
+
task_id, process_fn, args, child_conn = task
|
|
148
|
+
try:
|
|
149
|
+
result = process_fn(*args)
|
|
150
|
+
child_conn.send((result, None))
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.error("Task %d error in worker %d: %s", task_id, os.getpid(), e)
|
|
153
|
+
child_conn.send((None, e))
|
|
154
|
+
finally:
|
|
155
|
+
child_conn.close()
|
|
156
|
+
|
|
157
|
+
def submit_task(self, process_fn: Callable, *args: Any) -> SimpleFuture:
|
|
158
|
+
"""
|
|
159
|
+
Submits a task to the worker pool for asynchronous execution.
|
|
160
|
+
|
|
161
|
+
If a single tuple is passed as the only argument, it is unpacked.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
process_fn : Callable
|
|
166
|
+
The function to be executed asynchronously.
|
|
167
|
+
*args : Any
|
|
168
|
+
The arguments to pass to the process function. If a single argument is a tuple,
|
|
169
|
+
it will be unpacked as the function arguments.
|
|
170
|
+
|
|
171
|
+
Returns
|
|
172
|
+
-------
|
|
173
|
+
SimpleFuture
|
|
174
|
+
A future object that can be used to retrieve the result of the task.
|
|
175
|
+
"""
|
|
176
|
+
# Unpack tuple if a single tuple argument is provided.
|
|
177
|
+
if len(args) == 1 and isinstance(args[0], tuple):
|
|
178
|
+
args = args[0]
|
|
179
|
+
parent_conn, child_conn = mp.Pipe(duplex=False)
|
|
180
|
+
task_id: int = self._next_task_id
|
|
181
|
+
self._next_task_id += 1
|
|
182
|
+
self._task_queue.put((task_id, process_fn, args, child_conn))
|
|
183
|
+
return SimpleFuture(parent_conn)
|
|
184
|
+
|
|
185
|
+
def close(self) -> None:
|
|
186
|
+
"""
|
|
187
|
+
Closes the worker pool and terminates all worker processes.
|
|
188
|
+
|
|
189
|
+
Sends a stop signal to each worker and waits for them to terminate.
|
|
190
|
+
"""
|
|
191
|
+
logger.debug("Closing ProcessWorkerPoolSingleton...")
|
|
192
|
+
# Send a stop signal (None) for each worker.
|
|
193
|
+
for _ in range(self._total_workers):
|
|
194
|
+
self._task_queue.put(None)
|
|
195
|
+
logger.debug("Sent stop signal to worker.")
|
|
196
|
+
# Wait for all processes to finish.
|
|
197
|
+
for i, p in enumerate(self._processes):
|
|
198
|
+
p.join()
|
|
199
|
+
logger.debug("Worker process %d/%d joined: PID %d", i + 1, self._total_workers, p.pid)
|
|
200
|
+
logger.debug("ProcessWorkerPoolSingleton closed.")
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Tuple, Optional
|
|
6
|
+
import re
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from nv_ingest_api.internal.primitives.nim import NimClient
|
|
10
|
+
from nv_ingest_api.internal.primitives.nim.model_interface.text_embedding import EmbeddingModelInterface
|
|
11
|
+
from nv_ingest_api.internal.primitives.nim.nim_client import NimClientManager
|
|
12
|
+
from nv_ingest_api.internal.primitives.nim.nim_client import get_nim_client_manager
|
|
13
|
+
from nv_ingest_api.internal.primitives.nim.nim_model_interface import ModelInterface
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
__all__ = ["create_inference_client", "infer_microservice"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_inference_client(
|
|
22
|
+
endpoints: Tuple[str, str],
|
|
23
|
+
model_interface: ModelInterface,
|
|
24
|
+
auth_token: Optional[str] = None,
|
|
25
|
+
infer_protocol: Optional[str] = None,
|
|
26
|
+
timeout: float = 120.0,
|
|
27
|
+
max_retries: int = 10,
|
|
28
|
+
**kwargs,
|
|
29
|
+
) -> NimClientManager:
|
|
30
|
+
"""
|
|
31
|
+
Create a NimClientManager for interfacing with a model inference server.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
endpoints : tuple
|
|
36
|
+
A tuple containing the gRPC and HTTP endpoints.
|
|
37
|
+
model_interface : ModelInterface
|
|
38
|
+
The model interface implementation to use.
|
|
39
|
+
auth_token : str, optional
|
|
40
|
+
Authorization token for HTTP requests (default: None).
|
|
41
|
+
infer_protocol : str, optional
|
|
42
|
+
The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints.
|
|
43
|
+
timeout : float, optional
|
|
44
|
+
The timeout for the request in seconds (default: 120.0).
|
|
45
|
+
max_retries : int, optional
|
|
46
|
+
The maximum number of retries for the request (default: 10).
|
|
47
|
+
**kwargs : dict, optional
|
|
48
|
+
Additional keyword arguments to pass to the NimClientManager.
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
NimClientManager
|
|
53
|
+
The initialized NimClientManager.
|
|
54
|
+
|
|
55
|
+
Raises
|
|
56
|
+
------
|
|
57
|
+
ValueError
|
|
58
|
+
If an invalid infer_protocol is specified.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
grpc_endpoint, http_endpoint = endpoints
|
|
62
|
+
|
|
63
|
+
if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
|
|
64
|
+
infer_protocol = "grpc"
|
|
65
|
+
elif infer_protocol is None and http_endpoint:
|
|
66
|
+
infer_protocol = "http"
|
|
67
|
+
|
|
68
|
+
if infer_protocol not in ["grpc", "http"]:
|
|
69
|
+
raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.")
|
|
70
|
+
|
|
71
|
+
manager = get_nim_client_manager()
|
|
72
|
+
client = manager.get_client(
|
|
73
|
+
model_interface=model_interface,
|
|
74
|
+
protocol=infer_protocol,
|
|
75
|
+
endpoints=endpoints,
|
|
76
|
+
auth_token=auth_token,
|
|
77
|
+
timeout=timeout,
|
|
78
|
+
max_retries=max_retries,
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return client
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def infer_microservice(
|
|
86
|
+
data,
|
|
87
|
+
model_name: str = None,
|
|
88
|
+
embedding_endpoint: str = None,
|
|
89
|
+
nvidia_api_key: str = None,
|
|
90
|
+
input_type: str = "passage",
|
|
91
|
+
truncate: str = "END",
|
|
92
|
+
batch_size: int = 8191,
|
|
93
|
+
grpc: bool = False,
|
|
94
|
+
input_names: list = ["text"],
|
|
95
|
+
output_names: list = ["embeddings"],
|
|
96
|
+
dtypes: list = ["BYTES"],
|
|
97
|
+
):
|
|
98
|
+
"""
|
|
99
|
+
This function takes the input data and creates a list of embeddings
|
|
100
|
+
using the NVIDIA embedding microservice.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
data : list
|
|
105
|
+
The input data to be embedded.
|
|
106
|
+
model_name : str
|
|
107
|
+
The name of the model to use.
|
|
108
|
+
embedding_endpoint : str
|
|
109
|
+
The endpoint of the embedding microservice.
|
|
110
|
+
nvidia_api_key : str
|
|
111
|
+
The API key for the NVIDIA embedding microservice.
|
|
112
|
+
input_type : str
|
|
113
|
+
The type of input to be embedded.
|
|
114
|
+
truncate : str
|
|
115
|
+
The truncation of the input data.
|
|
116
|
+
batch_size : int
|
|
117
|
+
The batch size of the input data.
|
|
118
|
+
grpc : bool
|
|
119
|
+
Whether to use gRPC or HTTP.
|
|
120
|
+
input_names : list
|
|
121
|
+
The names of the input data.
|
|
122
|
+
output_names : list
|
|
123
|
+
The names of the output data.
|
|
124
|
+
dtypes : list
|
|
125
|
+
The data types of the input data.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
list
|
|
130
|
+
The list of embeddings.
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(data[0], str):
|
|
133
|
+
data = {"prompts": data}
|
|
134
|
+
else:
|
|
135
|
+
data = {"prompts": [res["metadata"]["content"] for res in data]}
|
|
136
|
+
if grpc:
|
|
137
|
+
model_name = re.sub(r"[^a-zA-Z0-9]", "_", model_name)
|
|
138
|
+
client = NimClient(
|
|
139
|
+
model_interface=EmbeddingModelInterface(),
|
|
140
|
+
protocol="grpc",
|
|
141
|
+
endpoints=(embedding_endpoint, None),
|
|
142
|
+
auth_token=nvidia_api_key,
|
|
143
|
+
)
|
|
144
|
+
return client.infer(
|
|
145
|
+
data,
|
|
146
|
+
model_name,
|
|
147
|
+
parameters={"input_type": input_type, "truncate": truncate},
|
|
148
|
+
dtypes=dtypes,
|
|
149
|
+
input_names=input_names,
|
|
150
|
+
batch_size=batch_size,
|
|
151
|
+
output_names=output_names,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
embedding_endpoint = f"{embedding_endpoint}/embeddings"
|
|
155
|
+
client = NimClient(
|
|
156
|
+
model_interface=EmbeddingModelInterface(),
|
|
157
|
+
protocol="http",
|
|
158
|
+
endpoints=(None, embedding_endpoint),
|
|
159
|
+
auth_token=nvidia_api_key,
|
|
160
|
+
)
|
|
161
|
+
return client.infer(data, model_name, input_type=input_type, truncate=truncate, batch_size=batch_size)
|