nv-ingest-api 2025.4.15.dev20250415__py3-none-any.whl → 2025.4.17.dev20250417__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/__init__.py +3 -0
- nv_ingest_api/interface/__init__.py +215 -0
- nv_ingest_api/interface/extract.py +972 -0
- nv_ingest_api/interface/mutate.py +154 -0
- nv_ingest_api/interface/store.py +218 -0
- nv_ingest_api/interface/transform.py +382 -0
- nv_ingest_api/interface/utility.py +200 -0
- nv_ingest_api/internal/enums/__init__.py +3 -0
- nv_ingest_api/internal/enums/common.py +494 -0
- nv_ingest_api/internal/extract/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/__init__.py +3 -0
- nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
- nv_ingest_api/internal/extract/docx/__init__.py +5 -0
- nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
- nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
- nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
- nv_ingest_api/internal/extract/image/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
- nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
- nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
- nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
- nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
- nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
- nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
- nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
- nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
- nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
- nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
- nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
- nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
- nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
- nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
- nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
- nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
- nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
- nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
- nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
- nv_ingest_api/internal/mutate/__init__.py +3 -0
- nv_ingest_api/internal/mutate/deduplicate.py +110 -0
- nv_ingest_api/internal/mutate/filter.py +133 -0
- nv_ingest_api/internal/primitives/__init__.py +0 -0
- nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
- nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
- nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
- nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
- nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
- nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
- nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
- nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
- nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
- nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
- nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
- nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
- nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
- nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
- nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
- nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
- nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
- nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
- nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
- nv_ingest_api/internal/schemas/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
- nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
- nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
- nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
- nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
- nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
- nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
- nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
- nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
- nv_ingest_api/internal/schemas/store/__init__.py +3 -0
- nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
- nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
- nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
- nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
- nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
- nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
- nv_ingest_api/internal/store/__init__.py +3 -0
- nv_ingest_api/internal/store/embed_text_upload.py +236 -0
- nv_ingest_api/internal/store/image_upload.py +232 -0
- nv_ingest_api/internal/transform/__init__.py +3 -0
- nv_ingest_api/internal/transform/caption_image.py +205 -0
- nv_ingest_api/internal/transform/embed_text.py +496 -0
- nv_ingest_api/internal/transform/split_text.py +157 -0
- nv_ingest_api/util/__init__.py +0 -0
- nv_ingest_api/util/control_message/__init__.py +0 -0
- nv_ingest_api/util/control_message/validators.py +47 -0
- nv_ingest_api/util/converters/__init__.py +0 -0
- nv_ingest_api/util/converters/bytetools.py +78 -0
- nv_ingest_api/util/converters/containers.py +65 -0
- nv_ingest_api/util/converters/datetools.py +90 -0
- nv_ingest_api/util/converters/dftools.py +127 -0
- nv_ingest_api/util/converters/formats.py +64 -0
- nv_ingest_api/util/converters/type_mappings.py +27 -0
- nv_ingest_api/util/detectors/__init__.py +5 -0
- nv_ingest_api/util/detectors/language.py +38 -0
- nv_ingest_api/util/exception_handlers/__init__.py +0 -0
- nv_ingest_api/util/exception_handlers/converters.py +72 -0
- nv_ingest_api/util/exception_handlers/decorators.py +223 -0
- nv_ingest_api/util/exception_handlers/detectors.py +74 -0
- nv_ingest_api/util/exception_handlers/pdf.py +116 -0
- nv_ingest_api/util/exception_handlers/schemas.py +68 -0
- nv_ingest_api/util/image_processing/__init__.py +5 -0
- nv_ingest_api/util/image_processing/clustering.py +260 -0
- nv_ingest_api/util/image_processing/processing.py +179 -0
- nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
- nv_ingest_api/util/image_processing/transforms.py +407 -0
- nv_ingest_api/util/logging/__init__.py +0 -0
- nv_ingest_api/util/logging/configuration.py +31 -0
- nv_ingest_api/util/message_brokers/__init__.py +3 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +435 -0
- nv_ingest_api/util/metadata/__init__.py +5 -0
- nv_ingest_api/util/metadata/aggregators.py +469 -0
- nv_ingest_api/util/multi_processing/__init__.py +8 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
- nv_ingest_api/util/nim/__init__.py +56 -0
- nv_ingest_api/util/pdf/__init__.py +3 -0
- nv_ingest_api/util/pdf/pdfium.py +427 -0
- nv_ingest_api/util/schema/__init__.py +0 -0
- nv_ingest_api/util/schema/schema_validator.py +10 -0
- nv_ingest_api/util/service_clients/__init__.py +3 -0
- nv_ingest_api/util/service_clients/client_base.py +72 -0
- nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
- nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
- nv_ingest_api/util/service_clients/redis/redis_client.py +334 -0
- nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
- nv_ingest_api/util/service_clients/rest/rest_client.py +398 -0
- nv_ingest_api/util/string_processing/__init__.py +51 -0
- {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/METADATA +1 -1
- nv_ingest_api-2025.4.17.dev20250417.dist-info/RECORD +152 -0
- nv_ingest_api-2025.4.15.dev20250415.dist-info/RECORD +0 -9
- /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
- {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/WHEEL +0 -0
- {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import base64
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any
|
|
9
|
+
from typing import List
|
|
10
|
+
from typing import Optional
|
|
11
|
+
from typing import Tuple
|
|
12
|
+
|
|
13
|
+
import grpc
|
|
14
|
+
import numpy as np
|
|
15
|
+
import riva.client
|
|
16
|
+
from scipy.io import wavfile
|
|
17
|
+
|
|
18
|
+
from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import librosa
|
|
22
|
+
except ImportError:
|
|
23
|
+
librosa = None
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ParakeetClient:
|
|
29
|
+
"""
|
|
30
|
+
A simple interface for handling inference with a Parakeet model (e.g., speech, audio-related).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
endpoint: str,
|
|
36
|
+
auth_token: Optional[str] = None,
|
|
37
|
+
function_id: Optional[str] = None,
|
|
38
|
+
use_ssl: Optional[bool] = None,
|
|
39
|
+
ssl_cert: Optional[str] = None,
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Initialize the ParakeetClient.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
endpoint : str
|
|
47
|
+
The URL of the Parakeet service endpoint.
|
|
48
|
+
auth_token : Optional[str], default=None
|
|
49
|
+
The authentication token for accessing the service.
|
|
50
|
+
function_id: Optional[str]
|
|
51
|
+
The NVCF function ID for invoking the service.
|
|
52
|
+
use_ssl : bool, default=False
|
|
53
|
+
Whether to use SSL for the connection.
|
|
54
|
+
ssl_cert : Optional[str], default=None
|
|
55
|
+
Path to the SSL certificate if required.
|
|
56
|
+
auth_metadata : Optional[List[Tuple[str, str]]], default=None
|
|
57
|
+
Additional authentication metadata for the service.
|
|
58
|
+
"""
|
|
59
|
+
self.endpoint = endpoint
|
|
60
|
+
self.auth_token = auth_token
|
|
61
|
+
self.function_id = function_id
|
|
62
|
+
if use_ssl is None:
|
|
63
|
+
self.use_ssl = True if ("grpc.nvcf.nvidia.com" in self.endpoint) and self.function_id else False
|
|
64
|
+
else:
|
|
65
|
+
self.use_ssl = use_ssl
|
|
66
|
+
self.ssl_cert = ssl_cert
|
|
67
|
+
|
|
68
|
+
self.auth_metadata = []
|
|
69
|
+
if self.auth_token:
|
|
70
|
+
self.auth_metadata.append(("authorization", f"Bearer {self.auth_token}"))
|
|
71
|
+
if self.function_id:
|
|
72
|
+
self.auth_metadata.append(("function-id", self.function_id))
|
|
73
|
+
|
|
74
|
+
# Create authentication and ASR service objects.
|
|
75
|
+
self._auth = riva.client.Auth(self.ssl_cert, self.use_ssl, self.endpoint, self.auth_metadata)
|
|
76
|
+
self._asr_service = riva.client.ASRService(self._auth)
|
|
77
|
+
|
|
78
|
+
@traceable_func(trace_name="{stage_name}::{model_name}")
|
|
79
|
+
def infer(self, data: dict, model_name: str, **kwargs) -> Any:
|
|
80
|
+
"""
|
|
81
|
+
Perform inference using the specified model and input data.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
data : dict
|
|
86
|
+
The input data for inference.
|
|
87
|
+
model_name : str
|
|
88
|
+
The model name.
|
|
89
|
+
kwargs : dict
|
|
90
|
+
Additional parameters for inference.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Any
|
|
95
|
+
The processed inference results, coalesced in the same order as the input images.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
response = self.transcribe(data)
|
|
99
|
+
if response is None:
|
|
100
|
+
return None
|
|
101
|
+
segments, transcript = process_transcription_response(response)
|
|
102
|
+
logger.debug("Processing Parakeet inference results (pass-through).")
|
|
103
|
+
|
|
104
|
+
return transcript
|
|
105
|
+
|
|
106
|
+
def transcribe(
|
|
107
|
+
self,
|
|
108
|
+
audio_content: str,
|
|
109
|
+
language_code: str = "en-US",
|
|
110
|
+
automatic_punctuation: bool = True,
|
|
111
|
+
word_time_offsets: bool = True,
|
|
112
|
+
max_alternatives: int = 1,
|
|
113
|
+
profanity_filter: bool = False,
|
|
114
|
+
verbatim_transcripts: bool = True,
|
|
115
|
+
speaker_diarization: bool = False,
|
|
116
|
+
boosted_lm_words: Optional[List[str]] = None,
|
|
117
|
+
boosted_lm_score: float = 0.0,
|
|
118
|
+
diarization_max_speakers: int = 0,
|
|
119
|
+
start_history: float = 0.0,
|
|
120
|
+
start_threshold: float = 0.0,
|
|
121
|
+
stop_history: float = 0.0,
|
|
122
|
+
stop_history_eou: bool = False,
|
|
123
|
+
stop_threshold: float = 0.0,
|
|
124
|
+
stop_threshold_eou: bool = False,
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Transcribe an audio file using Riva ASR.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
audio_content : str
|
|
132
|
+
Base64-encoded audio content to be transcribed.
|
|
133
|
+
language_code : str, default="en-US"
|
|
134
|
+
The language code for transcription.
|
|
135
|
+
automatic_punctuation : bool, default=True
|
|
136
|
+
Whether to enable automatic punctuation in the transcript.
|
|
137
|
+
word_time_offsets : bool, default=True
|
|
138
|
+
Whether to include word-level timestamps in the transcript.
|
|
139
|
+
max_alternatives : int, default=1
|
|
140
|
+
The maximum number of alternative transcripts to return.
|
|
141
|
+
profanity_filter : bool, default=False
|
|
142
|
+
Whether to filter out profanity from the transcript.
|
|
143
|
+
verbatim_transcripts : bool, default=True
|
|
144
|
+
Whether to return verbatim transcripts without normalization.
|
|
145
|
+
speaker_diarization : bool, default=False
|
|
146
|
+
Whether to enable speaker diarization.
|
|
147
|
+
boosted_lm_words : Optional[List[str]], default=None
|
|
148
|
+
A list of words to boost for language modeling.
|
|
149
|
+
boosted_lm_score : float, default=0.0
|
|
150
|
+
The boosting score for language model words.
|
|
151
|
+
diarization_max_speakers : int, default=0
|
|
152
|
+
The maximum number of speakers to differentiate in speaker diarization.
|
|
153
|
+
start_history : float, default=0.0
|
|
154
|
+
History window size for endpoint detection.
|
|
155
|
+
start_threshold : float, default=0.0
|
|
156
|
+
The threshold for starting speech detection.
|
|
157
|
+
stop_history : float, default=0.0
|
|
158
|
+
History window size for stopping speech detection.
|
|
159
|
+
stop_history_eou : bool, default=False
|
|
160
|
+
Whether to use an end-of-utterance flag for stopping detection.
|
|
161
|
+
stop_threshold : float, default=0.0
|
|
162
|
+
The threshold for stopping speech detection.
|
|
163
|
+
stop_threshold_eou : bool, default=False
|
|
164
|
+
Whether to use an end-of-utterance flag for stop threshold.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
Optional[riva.client.RecognitionResponse]
|
|
169
|
+
The response containing the transcription results.
|
|
170
|
+
Returns None if the transcription fails.
|
|
171
|
+
"""
|
|
172
|
+
# Build the recognition configuration.
|
|
173
|
+
recognition_config = riva.client.RecognitionConfig(
|
|
174
|
+
language_code=language_code,
|
|
175
|
+
max_alternatives=max_alternatives,
|
|
176
|
+
profanity_filter=profanity_filter,
|
|
177
|
+
enable_automatic_punctuation=automatic_punctuation,
|
|
178
|
+
verbatim_transcripts=verbatim_transcripts,
|
|
179
|
+
enable_word_time_offsets=word_time_offsets,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Add additional configuration parameters.
|
|
183
|
+
riva.client.add_word_boosting_to_config(
|
|
184
|
+
recognition_config,
|
|
185
|
+
boosted_lm_words or [],
|
|
186
|
+
boosted_lm_score,
|
|
187
|
+
)
|
|
188
|
+
riva.client.add_speaker_diarization_to_config(
|
|
189
|
+
recognition_config,
|
|
190
|
+
speaker_diarization,
|
|
191
|
+
diarization_max_speakers,
|
|
192
|
+
)
|
|
193
|
+
riva.client.add_endpoint_parameters_to_config(
|
|
194
|
+
recognition_config,
|
|
195
|
+
start_history,
|
|
196
|
+
start_threshold,
|
|
197
|
+
stop_history,
|
|
198
|
+
stop_history_eou,
|
|
199
|
+
stop_threshold,
|
|
200
|
+
stop_threshold_eou,
|
|
201
|
+
)
|
|
202
|
+
audio_bytes = base64.b64decode(audio_content)
|
|
203
|
+
mono_audio_bytes = convert_to_mono_wav(audio_bytes)
|
|
204
|
+
|
|
205
|
+
# Perform offline recognition and print the transcript.
|
|
206
|
+
try:
|
|
207
|
+
response = self._asr_service.offline_recognize(mono_audio_bytes, recognition_config)
|
|
208
|
+
return response
|
|
209
|
+
except grpc.RpcError as e:
|
|
210
|
+
logger.exception(f"Error transcribing audio file: {e.details()}")
|
|
211
|
+
raise
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def convert_to_mono_wav(audio_bytes):
|
|
215
|
+
"""
|
|
216
|
+
Convert an audio file to mono WAV format using Librosa and SciPy.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
audio_bytes : bytes
|
|
221
|
+
The raw audio data in bytes.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
bytes
|
|
226
|
+
The processed audio in mono WAV format.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
if librosa is None:
|
|
230
|
+
raise ImportError(
|
|
231
|
+
"Librosa is required for audio processing. "
|
|
232
|
+
"If you are running this code with the ingest container, it can be installed by setting "
|
|
233
|
+
"the environment variable. INSTALL_AUDIO_EXTRACTION_DEPS=true"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Create a BytesIO object from the audio bytes
|
|
237
|
+
byte_io = io.BytesIO(audio_bytes)
|
|
238
|
+
|
|
239
|
+
# Load the audio file with librosa
|
|
240
|
+
# librosa.load automatically converts to mono by default
|
|
241
|
+
audio_data, sample_rate = librosa.load(byte_io, sr=44100, mono=True)
|
|
242
|
+
|
|
243
|
+
# Ensure audio is properly scaled for 16-bit PCM
|
|
244
|
+
# Librosa normalizes the data between -1 and 1
|
|
245
|
+
if np.max(np.abs(audio_data)) > 0:
|
|
246
|
+
audio_data = audio_data / np.max(np.abs(audio_data)) * 0.9
|
|
247
|
+
|
|
248
|
+
# Convert to int16 format for 16-bit PCM WAV
|
|
249
|
+
audio_data_int16 = (audio_data * 32767).astype(np.int16)
|
|
250
|
+
|
|
251
|
+
# Create a BytesIO buffer to write the WAV file
|
|
252
|
+
output_io = io.BytesIO()
|
|
253
|
+
|
|
254
|
+
# Write the WAV data using scipy
|
|
255
|
+
wavfile.write(output_io, sample_rate, audio_data_int16)
|
|
256
|
+
|
|
257
|
+
# Reset the file pointer to the beginning and read all contents
|
|
258
|
+
output_io.seek(0)
|
|
259
|
+
wav_bytes = output_io.read()
|
|
260
|
+
|
|
261
|
+
return wav_bytes
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def process_transcription_response(response):
|
|
265
|
+
"""
|
|
266
|
+
Process a Riva transcription response (a protobuf message) to extract:
|
|
267
|
+
- final_transcript: the complete transcript.
|
|
268
|
+
- segments: a list of segments with start/end times and text.
|
|
269
|
+
|
|
270
|
+
Parameters:
|
|
271
|
+
response: The Riva transcription response message.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
segments (list): Each segment is a dict with keys "start", "end", and "text".
|
|
275
|
+
final_transcript (str): The overall transcript.
|
|
276
|
+
"""
|
|
277
|
+
words_list = []
|
|
278
|
+
# Iterate directly over the results.
|
|
279
|
+
for result in response.results:
|
|
280
|
+
# Ensure there is at least one alternative.
|
|
281
|
+
if not result.alternatives:
|
|
282
|
+
continue
|
|
283
|
+
alternative = result.alternatives[0]
|
|
284
|
+
# Each alternative has a repeated field "words"
|
|
285
|
+
for word_info in alternative.words:
|
|
286
|
+
words_list.append(word_info)
|
|
287
|
+
|
|
288
|
+
# Build the overall transcript by joining the word strings.
|
|
289
|
+
final_transcript = " ".join(word.word for word in words_list)
|
|
290
|
+
|
|
291
|
+
# Now, segment the transcript based on punctuation.
|
|
292
|
+
segments = []
|
|
293
|
+
current_words = []
|
|
294
|
+
segment_start = None
|
|
295
|
+
segment_end = None
|
|
296
|
+
punctuation_marks = {".", "?", "!"}
|
|
297
|
+
|
|
298
|
+
for word in words_list:
|
|
299
|
+
# Mark the start of a segment if not already set.
|
|
300
|
+
if segment_start is None:
|
|
301
|
+
segment_start = word.start_time
|
|
302
|
+
segment_end = word.end_time
|
|
303
|
+
current_words.append(word.word)
|
|
304
|
+
|
|
305
|
+
# End the segment when a word ends with punctuation.
|
|
306
|
+
if word.word and word.word[-1] in punctuation_marks:
|
|
307
|
+
segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
|
|
308
|
+
current_words = []
|
|
309
|
+
segment_start = None
|
|
310
|
+
segment_end = None
|
|
311
|
+
|
|
312
|
+
# Add any remaining words as a segment.
|
|
313
|
+
if current_words:
|
|
314
|
+
segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
|
|
315
|
+
|
|
316
|
+
return segments, final_transcript
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def create_audio_inference_client(
|
|
320
|
+
endpoints: Tuple[str, str],
|
|
321
|
+
infer_protocol: Optional[str] = None,
|
|
322
|
+
auth_token: Optional[str] = None,
|
|
323
|
+
function_id: Optional[str] = None,
|
|
324
|
+
use_ssl: bool = False,
|
|
325
|
+
ssl_cert: Optional[str] = None,
|
|
326
|
+
):
|
|
327
|
+
"""
|
|
328
|
+
Create a ParakeetClient for interfacing with an audio model inference server.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
endpoints : tuple
|
|
333
|
+
A tuple containing the gRPC and HTTP endpoints. Only the gRPC endpoint is used.
|
|
334
|
+
infer_protocol : str, optional
|
|
335
|
+
The protocol to use ("grpc" or "http").
|
|
336
|
+
If not specified, defaults to "grpc" if a valid gRPC endpoint is provided.
|
|
337
|
+
HTTP endpoints are not supported for audio inference.
|
|
338
|
+
auth_token : str, optional
|
|
339
|
+
Authorization token for authentication (default: None).
|
|
340
|
+
function_id : str, optional
|
|
341
|
+
NVCF function ID of the invocation (default: None)
|
|
342
|
+
use_ssl : bool, optional
|
|
343
|
+
Whether to use SSL for secure communication (default: False).
|
|
344
|
+
ssl_cert : str, optional
|
|
345
|
+
Path to the SSL certificate file if `use_ssl` is enabled (default: None).
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
ParakeetClient
|
|
350
|
+
The initialized ParakeetClient configured for audio inference over gRPC.
|
|
351
|
+
|
|
352
|
+
Raises
|
|
353
|
+
------
|
|
354
|
+
ValueError
|
|
355
|
+
If an invalid `infer_protocol` is specified or if an HTTP endpoint is provided.
|
|
356
|
+
"""
|
|
357
|
+
grpc_endpoint, http_endpoint = endpoints
|
|
358
|
+
|
|
359
|
+
if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
|
|
360
|
+
infer_protocol = "grpc"
|
|
361
|
+
|
|
362
|
+
if infer_protocol == "http":
|
|
363
|
+
raise ValueError("`http` endpoints are not supported for audio. Use `grpc`.")
|
|
364
|
+
|
|
365
|
+
return ParakeetClient(
|
|
366
|
+
grpc_endpoint, auth_token=auth_token, function_id=function_id, use_ssl=use_ssl, ssl_cert=ssl_cert
|
|
367
|
+
)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Assume ModelInterface is defined elsewhere in the project.
|
|
11
|
+
class EmbeddingModelInterface(ModelInterface):
|
|
12
|
+
"""
|
|
13
|
+
An interface for handling inference with an embedding model endpoint.
|
|
14
|
+
This implementation supports HTTP inference for generating embeddings from text prompts.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def name(self) -> str:
|
|
18
|
+
"""
|
|
19
|
+
Return the name of this model interface.
|
|
20
|
+
"""
|
|
21
|
+
return "Embedding"
|
|
22
|
+
|
|
23
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
24
|
+
"""
|
|
25
|
+
Prepare input data for embedding inference. Ensures that a 'prompts' key is provided
|
|
26
|
+
and that its value is a list.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
KeyError
|
|
31
|
+
If the 'prompts' key is missing.
|
|
32
|
+
"""
|
|
33
|
+
if "prompts" not in data:
|
|
34
|
+
raise KeyError("Input data must include 'prompts'.")
|
|
35
|
+
# Ensure the prompts are in list format.
|
|
36
|
+
if not isinstance(data["prompts"], list):
|
|
37
|
+
data["prompts"] = [data["prompts"]]
|
|
38
|
+
return data
|
|
39
|
+
|
|
40
|
+
def format_input(
|
|
41
|
+
self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
|
|
42
|
+
) -> Tuple[List[Any], List[Dict[str, Any]]]:
|
|
43
|
+
"""
|
|
44
|
+
Format the input payload for the embedding endpoint. This method constructs one payload per batch,
|
|
45
|
+
where each payload includes a list of text prompts.
|
|
46
|
+
Additionally, it returns batch data that preserves the original order of prompts.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
data : dict
|
|
51
|
+
The input data containing "prompts" (a list of text prompts).
|
|
52
|
+
protocol : str
|
|
53
|
+
Only "http" is supported.
|
|
54
|
+
max_batch_size : int
|
|
55
|
+
Maximum number of prompts per payload.
|
|
56
|
+
kwargs : dict
|
|
57
|
+
Additional parameters including model_name, encoding_format, input_type, and truncate.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
tuple
|
|
62
|
+
A tuple (payloads, batch_data_list) where:
|
|
63
|
+
- payloads is a list of JSON-serializable payload dictionaries.
|
|
64
|
+
- batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
|
|
65
|
+
"""
|
|
66
|
+
if protocol != "http":
|
|
67
|
+
raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
|
|
68
|
+
|
|
69
|
+
prompts = data.get("prompts", [])
|
|
70
|
+
|
|
71
|
+
def chunk_list(lst, chunk_size):
|
|
72
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
73
|
+
|
|
74
|
+
batches = chunk_list(prompts, max_batch_size)
|
|
75
|
+
payloads = []
|
|
76
|
+
batch_data_list = []
|
|
77
|
+
for batch in batches:
|
|
78
|
+
payload = {
|
|
79
|
+
"model": kwargs.get("model_name"),
|
|
80
|
+
"input": batch,
|
|
81
|
+
"encoding_format": kwargs.get("encoding_format", "float"),
|
|
82
|
+
"extra_body": {
|
|
83
|
+
"input_type": kwargs.get("input_type", "query"),
|
|
84
|
+
"truncate": kwargs.get("truncate", "NONE"),
|
|
85
|
+
},
|
|
86
|
+
}
|
|
87
|
+
payloads.append(payload)
|
|
88
|
+
batch_data_list.append({"prompts": batch})
|
|
89
|
+
return payloads, batch_data_list
|
|
90
|
+
|
|
91
|
+
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
|
|
92
|
+
"""
|
|
93
|
+
Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
response : Any
|
|
98
|
+
The raw HTTP response (assumed to be already decoded as JSON).
|
|
99
|
+
protocol : str
|
|
100
|
+
Only "http" is supported.
|
|
101
|
+
data : dict, optional
|
|
102
|
+
The original input data.
|
|
103
|
+
kwargs : dict
|
|
104
|
+
Additional keyword arguments.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
list
|
|
109
|
+
A list of generated embeddings extracted from the response.
|
|
110
|
+
"""
|
|
111
|
+
if protocol != "http":
|
|
112
|
+
raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
|
|
113
|
+
if isinstance(response, dict):
|
|
114
|
+
embeddings = response.get("data")
|
|
115
|
+
if not embeddings:
|
|
116
|
+
raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
|
|
117
|
+
# Each item in embeddings is expected to have an 'embedding' field.
|
|
118
|
+
return [item.get("embedding", None) for item in embeddings]
|
|
119
|
+
else:
|
|
120
|
+
return [str(response)]
|
|
121
|
+
|
|
122
|
+
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
|
|
123
|
+
"""
|
|
124
|
+
Process inference results for the embedding model.
|
|
125
|
+
For this implementation, the output is expected to be a list of embeddings.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
list
|
|
130
|
+
The processed list of embeddings.
|
|
131
|
+
"""
|
|
132
|
+
return output
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Any, Optional, Tuple, List
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from nv_ingest_api.internal.primitives.nim import ModelInterface
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VLMModelInterface(ModelInterface):
|
|
15
|
+
"""
|
|
16
|
+
An interface for handling inference with a VLM model endpoint (e.g., NVIDIA LLaMA-based VLM).
|
|
17
|
+
This implementation supports HTTP inference with one or more base64-encoded images and a caption prompt.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def name(self) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Return the name of this model interface.
|
|
23
|
+
"""
|
|
24
|
+
return "VLM"
|
|
25
|
+
|
|
26
|
+
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
27
|
+
"""
|
|
28
|
+
Prepare input data for VLM inference. Accepts either a single base64 image or a list of images.
|
|
29
|
+
Ensures that a 'prompt' is provided.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
KeyError
|
|
34
|
+
If neither "base64_image" nor "base64_images" is provided or if "prompt" is missing.
|
|
35
|
+
ValueError
|
|
36
|
+
If "base64_images" exists but is not a list.
|
|
37
|
+
"""
|
|
38
|
+
# Allow either a single image with "base64_image" or multiple images with "base64_images".
|
|
39
|
+
if "base64_images" in data:
|
|
40
|
+
if not isinstance(data["base64_images"], list):
|
|
41
|
+
raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
|
|
42
|
+
elif "base64_image" in data:
|
|
43
|
+
# Convert a single image into a list.
|
|
44
|
+
data["base64_images"] = [data["base64_image"]]
|
|
45
|
+
else:
|
|
46
|
+
raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
|
|
47
|
+
|
|
48
|
+
if "prompt" not in data:
|
|
49
|
+
raise KeyError("Input data must include 'prompt'.")
|
|
50
|
+
return data
|
|
51
|
+
|
|
52
|
+
def format_input(
|
|
53
|
+
self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
|
|
54
|
+
) -> Tuple[List[Any], List[Dict[str, Any]]]:
|
|
55
|
+
"""
|
|
56
|
+
Format the input payload for the VLM endpoint. This method constructs one payload per batch,
|
|
57
|
+
where each payload includes one message per image in the batch.
|
|
58
|
+
Additionally, it returns batch data that preserves the original order of images by including
|
|
59
|
+
the list of base64 images and the prompt for each batch.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
data : dict
|
|
64
|
+
The input data containing "base64_images" (a list of base64-encoded images) and "prompt".
|
|
65
|
+
protocol : str
|
|
66
|
+
Only "http" is supported.
|
|
67
|
+
max_batch_size : int
|
|
68
|
+
Maximum number of images per payload.
|
|
69
|
+
kwargs : dict
|
|
70
|
+
Additional parameters including model_name, max_tokens, temperature, top_p, and stream.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
tuple
|
|
75
|
+
A tuple (payloads, batch_data_list) where:
|
|
76
|
+
- payloads is a list of JSON-serializable payload dictionaries.
|
|
77
|
+
- batch_data_list is a list of dictionaries containing the keys "base64_images" and "prompt"
|
|
78
|
+
corresponding to each batch.
|
|
79
|
+
"""
|
|
80
|
+
if protocol != "http":
|
|
81
|
+
raise ValueError("VLMModelInterface only supports HTTP protocol.")
|
|
82
|
+
|
|
83
|
+
images = data.get("base64_images", [])
|
|
84
|
+
prompt = data["prompt"]
|
|
85
|
+
|
|
86
|
+
# Helper function to chunk the list into batches.
|
|
87
|
+
def chunk_list(lst, chunk_size):
|
|
88
|
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
89
|
+
|
|
90
|
+
batches = chunk_list(images, max_batch_size)
|
|
91
|
+
payloads = []
|
|
92
|
+
batch_data_list = []
|
|
93
|
+
for batch in batches:
|
|
94
|
+
# Create one message per image in the batch.
|
|
95
|
+
messages = [
|
|
96
|
+
{"role": "user", "content": f'{prompt} <img src="data:image/png;base64,{img}" />'} for img in batch
|
|
97
|
+
]
|
|
98
|
+
payload = {
|
|
99
|
+
"model": kwargs.get("model_name"),
|
|
100
|
+
"messages": messages,
|
|
101
|
+
"max_tokens": kwargs.get("max_tokens", 512),
|
|
102
|
+
"temperature": kwargs.get("temperature", 1.0),
|
|
103
|
+
"top_p": kwargs.get("top_p", 1.0),
|
|
104
|
+
"stream": kwargs.get("stream", False),
|
|
105
|
+
}
|
|
106
|
+
payloads.append(payload)
|
|
107
|
+
batch_data_list.append({"base64_images": batch, "prompt": prompt})
|
|
108
|
+
return payloads, batch_data_list
|
|
109
|
+
|
|
110
|
+
def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
|
|
111
|
+
"""
|
|
112
|
+
Parse the HTTP response from the VLM endpoint. Expects a response structure with a "choices" key.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
response : Any
|
|
117
|
+
The raw HTTP response (assumed to be already decoded as JSON).
|
|
118
|
+
protocol : str
|
|
119
|
+
Only "http" is supported.
|
|
120
|
+
data : dict, optional
|
|
121
|
+
The original input data.
|
|
122
|
+
kwargs : dict
|
|
123
|
+
Additional keyword arguments.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
list
|
|
128
|
+
A list of generated captions extracted from the response.
|
|
129
|
+
"""
|
|
130
|
+
if protocol != "http":
|
|
131
|
+
raise ValueError("VLMModelInterface only supports HTTP protocol.")
|
|
132
|
+
if isinstance(response, dict):
|
|
133
|
+
choices = response.get("choices", [])
|
|
134
|
+
if not choices:
|
|
135
|
+
raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
|
|
136
|
+
# Return a list of captions, one per choice.
|
|
137
|
+
return [choice.get("message", {}).get("content", "No caption returned") for choice in choices]
|
|
138
|
+
else:
|
|
139
|
+
# If response is not a dict, return its string representation in a list.
|
|
140
|
+
return [str(response)]
|
|
141
|
+
|
|
142
|
+
def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
|
|
143
|
+
"""
|
|
144
|
+
Process inference results for the VLM model.
|
|
145
|
+
For this implementation, the output is expected to be a list of captions.
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
list
|
|
150
|
+
The processed list of captions.
|
|
151
|
+
"""
|
|
152
|
+
return output
|