nv-ingest-api 26.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (177) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +218 -0
  3. nv_ingest_api/interface/extract.py +977 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +200 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +186 -0
  8. nv_ingest_api/internal/__init__.py +0 -0
  9. nv_ingest_api/internal/enums/__init__.py +3 -0
  10. nv_ingest_api/internal/enums/common.py +550 -0
  11. nv_ingest_api/internal/extract/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  13. nv_ingest_api/internal/extract/audio/audio_extraction.py +202 -0
  14. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  15. nv_ingest_api/internal/extract/docx/docx_extractor.py +232 -0
  16. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +127 -0
  19. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +971 -0
  20. nv_ingest_api/internal/extract/html/__init__.py +3 -0
  21. nv_ingest_api/internal/extract/html/html_extractor.py +84 -0
  22. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/chart_extractor.py +375 -0
  24. nv_ingest_api/internal/extract/image/image_extractor.py +208 -0
  25. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  26. nv_ingest_api/internal/extract/image/image_helpers/common.py +433 -0
  27. nv_ingest_api/internal/extract/image/infographic_extractor.py +290 -0
  28. nv_ingest_api/internal/extract/image/ocr_extractor.py +407 -0
  29. nv_ingest_api/internal/extract/image/table_extractor.py +391 -0
  30. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  31. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  32. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  33. nv_ingest_api/internal/extract/pdf/engines/llama.py +246 -0
  34. nv_ingest_api/internal/extract/pdf/engines/nemotron_parse.py +598 -0
  35. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +166 -0
  36. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +652 -0
  37. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  38. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  39. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  40. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  41. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  42. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +968 -0
  43. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +210 -0
  44. nv_ingest_api/internal/meta/__init__.py +3 -0
  45. nv_ingest_api/internal/meta/udf.py +232 -0
  46. nv_ingest_api/internal/mutate/__init__.py +3 -0
  47. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  48. nv_ingest_api/internal/mutate/filter.py +133 -0
  49. nv_ingest_api/internal/primitives/__init__.py +0 -0
  50. nv_ingest_api/internal/primitives/control_message_task.py +16 -0
  51. nv_ingest_api/internal/primitives/ingest_control_message.py +307 -0
  52. nv_ingest_api/internal/primitives/nim/__init__.py +9 -0
  53. nv_ingest_api/internal/primitives/nim/default_values.py +14 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +338 -0
  59. nv_ingest_api/internal/primitives/nim/model_interface/nemotron_parse.py +239 -0
  60. nv_ingest_api/internal/primitives/nim/model_interface/ocr.py +776 -0
  61. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  62. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +129 -0
  63. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +177 -0
  64. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1681 -0
  65. nv_ingest_api/internal/primitives/nim/nim_client.py +801 -0
  66. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +126 -0
  67. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  68. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  69. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  70. nv_ingest_api/internal/primitives/tracing/tagging.py +288 -0
  71. nv_ingest_api/internal/schemas/__init__.py +3 -0
  72. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  73. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +133 -0
  74. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +144 -0
  75. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +129 -0
  76. nv_ingest_api/internal/schemas/extract/extract_html_schema.py +34 -0
  77. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +126 -0
  78. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +137 -0
  79. nv_ingest_api/internal/schemas/extract/extract_ocr_schema.py +137 -0
  80. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +220 -0
  81. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +128 -0
  82. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +137 -0
  83. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +37 -0
  85. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  86. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  87. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  88. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  89. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +355 -0
  90. nv_ingest_api/internal/schemas/meta/metadata_schema.py +394 -0
  91. nv_ingest_api/internal/schemas/meta/udf.py +23 -0
  92. nv_ingest_api/internal/schemas/mixins.py +39 -0
  93. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  94. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  95. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  96. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  97. nv_ingest_api/internal/schemas/store/store_image_schema.py +45 -0
  98. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  99. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +36 -0
  100. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  101. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +48 -0
  102. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +24 -0
  103. nv_ingest_api/internal/store/__init__.py +3 -0
  104. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  105. nv_ingest_api/internal/store/image_upload.py +251 -0
  106. nv_ingest_api/internal/transform/__init__.py +3 -0
  107. nv_ingest_api/internal/transform/caption_image.py +219 -0
  108. nv_ingest_api/internal/transform/embed_text.py +702 -0
  109. nv_ingest_api/internal/transform/split_text.py +182 -0
  110. nv_ingest_api/util/__init__.py +3 -0
  111. nv_ingest_api/util/control_message/__init__.py +0 -0
  112. nv_ingest_api/util/control_message/validators.py +47 -0
  113. nv_ingest_api/util/converters/__init__.py +0 -0
  114. nv_ingest_api/util/converters/bytetools.py +78 -0
  115. nv_ingest_api/util/converters/containers.py +65 -0
  116. nv_ingest_api/util/converters/datetools.py +90 -0
  117. nv_ingest_api/util/converters/dftools.py +127 -0
  118. nv_ingest_api/util/converters/formats.py +64 -0
  119. nv_ingest_api/util/converters/type_mappings.py +27 -0
  120. nv_ingest_api/util/dataloader/__init__.py +9 -0
  121. nv_ingest_api/util/dataloader/dataloader.py +409 -0
  122. nv_ingest_api/util/detectors/__init__.py +5 -0
  123. nv_ingest_api/util/detectors/language.py +38 -0
  124. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  125. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  126. nv_ingest_api/util/exception_handlers/decorators.py +429 -0
  127. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  128. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  129. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  130. nv_ingest_api/util/image_processing/__init__.py +5 -0
  131. nv_ingest_api/util/image_processing/clustering.py +260 -0
  132. nv_ingest_api/util/image_processing/processing.py +177 -0
  133. nv_ingest_api/util/image_processing/table_and_chart.py +504 -0
  134. nv_ingest_api/util/image_processing/transforms.py +850 -0
  135. nv_ingest_api/util/imports/__init__.py +3 -0
  136. nv_ingest_api/util/imports/callable_signatures.py +108 -0
  137. nv_ingest_api/util/imports/dynamic_resolvers.py +158 -0
  138. nv_ingest_api/util/introspection/__init__.py +3 -0
  139. nv_ingest_api/util/introspection/class_inspect.py +145 -0
  140. nv_ingest_api/util/introspection/function_inspect.py +65 -0
  141. nv_ingest_api/util/logging/__init__.py +0 -0
  142. nv_ingest_api/util/logging/configuration.py +102 -0
  143. nv_ingest_api/util/logging/sanitize.py +84 -0
  144. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  145. nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
  146. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  147. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  148. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  149. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +455 -0
  150. nv_ingest_api/util/metadata/__init__.py +5 -0
  151. nv_ingest_api/util/metadata/aggregators.py +516 -0
  152. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  153. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +200 -0
  154. nv_ingest_api/util/nim/__init__.py +161 -0
  155. nv_ingest_api/util/pdf/__init__.py +3 -0
  156. nv_ingest_api/util/pdf/pdfium.py +428 -0
  157. nv_ingest_api/util/schema/__init__.py +3 -0
  158. nv_ingest_api/util/schema/schema_validator.py +10 -0
  159. nv_ingest_api/util/service_clients/__init__.py +3 -0
  160. nv_ingest_api/util/service_clients/client_base.py +86 -0
  161. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  162. nv_ingest_api/util/service_clients/redis/__init__.py +3 -0
  163. nv_ingest_api/util/service_clients/redis/redis_client.py +983 -0
  164. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  165. nv_ingest_api/util/service_clients/rest/rest_client.py +595 -0
  166. nv_ingest_api/util/string_processing/__init__.py +51 -0
  167. nv_ingest_api/util/string_processing/configuration.py +682 -0
  168. nv_ingest_api/util/string_processing/yaml.py +109 -0
  169. nv_ingest_api/util/system/__init__.py +0 -0
  170. nv_ingest_api/util/system/hardware_info.py +594 -0
  171. nv_ingest_api-26.1.0rc4.dist-info/METADATA +237 -0
  172. nv_ingest_api-26.1.0rc4.dist-info/RECORD +177 -0
  173. nv_ingest_api-26.1.0rc4.dist-info/WHEEL +5 -0
  174. nv_ingest_api-26.1.0rc4.dist-info/licenses/LICENSE +201 -0
  175. nv_ingest_api-26.1.0rc4.dist-info/top_level.txt +2 -0
  176. udfs/__init__.py +5 -0
  177. udfs/llm_summarizer_udf.py +259 -0
@@ -0,0 +1,367 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import io
6
+ import base64
7
+ import logging
8
+ from typing import Any
9
+ from typing import List
10
+ from typing import Optional
11
+ from typing import Tuple
12
+
13
+ import grpc
14
+ import numpy as np
15
+ import riva.client
16
+ from scipy.io import wavfile
17
+
18
+ from nv_ingest_api.internal.primitives.tracing.tagging import traceable_func
19
+
20
+ try:
21
+ import librosa
22
+ except ImportError:
23
+ librosa = None
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ParakeetClient:
29
+ """
30
+ A simple interface for handling inference with a Parakeet model (e.g., speech, audio-related).
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ endpoint: str,
36
+ auth_token: Optional[str] = None,
37
+ function_id: Optional[str] = None,
38
+ use_ssl: Optional[bool] = None,
39
+ ssl_cert: Optional[str] = None,
40
+ ):
41
+ """
42
+ Initialize the ParakeetClient.
43
+
44
+ Parameters
45
+ ----------
46
+ endpoint : str
47
+ The URL of the Parakeet service endpoint.
48
+ auth_token : Optional[str], default=None
49
+ The authentication token for accessing the service.
50
+ function_id: Optional[str]
51
+ The NVCF function ID for invoking the service.
52
+ use_ssl : bool, default=False
53
+ Whether to use SSL for the connection.
54
+ ssl_cert : Optional[str], default=None
55
+ Path to the SSL certificate if required.
56
+ auth_metadata : Optional[List[Tuple[str, str]]], default=None
57
+ Additional authentication metadata for the service.
58
+ """
59
+ self.endpoint = endpoint
60
+ self.auth_token = auth_token
61
+ self.function_id = function_id
62
+ if use_ssl is None:
63
+ self.use_ssl = True if ("grpc.nvcf.nvidia.com" in self.endpoint) and self.function_id else False
64
+ else:
65
+ self.use_ssl = use_ssl
66
+ self.ssl_cert = ssl_cert
67
+
68
+ self.auth_metadata = []
69
+ if self.auth_token:
70
+ self.auth_metadata.append(("authorization", f"Bearer {self.auth_token}"))
71
+ if self.function_id:
72
+ self.auth_metadata.append(("function-id", self.function_id))
73
+
74
+ # Create authentication and ASR service objects.
75
+ self._auth = riva.client.Auth(self.ssl_cert, self.use_ssl, self.endpoint, self.auth_metadata)
76
+ self._asr_service = riva.client.ASRService(self._auth)
77
+
78
+ @traceable_func(trace_name="{stage_name}::{model_name}")
79
+ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
80
+ """
81
+ Perform inference using the specified model and input data.
82
+
83
+ Parameters
84
+ ----------
85
+ data : dict
86
+ The input data for inference.
87
+ model_name : str
88
+ The model name.
89
+ kwargs : dict
90
+ Additional parameters for inference.
91
+
92
+ Returns
93
+ -------
94
+ Any
95
+ The processed inference results, coalesced in the same order as the input images.
96
+ """
97
+
98
+ response = self.transcribe(data)
99
+ if response is None:
100
+ return None
101
+ segments, transcript = process_transcription_response(response)
102
+ logger.debug("Processing Parakeet inference results (pass-through).")
103
+
104
+ return segments, transcript
105
+
106
+ def transcribe(
107
+ self,
108
+ audio_content: str,
109
+ language_code: str = "en-US",
110
+ automatic_punctuation: bool = True,
111
+ word_time_offsets: bool = True,
112
+ max_alternatives: int = 1,
113
+ profanity_filter: bool = False,
114
+ verbatim_transcripts: bool = True,
115
+ speaker_diarization: bool = False,
116
+ boosted_lm_words: Optional[List[str]] = None,
117
+ boosted_lm_score: float = 0.0,
118
+ diarization_max_speakers: int = 0,
119
+ start_history: float = 0.0,
120
+ start_threshold: float = 0.0,
121
+ stop_history: float = 0.0,
122
+ stop_history_eou: bool = False,
123
+ stop_threshold: float = 0.0,
124
+ stop_threshold_eou: bool = False,
125
+ ):
126
+ """
127
+ Transcribe an audio file using Riva ASR.
128
+
129
+ Parameters
130
+ ----------
131
+ audio_content : str
132
+ Base64-encoded audio content to be transcribed.
133
+ language_code : str, default="en-US"
134
+ The language code for transcription.
135
+ automatic_punctuation : bool, default=True
136
+ Whether to enable automatic punctuation in the transcript.
137
+ word_time_offsets : bool, default=True
138
+ Whether to include word-level timestamps in the transcript.
139
+ max_alternatives : int, default=1
140
+ The maximum number of alternative transcripts to return.
141
+ profanity_filter : bool, default=False
142
+ Whether to filter out profanity from the transcript.
143
+ verbatim_transcripts : bool, default=True
144
+ Whether to return verbatim transcripts without normalization.
145
+ speaker_diarization : bool, default=False
146
+ Whether to enable speaker diarization.
147
+ boosted_lm_words : Optional[List[str]], default=None
148
+ A list of words to boost for language modeling.
149
+ boosted_lm_score : float, default=0.0
150
+ The boosting score for language model words.
151
+ diarization_max_speakers : int, default=0
152
+ The maximum number of speakers to differentiate in speaker diarization.
153
+ start_history : float, default=0.0
154
+ History window size for endpoint detection.
155
+ start_threshold : float, default=0.0
156
+ The threshold for starting speech detection.
157
+ stop_history : float, default=0.0
158
+ History window size for stopping speech detection.
159
+ stop_history_eou : bool, default=False
160
+ Whether to use an end-of-utterance flag for stopping detection.
161
+ stop_threshold : float, default=0.0
162
+ The threshold for stopping speech detection.
163
+ stop_threshold_eou : bool, default=False
164
+ Whether to use an end-of-utterance flag for stop threshold.
165
+
166
+ Returns
167
+ -------
168
+ Optional[riva.client.RecognitionResponse]
169
+ The response containing the transcription results.
170
+ Returns None if the transcription fails.
171
+ """
172
+ # Build the recognition configuration.
173
+ recognition_config = riva.client.RecognitionConfig(
174
+ language_code=language_code,
175
+ max_alternatives=max_alternatives,
176
+ profanity_filter=profanity_filter,
177
+ enable_automatic_punctuation=automatic_punctuation,
178
+ verbatim_transcripts=verbatim_transcripts,
179
+ enable_word_time_offsets=word_time_offsets,
180
+ )
181
+
182
+ # Add additional configuration parameters.
183
+ riva.client.add_word_boosting_to_config(
184
+ recognition_config,
185
+ boosted_lm_words or [],
186
+ boosted_lm_score,
187
+ )
188
+ riva.client.add_speaker_diarization_to_config(
189
+ recognition_config,
190
+ speaker_diarization,
191
+ diarization_max_speakers,
192
+ )
193
+ riva.client.add_endpoint_parameters_to_config(
194
+ recognition_config,
195
+ start_history,
196
+ start_threshold,
197
+ stop_history,
198
+ stop_history_eou,
199
+ stop_threshold,
200
+ stop_threshold_eou,
201
+ )
202
+ audio_bytes = base64.b64decode(audio_content)
203
+ mono_audio_bytes = convert_to_mono_wav(audio_bytes)
204
+
205
+ # Perform offline recognition and print the transcript.
206
+ try:
207
+ response = self._asr_service.offline_recognize(mono_audio_bytes, recognition_config)
208
+ return response
209
+ except grpc.RpcError as e:
210
+ logger.exception(f"Error transcribing audio file: {e.details()}")
211
+ raise
212
+
213
+
214
+ def convert_to_mono_wav(audio_bytes):
215
+ """
216
+ Convert an audio file to mono WAV format using Librosa and SciPy.
217
+
218
+ Parameters
219
+ ----------
220
+ audio_bytes : bytes
221
+ The raw audio data in bytes.
222
+
223
+ Returns
224
+ -------
225
+ bytes
226
+ The processed audio in mono WAV format.
227
+ """
228
+
229
+ if librosa is None:
230
+ raise ImportError("Librosa is required for audio processing. ")
231
+
232
+ # Create a BytesIO object from the audio bytes
233
+ byte_io = io.BytesIO(audio_bytes)
234
+
235
+ # Load the audio file with librosa
236
+ # librosa.load automatically converts to mono by default
237
+ audio_data, sample_rate = librosa.load(byte_io, sr=44100, mono=True)
238
+
239
+ # Ensure audio is properly scaled for 16-bit PCM
240
+ # Librosa normalizes the data between -1 and 1
241
+ if np.max(np.abs(audio_data)) > 0:
242
+ audio_data = audio_data / np.max(np.abs(audio_data)) * 0.9
243
+
244
+ # Convert to int16 format for 16-bit PCM WAV
245
+ audio_data_int16 = (audio_data * 32767).astype(np.int16)
246
+
247
+ # Create a BytesIO buffer to write the WAV file
248
+ output_io = io.BytesIO()
249
+
250
+ # Write the WAV data using scipy
251
+ wavfile.write(output_io, sample_rate, audio_data_int16)
252
+
253
+ # Reset the file pointer to the beginning and read all contents
254
+ output_io.seek(0)
255
+ wav_bytes = output_io.read()
256
+
257
+ return wav_bytes
258
+
259
+
260
+ def process_transcription_response(response):
261
+ """
262
+ Process a Riva transcription response (a protobuf message) to extract:
263
+ - final_transcript: the complete transcript.
264
+ - segments: a list of segments with start/end times and text.
265
+
266
+ Parameters:
267
+ response: The Riva transcription response message.
268
+
269
+ Returns:
270
+ segments (list): Each segment is a dict with keys "start", "end", and "text".
271
+ final_transcript (str): The overall transcript.
272
+ """
273
+ words_list = []
274
+ # Iterate directly over the results.
275
+ for result in response.results:
276
+ # Ensure there is at least one alternative.
277
+ if not result.alternatives:
278
+ continue
279
+ alternative = result.alternatives[0]
280
+ # Each alternative has a repeated field "words"
281
+ for word_info in alternative.words:
282
+ words_list.append(word_info)
283
+
284
+ # Build the overall transcript by joining the word strings.
285
+ final_transcript = " ".join(word.word for word in words_list)
286
+
287
+ # Now, segment the transcript based on punctuation.
288
+ segments = []
289
+ current_words = []
290
+ segment_start = None
291
+ segment_end = None
292
+ punctuation_marks = {".", "?", "!"}
293
+
294
+ for word in words_list:
295
+ # Mark the start of a segment if not already set.
296
+ if segment_start is None:
297
+ segment_start = word.start_time
298
+ segment_end = word.end_time
299
+ current_words.append(word.word)
300
+
301
+ # End the segment when a word ends with punctuation.
302
+ if word.word and word.word[-1] in punctuation_marks:
303
+ segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
304
+ current_words = []
305
+ segment_start = None
306
+ segment_end = None
307
+
308
+ # Add any remaining words as a segment.
309
+ if current_words:
310
+ segments.append({"start": segment_start, "end": segment_end, "text": " ".join(current_words)})
311
+
312
+ return segments, final_transcript
313
+
314
+
315
+ def create_audio_inference_client(
316
+ endpoints: Tuple[str, str],
317
+ infer_protocol: Optional[str] = None,
318
+ auth_token: Optional[str] = None,
319
+ function_id: Optional[str] = None,
320
+ use_ssl: bool = False,
321
+ ssl_cert: Optional[str] = None,
322
+ ):
323
+ """
324
+ Create a ParakeetClient for interfacing with an audio model inference server.
325
+
326
+ Parameters
327
+ ----------
328
+ endpoints : tuple
329
+ A tuple containing the gRPC and HTTP endpoints. Only the gRPC endpoint is used.
330
+ infer_protocol : str, optional
331
+ The protocol to use ("grpc" or "http").
332
+ If not specified, defaults to "grpc" if a valid gRPC endpoint is provided.
333
+ HTTP endpoints are not supported for audio inference.
334
+ auth_token : str, optional
335
+ Authorization token for authentication (default: None).
336
+ function_id : str, optional
337
+ NVCF function ID of the invocation (default: None)
338
+ use_ssl : bool, optional
339
+ Whether to use SSL for secure communication (default: False).
340
+ ssl_cert : str, optional
341
+ Path to the SSL certificate file if `use_ssl` is enabled (default: None).
342
+
343
+ Returns
344
+ -------
345
+ ParakeetClient
346
+ The initialized ParakeetClient configured for audio inference over gRPC.
347
+
348
+ Raises
349
+ ------
350
+ ValueError
351
+ If an invalid `infer_protocol` is specified or if an HTTP endpoint is provided.
352
+ """
353
+ grpc_endpoint, http_endpoint = endpoints
354
+
355
+ if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
356
+ infer_protocol = "grpc"
357
+
358
+ # Normalize protocol to lowercase for case-insensitive comparison
359
+ if infer_protocol:
360
+ infer_protocol = infer_protocol.lower()
361
+
362
+ if infer_protocol == "http":
363
+ raise ValueError("`http` endpoints are not supported for audio. Use `grpc`.")
364
+
365
+ return ParakeetClient(
366
+ grpc_endpoint, auth_token=auth_token, function_id=function_id, use_ssl=use_ssl, ssl_cert=ssl_cert
367
+ )
@@ -0,0 +1,129 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
8
+ import numpy as np
9
+
10
+
11
+ class EmbeddingModelInterface(ModelInterface):
12
+ """
13
+ An interface for handling inference with an embedding model endpoint.
14
+ This implementation supports HTTP inference for generating embeddings from text prompts.
15
+ """
16
+
17
+ def name(self) -> str:
18
+ """
19
+ Return the name of this model interface.
20
+ """
21
+ return "Embedding"
22
+
23
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ """
25
+ Prepare input data for embedding inference. Returns a list of strings representing the text to be embedded.
26
+ """
27
+ if "prompts" not in data:
28
+ raise KeyError("Input data must include 'prompts'.")
29
+ if not isinstance(data["prompts"], list):
30
+ data["prompts"] = [data["prompts"]]
31
+ return {"prompts": data["prompts"]}
32
+
33
+ def format_input(
34
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
35
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
36
+ """
37
+ Format the input payload for the embedding endpoint. This method constructs one payload per batch,
38
+ where each payload includes a list of text prompts.
39
+ Additionally, it returns batch data that preserves the original order of prompts.
40
+
41
+ Parameters
42
+ ----------
43
+ data : dict
44
+ The input data containing "prompts" (a list of text prompts).
45
+ protocol : str
46
+ Only "http" is supported.
47
+ max_batch_size : int
48
+ Maximum number of prompts per payload.
49
+ kwargs : dict
50
+ Additional parameters including model_name, encoding_format, input_type, and truncate.
51
+
52
+ Returns
53
+ -------
54
+ tuple
55
+ A tuple (payloads, batch_data_list) where:
56
+ - payloads is a list of JSON-serializable payload dictionaries.
57
+ - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
58
+ """
59
+
60
+ def chunk_list(lst, chunk_size):
61
+ lst = lst["prompts"]
62
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
63
+
64
+ batches = chunk_list(data, max_batch_size)
65
+ if protocol == "http":
66
+ payloads = []
67
+ batch_data_list = []
68
+ for batch in batches:
69
+ payload = {
70
+ "model": kwargs.get("model_name"),
71
+ "input": batch,
72
+ "encoding_format": kwargs.get("encoding_format", "float"),
73
+ "input_type": kwargs.get("input_type", "passage"),
74
+ "truncate": kwargs.get("truncate", "NONE"),
75
+ }
76
+ payloads.append(payload)
77
+ batch_data_list.append({"prompts": batch})
78
+ elif protocol == "grpc":
79
+ payloads = []
80
+ batch_data_list = []
81
+ for batch in batches:
82
+ text_np = np.array([[text.encode("utf-8")] for text in batch], dtype=np.object_)
83
+ payloads.append(text_np)
84
+ batch_data_list.append({"prompts": batch})
85
+ return payloads, batch_data_list
86
+
87
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
88
+ """
89
+ Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key.
90
+
91
+ Parameters
92
+ ----------
93
+ response : Any
94
+ The raw HTTP response (assumed to be already decoded as JSON).
95
+ protocol : str
96
+ Only "http" is supported.
97
+ data : dict, optional
98
+ The original input data.
99
+ kwargs : dict
100
+ Additional keyword arguments.
101
+
102
+ Returns
103
+ -------
104
+ list
105
+ A list of generated embeddings extracted from the response.
106
+ """
107
+ if protocol == "http":
108
+ if isinstance(response, dict):
109
+ embeddings = response.get("data")
110
+ if not embeddings:
111
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
112
+ # Each item in embeddings is expected to have an 'embedding' field.
113
+ return [item.get("embedding", None) for item in embeddings]
114
+ else:
115
+ return [str(response)]
116
+ elif protocol == "grpc":
117
+ return [res.flatten() for res in response]
118
+
119
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
120
+ """
121
+ Process inference results for the embedding model.
122
+ For this implementation, the output is expected to be a list of embeddings.
123
+
124
+ Returns
125
+ -------
126
+ list
127
+ The processed list of embeddings.
128
+ """
129
+ return output
@@ -0,0 +1,177 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Dict, Any, Optional, Tuple, List
6
+
7
+ import logging
8
+
9
+ from nv_ingest_api.internal.primitives.nim import ModelInterface
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class VLMModelInterface(ModelInterface):
15
+ """
16
+ An interface for handling inference with a VLM model endpoint (e.g., NVIDIA LLaMA-based VLM).
17
+ This implementation supports HTTP inference with one or more base64-encoded images and a caption prompt.
18
+ """
19
+
20
+ def name(self) -> str:
21
+ """
22
+ Return the name of this model interface.
23
+ """
24
+ return "VLM"
25
+
26
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
+ """
28
+ Prepare input data for VLM inference. Accepts either a single base64 image or a list of images.
29
+ Ensures that a 'prompt' is provided.
30
+
31
+ Raises
32
+ ------
33
+ KeyError
34
+ If neither "base64_image" nor "base64_images" is provided or if "prompt" is missing.
35
+ ValueError
36
+ If "base64_images" exists but is not a list.
37
+ """
38
+ # Allow either a single image with "base64_image" or multiple images with "base64_images".
39
+ if "base64_images" in data:
40
+ if not isinstance(data["base64_images"], list):
41
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
42
+ elif "base64_image" in data:
43
+ # Convert a single image into a list.
44
+ data["base64_images"] = [data["base64_image"]]
45
+ else:
46
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
47
+
48
+ if "prompt" not in data:
49
+ raise KeyError("Input data must include 'prompt'.")
50
+ return data
51
+
52
+ def format_input(
53
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
54
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
55
+ """
56
+ Format the input payload for the VLM endpoint. This method constructs one payload per batch,
57
+ where each payload includes one message per image in the batch.
58
+ Additionally, it returns batch data that preserves the original order of images by including
59
+ the list of base64 images and the prompt for each batch.
60
+
61
+ Parameters
62
+ ----------
63
+ data : dict
64
+ The input data containing "base64_images" (a list of base64-encoded images) and "prompt".
65
+ protocol : str
66
+ Only "http" is supported.
67
+ max_batch_size : int
68
+ Maximum number of images per payload.
69
+ kwargs : dict
70
+ Additional parameters including model_name, max_tokens, temperature, top_p, and stream.
71
+
72
+ Returns
73
+ -------
74
+ tuple
75
+ A tuple (payloads, batch_data_list) where:
76
+ - payloads is a list of JSON-serializable payload dictionaries.
77
+ - batch_data_list is a list of dictionaries containing the keys "base64_images" and "prompt"
78
+ corresponding to each batch.
79
+ """
80
+ if protocol != "http":
81
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
82
+
83
+ images = data.get("base64_images", [])
84
+ prompt = data["prompt"]
85
+ system_prompt = data.get("system_prompt")
86
+
87
+ # Helper function to chunk the list into batches.
88
+ def chunk_list(lst, chunk_size):
89
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
90
+
91
+ batches = chunk_list(images, max_batch_size)
92
+ payloads = []
93
+ batch_data_list = []
94
+ for batch in batches:
95
+ messages = []
96
+
97
+ if system_prompt:
98
+ messages.append(
99
+ {
100
+ "role": "system",
101
+ "content": system_prompt,
102
+ }
103
+ )
104
+ else:
105
+ logger.debug("VLM: No system prompt provided, using default")
106
+
107
+ # Create one message per image in the batch.
108
+ messages.extend(
109
+ [
110
+ {
111
+ "role": "user",
112
+ "content": [
113
+ {"type": "text", "text": f"{prompt}"},
114
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}},
115
+ ],
116
+ }
117
+ for img in batch
118
+ ]
119
+ )
120
+ payload = {
121
+ "model": kwargs.get("model_name"),
122
+ "messages": messages,
123
+ "max_tokens": kwargs.get("max_tokens", 512),
124
+ "temperature": kwargs.get("temperature", 1.0),
125
+ "top_p": kwargs.get("top_p", 1.0),
126
+ "stream": kwargs.get("stream", False),
127
+ }
128
+ payloads.append(payload)
129
+ batch_data = {"base64_images": batch, "prompt": prompt}
130
+ if system_prompt:
131
+ batch_data["system_prompt"] = system_prompt
132
+ batch_data_list.append(batch_data)
133
+ return payloads, batch_data_list
134
+
135
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
136
+ """
137
+ Parse the HTTP response from the VLM endpoint. Expects a response structure with a "choices" key.
138
+
139
+ Parameters
140
+ ----------
141
+ response : Any
142
+ The raw HTTP response (assumed to be already decoded as JSON).
143
+ protocol : str
144
+ Only "http" is supported.
145
+ data : dict, optional
146
+ The original input data.
147
+ kwargs : dict
148
+ Additional keyword arguments.
149
+
150
+ Returns
151
+ -------
152
+ list
153
+ A list of generated captions extracted from the response.
154
+ """
155
+ if protocol != "http":
156
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
157
+ if isinstance(response, dict):
158
+ choices = response.get("choices", [])
159
+ if not choices:
160
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
161
+ # Return a list of captions, one per choice.
162
+ return [choice.get("message", {}).get("content", "No caption returned") for choice in choices]
163
+ else:
164
+ # If response is not a dict, return its string representation in a list.
165
+ return [str(response)]
166
+
167
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
168
+ """
169
+ Process inference results for the VLM model.
170
+ For this implementation, the output is expected to be a list of captions.
171
+
172
+ Returns
173
+ -------
174
+ list
175
+ The processed list of captions.
176
+ """
177
+ return output