nv-ingest-api 2025.4.16.dev20250416__py3-none-any.whl → 2025.4.17.dev20250417__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (153) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +215 -0
  3. nv_ingest_api/interface/extract.py +972 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +218 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +200 -0
  8. nv_ingest_api/internal/enums/__init__.py +3 -0
  9. nv_ingest_api/internal/enums/common.py +494 -0
  10. nv_ingest_api/internal/extract/__init__.py +3 -0
  11. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
  13. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  14. nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
  15. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  16. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
  19. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  20. nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
  21. nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
  22. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
  24. nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
  25. nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
  26. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  27. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  28. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  29. nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
  30. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
  31. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
  32. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
  33. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  34. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  35. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  36. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  37. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  38. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
  39. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
  40. nv_ingest_api/internal/mutate/__init__.py +3 -0
  41. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  42. nv_ingest_api/internal/mutate/filter.py +133 -0
  43. nv_ingest_api/internal/primitives/__init__.py +0 -0
  44. nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
  45. nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
  46. nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
  47. nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
  48. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  49. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  50. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  51. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  52. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
  53. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
  59. nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
  60. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
  61. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  62. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  63. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  64. nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
  65. nv_ingest_api/internal/schemas/__init__.py +3 -0
  66. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  67. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
  68. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
  69. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
  70. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
  71. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
  72. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
  73. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
  74. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
  75. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  76. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
  77. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  78. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  79. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  80. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  81. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
  82. nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
  83. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  85. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  86. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  87. nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
  88. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  89. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
  90. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  91. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
  92. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
  93. nv_ingest_api/internal/store/__init__.py +3 -0
  94. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  95. nv_ingest_api/internal/store/image_upload.py +232 -0
  96. nv_ingest_api/internal/transform/__init__.py +3 -0
  97. nv_ingest_api/internal/transform/caption_image.py +205 -0
  98. nv_ingest_api/internal/transform/embed_text.py +496 -0
  99. nv_ingest_api/internal/transform/split_text.py +157 -0
  100. nv_ingest_api/util/__init__.py +0 -0
  101. nv_ingest_api/util/control_message/__init__.py +0 -0
  102. nv_ingest_api/util/control_message/validators.py +47 -0
  103. nv_ingest_api/util/converters/__init__.py +0 -0
  104. nv_ingest_api/util/converters/bytetools.py +78 -0
  105. nv_ingest_api/util/converters/containers.py +65 -0
  106. nv_ingest_api/util/converters/datetools.py +90 -0
  107. nv_ingest_api/util/converters/dftools.py +127 -0
  108. nv_ingest_api/util/converters/formats.py +64 -0
  109. nv_ingest_api/util/converters/type_mappings.py +27 -0
  110. nv_ingest_api/util/detectors/__init__.py +5 -0
  111. nv_ingest_api/util/detectors/language.py +38 -0
  112. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  113. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  114. nv_ingest_api/util/exception_handlers/decorators.py +223 -0
  115. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  116. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  117. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  118. nv_ingest_api/util/image_processing/__init__.py +5 -0
  119. nv_ingest_api/util/image_processing/clustering.py +260 -0
  120. nv_ingest_api/util/image_processing/processing.py +179 -0
  121. nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
  122. nv_ingest_api/util/image_processing/transforms.py +407 -0
  123. nv_ingest_api/util/logging/__init__.py +0 -0
  124. nv_ingest_api/util/logging/configuration.py +31 -0
  125. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  126. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  127. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  128. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  129. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +435 -0
  130. nv_ingest_api/util/metadata/__init__.py +5 -0
  131. nv_ingest_api/util/metadata/aggregators.py +469 -0
  132. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  133. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
  134. nv_ingest_api/util/nim/__init__.py +56 -0
  135. nv_ingest_api/util/pdf/__init__.py +3 -0
  136. nv_ingest_api/util/pdf/pdfium.py +427 -0
  137. nv_ingest_api/util/schema/__init__.py +0 -0
  138. nv_ingest_api/util/schema/schema_validator.py +10 -0
  139. nv_ingest_api/util/service_clients/__init__.py +3 -0
  140. nv_ingest_api/util/service_clients/client_base.py +72 -0
  141. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  142. nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
  143. nv_ingest_api/util/service_clients/redis/redis_client.py +334 -0
  144. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  145. nv_ingest_api/util/service_clients/rest/rest_client.py +398 -0
  146. nv_ingest_api/util/string_processing/__init__.py +51 -0
  147. {nv_ingest_api-2025.4.16.dev20250416.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/METADATA +1 -1
  148. nv_ingest_api-2025.4.17.dev20250417.dist-info/RECORD +152 -0
  149. nv_ingest_api-2025.4.16.dev20250416.dist-info/RECORD +0 -9
  150. /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
  151. {nv_ingest_api-2025.4.16.dev20250416.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/WHEEL +0 -0
  152. {nv_ingest_api-2025.4.16.dev20250416.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/licenses/LICENSE +0 -0
  153. {nv_ingest_api-2025.4.16.dev20250416.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,367 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import io
6
+ import base64
7
+ import logging
8
+ from typing import Any
9
+ from typing import List
10
+ from typing import Optional
11
+ from typing import Tuple
12
+
13
+ import grpc
14
+ import numpy as np
15
+ import riva.client
16
+ from scipy.io import wavfile
17
+
18
+ from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
19
+
20
+ try:
21
+ import librosa
22
+ except ImportError:
23
+ librosa = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ParakeetClient:
29
+ """
30
+ A simple interface for handling inference with a Parakeet model (e.g., speech, audio-related).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ endpoint: str,
36
+ auth_token: Optional[str] = None,
37
+ function_id: Optional[str] = None,
38
+ use_ssl: Optional[bool] = None,
39
+ ssl_cert: Optional[str] = None,
40
+ ):
41
+ """
42
+ Initialize the ParakeetClient.
43
+
44
+ Parameters
45
+ ----------
46
+ endpoint : str
47
+ The URL of the Parakeet service endpoint.
48
+ auth_token : Optional[str], default=None
49
+ The authentication token for accessing the service.
50
+ function_id: Optional[str]
51
+ The NVCF function ID for invoking the service.
52
+ use_ssl : bool, default=False
53
+ Whether to use SSL for the connection.
54
+ ssl_cert : Optional[str], default=None
55
+ Path to the SSL certificate if required.
56
+ auth_metadata : Optional[List[Tuple[str, str]]], default=None
57
+ Additional authentication metadata for the service.
58
+ """
59
+ self.endpoint = endpoint
60
+ self.auth_token = auth_token
61
+ self.function_id = function_id
62
+ if use_ssl is None:
63
+ self.use_ssl = True if ("grpc.nvcf.nvidia.com" in self.endpoint) and self.function_id else False
64
+ else:
65
+ self.use_ssl = use_ssl
66
+ self.ssl_cert = ssl_cert
67
+
68
+ self.auth_metadata = []
69
+ if self.auth_token:
70
+ self.auth_metadata.append(("authorization", f"Bearer {self.auth_token}"))
71
+ if self.function_id:
72
+ self.auth_metadata.append(("function-id", self.function_id))
73
+
74
+ # Create authentication and ASR service objects.
75
+ self._auth = riva.client.Auth(self.ssl_cert, self.use_ssl, self.endpoint, self.auth_metadata)
76
+ self._asr_service = riva.client.ASRService(self._auth)
77
+
78
+ @traceable_func(trace_name="{stage_name}::{model_name}")
79
+ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
80
+ """
81
+ Perform inference using the specified model and input data.
82
+
83
+ Parameters
84
+ ----------
85
+ data : dict
86
+ The input data for inference.
87
+ model_name : str
88
+ The model name.
89
+ kwargs : dict
90
+ Additional parameters for inference.
91
+
92
+ Returns
93
+ -------
94
+ Any
95
+ The processed inference results, coalesced in the same order as the input images.
96
+ """
97
+
98
+ response = self.transcribe(data)
99
+ if response is None:
100
+ return None
101
+ segments, transcript = process_transcription_response(response)
102
+ logger.debug("Processing Parakeet inference results (pass-through).")
103
+
104
+ return transcript
105
+
106
+ def transcribe(
107
+ self,
108
+ audio_content: str,
109
+ language_code: str = "en-US",
110
+ automatic_punctuation: bool = True,
111
+ word_time_offsets: bool = True,
112
+ max_alternatives: int = 1,
113
+ profanity_filter: bool = False,
114
+ verbatim_transcripts: bool = True,
115
+ speaker_diarization: bool = False,
116
+ boosted_lm_words: Optional[List[str]] = None,
117
+ boosted_lm_score: float = 0.0,
118
+ diarization_max_speakers: int = 0,
119
+ start_history: float = 0.0,
120
+ start_threshold: float = 0.0,
121
+ stop_history: float = 0.0,
122
+ stop_history_eou: bool = False,
123
+ stop_threshold: float = 0.0,
124
+ stop_threshold_eou: bool = False,
125
+ ):
126
+ """
127
+ Transcribe an audio file using Riva ASR.
128
+
129
+ Parameters
130
+ ----------
131
+ audio_content : str
132
+ Base64-encoded audio content to be transcribed.
133
+ language_code : str, default="en-US"
134
+ The language code for transcription.
135
+ automatic_punctuation : bool, default=True
136
+ Whether to enable automatic punctuation in the transcript.
137
+ word_time_offsets : bool, default=True
138
+ Whether to include word-level timestamps in the transcript.
139
+ max_alternatives : int, default=1
140
+ The maximum number of alternative transcripts to return.
141
+ profanity_filter : bool, default=False
142
+ Whether to filter out profanity from the transcript.
143
+ verbatim_transcripts : bool, default=True
144
+ Whether to return verbatim transcripts without normalization.
145
+ speaker_diarization : bool, default=False
146
+ Whether to enable speaker diarization.
147
+ boosted_lm_words : Optional[List[str]], default=None
148
+ A list of words to boost for language modeling.
149
+ boosted_lm_score : float, default=0.0
150
+ The boosting score for language model words.
151
+ diarization_max_speakers : int, default=0
152
+ The maximum number of speakers to differentiate in speaker diarization.
153
+ start_history : float, default=0.0
154
+ History window size for endpoint detection.
155
+ start_threshold : float, default=0.0
156
+ The threshold for starting speech detection.
157
+ stop_history : float, default=0.0
158
+ History window size for stopping speech detection.
159
+ stop_history_eou : bool, default=False
160
+ Whether to use an end-of-utterance flag for stopping detection.
161
+ stop_threshold : float, default=0.0
162
+ The threshold for stopping speech detection.
163
+ stop_threshold_eou : bool, default=False
164
+ Whether to use an end-of-utterance flag for stop threshold.
165
+
166
+ Returns
167
+ -------
168
+ Optional[riva.client.RecognitionResponse]
169
+ The response containing the transcription results.
170
+ Returns None if the transcription fails.
171
+ """
172
+ # Build the recognition configuration.
173
+ recognition_config = riva.client.RecognitionConfig(
174
+ language_code=language_code,
175
+ max_alternatives=max_alternatives,
176
+ profanity_filter=profanity_filter,
177
+ enable_automatic_punctuation=automatic_punctuation,
178
+ verbatim_transcripts=verbatim_transcripts,
179
+ enable_word_time_offsets=word_time_offsets,
180
+ )
181
+
182
+ # Add additional configuration parameters.
183
+ riva.client.add_word_boosting_to_config(
184
+ recognition_config,
185
+ boosted_lm_words or [],
186
+ boosted_lm_score,
187
+ )
188
+ riva.client.add_speaker_diarization_to_config(
189
+ recognition_config,
190
+ speaker_diarization,
191
+ diarization_max_speakers,
192
+ )
193
+ riva.client.add_endpoint_parameters_to_config(
194
+ recognition_config,
195
+ start_history,
196
+ start_threshold,
197
+ stop_history,
198
+ stop_history_eou,
199
+ stop_threshold,
200
+ stop_threshold_eou,
201
+ )
202
+ audio_bytes = base64.b64decode(audio_content)
203
+ mono_audio_bytes = convert_to_mono_wav(audio_bytes)
204
+
205
+ # Perform offline recognition and print the transcript.
206
+ try:
207
+ response = self._asr_service.offline_recognize(mono_audio_bytes, recognition_config)
208
+ return response
209
+ except grpc.RpcError as e:
210
+ logger.exception(f"Error transcribing audio file: {e.details()}")
211
+ raise
212
+
213
+
214
+ def convert_to_mono_wav(audio_bytes):
215
+ """
216
+ Convert an audio file to mono WAV format using Librosa and SciPy.
217
+
218
+ Parameters
219
+ ----------
220
+ audio_bytes : bytes
221
+ The raw audio data in bytes.
222
+
223
+ Returns
224
+ -------
225
+ bytes
226
+ The processed audio in mono WAV format.
227
+ """
228
+
229
+ if librosa is None:
230
+ raise ImportError(
231
+ "Librosa is required for audio processing. "
232
+ "If you are running this code with the ingest container, it can be installed by setting "
233
+ "the environment variable. INSTALL_AUDIO_EXTRACTION_DEPS=true"
234
+ )
235
+
236
+ # Create a BytesIO object from the audio bytes
237
+ byte_io = io.BytesIO(audio_bytes)
238
+
239
+ # Load the audio file with librosa
240
+ # librosa.load automatically converts to mono by default
241
+ audio_data, sample_rate = librosa.load(byte_io, sr=44100, mono=True)
242
+
243
+ # Ensure audio is properly scaled for 16-bit PCM
244
+ # Librosa normalizes the data between -1 and 1
245
+ if np.max(np.abs(audio_data)) > 0:
246
+ audio_data = audio_data / np.max(np.abs(audio_data)) * 0.9
247
+
248
+ # Convert to int16 format for 16-bit PCM WAV
249
+ audio_data_int16 = (audio_data * 32767).astype(np.int16)
250
+
251
+ # Create a BytesIO buffer to write the WAV file
252
+ output_io = io.BytesIO()
253
+
254
+ # Write the WAV data using scipy
255
+ wavfile.write(output_io, sample_rate, audio_data_int16)
256
+
257
+ # Reset the file pointer to the beginning and read all contents
258
+ output_io.seek(0)
259
+ wav_bytes = output_io.read()
260
+
261
+ return wav_bytes
262
+
263
+
264
+ def process_transcription_response(response):
265
+ """
266
+ Process a Riva transcription response (a protobuf message) to extract:
267
+ - final_transcript: the complete transcript.
268
+ - segments: a list of segments with start/end times and text.
269
+
270
+ Parameters:
271
+ response: The Riva transcription response message.
272
+
273
+ Returns:
274
+ segments (list): Each segment is a dict with keys "start", "end", and "text".
275
+ final_transcript (str): The overall transcript.
276
+ """
277
+ words_list = []
278
+ # Iterate directly over the results.
279
+ for result in response.results:
280
+ # Ensure there is at least one alternative.
281
+ if not result.alternatives:
282
+ continue
283
+ alternative = result.alternatives[0]
284
+ # Each alternative has a repeated field "words"
285
+ for word_info in alternative.words:
286
+ words_list.append(word_info)
287
+
288
+ # Build the overall transcript by joining the word strings.
289
+ final_transcript = " ".join(word.word for word in words_list)
290
+
291
+ # Now, segment the transcript based on punctuation.
292
+ segments = []
293
+ current_words = []
294
+ segment_start = None
295
+ segment_end = None
296
+ punctuation_marks = {".", "?", "!"}
297
+
298
+ for word in words_list:
299
+ # Mark the start of a segment if not already set.
300
+ if segment_start is None:
301
+ segment_start = word.start_time
302
+ segment_end = word.end_time
303
+ current_words.append(word.word)
304
+
305
+ # End the segment when a word ends with punctuation.
306
+ if word.word and word.word[-1] in punctuation_marks:
307
+ segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
308
+ current_words = []
309
+ segment_start = None
310
+ segment_end = None
311
+
312
+ # Add any remaining words as a segment.
313
+ if current_words:
314
+ segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
315
+
316
+ return segments, final_transcript
317
+
318
+
319
+ def create_audio_inference_client(
320
+ endpoints: Tuple[str, str],
321
+ infer_protocol: Optional[str] = None,
322
+ auth_token: Optional[str] = None,
323
+ function_id: Optional[str] = None,
324
+ use_ssl: bool = False,
325
+ ssl_cert: Optional[str] = None,
326
+ ):
327
+ """
328
+ Create a ParakeetClient for interfacing with an audio model inference server.
329
+
330
+ Parameters
331
+ ----------
332
+ endpoints : tuple
333
+ A tuple containing the gRPC and HTTP endpoints. Only the gRPC endpoint is used.
334
+ infer_protocol : str, optional
335
+ The protocol to use ("grpc" or "http").
336
+ If not specified, defaults to "grpc" if a valid gRPC endpoint is provided.
337
+ HTTP endpoints are not supported for audio inference.
338
+ auth_token : str, optional
339
+ Authorization token for authentication (default: None).
340
+ function_id : str, optional
341
+ NVCF function ID of the invocation (default: None)
342
+ use_ssl : bool, optional
343
+ Whether to use SSL for secure communication (default: False).
344
+ ssl_cert : str, optional
345
+ Path to the SSL certificate file if `use_ssl` is enabled (default: None).
346
+
347
+ Returns
348
+ -------
349
+ ParakeetClient
350
+ The initialized ParakeetClient configured for audio inference over gRPC.
351
+
352
+ Raises
353
+ ------
354
+ ValueError
355
+ If an invalid `infer_protocol` is specified or if an HTTP endpoint is provided.
356
+ """
357
+ grpc_endpoint, http_endpoint = endpoints
358
+
359
+ if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
360
+ infer_protocol = "grpc"
361
+
362
+ if infer_protocol == "http":
363
+ raise ValueError("`http` endpoints are not supported for audio. Use `grpc`.")
364
+
365
+ return ParakeetClient(
366
+ grpc_endpoint, auth_token=auth_token, function_id=function_id, use_ssl=use_ssl, ssl_cert=ssl_cert
367
+ )
@@ -0,0 +1,132 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
8
+
9
+
10
+ # Assume ModelInterface is defined elsewhere in the project.
11
+ class EmbeddingModelInterface(ModelInterface):
12
+ """
13
+ An interface for handling inference with an embedding model endpoint.
14
+ This implementation supports HTTP inference for generating embeddings from text prompts.
15
+ """
16
+
17
+ def name(self) -> str:
18
+ """
19
+ Return the name of this model interface.
20
+ """
21
+ return "Embedding"
22
+
23
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ """
25
+ Prepare input data for embedding inference. Ensures that a 'prompts' key is provided
26
+ and that its value is a list.
27
+
28
+ Raises
29
+ ------
30
+ KeyError
31
+ If the 'prompts' key is missing.
32
+ """
33
+ if "prompts" not in data:
34
+ raise KeyError("Input data must include 'prompts'.")
35
+ # Ensure the prompts are in list format.
36
+ if not isinstance(data["prompts"], list):
37
+ data["prompts"] = [data["prompts"]]
38
+ return data
39
+
40
+ def format_input(
41
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
42
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
43
+ """
44
+ Format the input payload for the embedding endpoint. This method constructs one payload per batch,
45
+ where each payload includes a list of text prompts.
46
+ Additionally, it returns batch data that preserves the original order of prompts.
47
+
48
+ Parameters
49
+ ----------
50
+ data : dict
51
+ The input data containing "prompts" (a list of text prompts).
52
+ protocol : str
53
+ Only "http" is supported.
54
+ max_batch_size : int
55
+ Maximum number of prompts per payload.
56
+ kwargs : dict
57
+ Additional parameters including model_name, encoding_format, input_type, and truncate.
58
+
59
+ Returns
60
+ -------
61
+ tuple
62
+ A tuple (payloads, batch_data_list) where:
63
+ - payloads is a list of JSON-serializable payload dictionaries.
64
+ - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
65
+ """
66
+ if protocol != "http":
67
+ raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
68
+
69
+ prompts = data.get("prompts", [])
70
+
71
+ def chunk_list(lst, chunk_size):
72
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
73
+
74
+ batches = chunk_list(prompts, max_batch_size)
75
+ payloads = []
76
+ batch_data_list = []
77
+ for batch in batches:
78
+ payload = {
79
+ "model": kwargs.get("model_name"),
80
+ "input": batch,
81
+ "encoding_format": kwargs.get("encoding_format", "float"),
82
+ "extra_body": {
83
+ "input_type": kwargs.get("input_type", "query"),
84
+ "truncate": kwargs.get("truncate", "NONE"),
85
+ },
86
+ }
87
+ payloads.append(payload)
88
+ batch_data_list.append({"prompts": batch})
89
+ return payloads, batch_data_list
90
+
91
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
92
+ """
93
+ Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key.
94
+
95
+ Parameters
96
+ ----------
97
+ response : Any
98
+ The raw HTTP response (assumed to be already decoded as JSON).
99
+ protocol : str
100
+ Only "http" is supported.
101
+ data : dict, optional
102
+ The original input data.
103
+ kwargs : dict
104
+ Additional keyword arguments.
105
+
106
+ Returns
107
+ -------
108
+ list
109
+ A list of generated embeddings extracted from the response.
110
+ """
111
+ if protocol != "http":
112
+ raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
113
+ if isinstance(response, dict):
114
+ embeddings = response.get("data")
115
+ if not embeddings:
116
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
117
+ # Each item in embeddings is expected to have an 'embedding' field.
118
+ return [item.get("embedding", None) for item in embeddings]
119
+ else:
120
+ return [str(response)]
121
+
122
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
123
+ """
124
+ Process inference results for the embedding model.
125
+ For this implementation, the output is expected to be a list of embeddings.
126
+
127
+ Returns
128
+ -------
129
+ list
130
+ The processed list of embeddings.
131
+ """
132
+ return output
@@ -0,0 +1,152 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Dict, Any, Optional, Tuple, List
6
+
7
+ import logging
8
+
9
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class VLMModelInterface(ModelInterface):
15
+ """
16
+ An interface for handling inference with a VLM model endpoint (e.g., NVIDIA LLaMA-based VLM).
17
+ This implementation supports HTTP inference with one or more base64-encoded images and a caption prompt.
18
+ """
19
+
20
+ def name(self) -> str:
21
+ """
22
+ Return the name of this model interface.
23
+ """
24
+ return "VLM"
25
+
26
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
+ """
28
+ Prepare input data for VLM inference. Accepts either a single base64 image or a list of images.
29
+ Ensures that a 'prompt' is provided.
30
+
31
+ Raises
32
+ ------
33
+ KeyError
34
+ If neither "base64_image" nor "base64_images" is provided or if "prompt" is missing.
35
+ ValueError
36
+ If "base64_images" exists but is not a list.
37
+ """
38
+ # Allow either a single image with "base64_image" or multiple images with "base64_images".
39
+ if "base64_images" in data:
40
+ if not isinstance(data["base64_images"], list):
41
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
42
+ elif "base64_image" in data:
43
+ # Convert a single image into a list.
44
+ data["base64_images"] = [data["base64_image"]]
45
+ else:
46
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
47
+
48
+ if "prompt" not in data:
49
+ raise KeyError("Input data must include 'prompt'.")
50
+ return data
51
+
52
+ def format_input(
53
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
54
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
55
+ """
56
+ Format the input payload for the VLM endpoint. This method constructs one payload per batch,
57
+ where each payload includes one message per image in the batch.
58
+ Additionally, it returns batch data that preserves the original order of images by including
59
+ the list of base64 images and the prompt for each batch.
60
+
61
+ Parameters
62
+ ----------
63
+ data : dict
64
+ The input data containing "base64_images" (a list of base64-encoded images) and "prompt".
65
+ protocol : str
66
+ Only "http" is supported.
67
+ max_batch_size : int
68
+ Maximum number of images per payload.
69
+ kwargs : dict
70
+ Additional parameters including model_name, max_tokens, temperature, top_p, and stream.
71
+
72
+ Returns
73
+ -------
74
+ tuple
75
+ A tuple (payloads, batch_data_list) where:
76
+ - payloads is a list of JSON-serializable payload dictionaries.
77
+ - batch_data_list is a list of dictionaries containing the keys "base64_images" and "prompt"
78
+ corresponding to each batch.
79
+ """
80
+ if protocol != "http":
81
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
82
+
83
+ images = data.get("base64_images", [])
84
+ prompt = data["prompt"]
85
+
86
+ # Helper function to chunk the list into batches.
87
+ def chunk_list(lst, chunk_size):
88
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
89
+
90
+ batches = chunk_list(images, max_batch_size)
91
+ payloads = []
92
+ batch_data_list = []
93
+ for batch in batches:
94
+ # Create one message per image in the batch.
95
+ messages = [
96
+ {"role": "user", "content": f'{prompt} <img src="data:image/png;base64,{img}" />'} for img in batch
97
+ ]
98
+ payload = {
99
+ "model": kwargs.get("model_name"),
100
+ "messages": messages,
101
+ "max_tokens": kwargs.get("max_tokens", 512),
102
+ "temperature": kwargs.get("temperature", 1.0),
103
+ "top_p": kwargs.get("top_p", 1.0),
104
+ "stream": kwargs.get("stream", False),
105
+ }
106
+ payloads.append(payload)
107
+ batch_data_list.append({"base64_images": batch, "prompt": prompt})
108
+ return payloads, batch_data_list
109
+
110
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
111
+ """
112
+ Parse the HTTP response from the VLM endpoint. Expects a response structure with a "choices" key.
113
+
114
+ Parameters
115
+ ----------
116
+ response : Any
117
+ The raw HTTP response (assumed to be already decoded as JSON).
118
+ protocol : str
119
+ Only "http" is supported.
120
+ data : dict, optional
121
+ The original input data.
122
+ kwargs : dict
123
+ Additional keyword arguments.
124
+
125
+ Returns
126
+ -------
127
+ list
128
+ A list of generated captions extracted from the response.
129
+ """
130
+ if protocol != "http":
131
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
132
+ if isinstance(response, dict):
133
+ choices = response.get("choices", [])
134
+ if not choices:
135
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
136
+ # Return a list of captions, one per choice.
137
+ return [choice.get("message", {}).get("content", "No caption returned") for choice in choices]
138
+ else:
139
+ # If response is not a dict, return its string representation in a list.
140
+ return [str(response)]
141
+
142
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
143
+ """
144
+ Process inference results for the VLM model.
145
+ For this implementation, the output is expected to be a list of captions.
146
+
147
+ Returns
148
+ -------
149
+ list
150
+ The processed list of captions.
151
+ """
152
+ return output