nv-ingest-api 2025.4.15.dev20250415__py3-none-any.whl → 2025.4.17.dev20250417__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-api might be problematic. Click here for more details.

Files changed (153) hide show
  1. nv_ingest_api/__init__.py +3 -0
  2. nv_ingest_api/interface/__init__.py +215 -0
  3. nv_ingest_api/interface/extract.py +972 -0
  4. nv_ingest_api/interface/mutate.py +154 -0
  5. nv_ingest_api/interface/store.py +218 -0
  6. nv_ingest_api/interface/transform.py +382 -0
  7. nv_ingest_api/interface/utility.py +200 -0
  8. nv_ingest_api/internal/enums/__init__.py +3 -0
  9. nv_ingest_api/internal/enums/common.py +494 -0
  10. nv_ingest_api/internal/extract/__init__.py +3 -0
  11. nv_ingest_api/internal/extract/audio/__init__.py +3 -0
  12. nv_ingest_api/internal/extract/audio/audio_extraction.py +149 -0
  13. nv_ingest_api/internal/extract/docx/__init__.py +5 -0
  14. nv_ingest_api/internal/extract/docx/docx_extractor.py +205 -0
  15. nv_ingest_api/internal/extract/docx/engines/__init__.py +0 -0
  16. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/__init__.py +3 -0
  17. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docx_helper.py +122 -0
  18. nv_ingest_api/internal/extract/docx/engines/docxreader_helpers/docxreader.py +895 -0
  19. nv_ingest_api/internal/extract/image/__init__.py +3 -0
  20. nv_ingest_api/internal/extract/image/chart_extractor.py +353 -0
  21. nv_ingest_api/internal/extract/image/image_extractor.py +204 -0
  22. nv_ingest_api/internal/extract/image/image_helpers/__init__.py +3 -0
  23. nv_ingest_api/internal/extract/image/image_helpers/common.py +403 -0
  24. nv_ingest_api/internal/extract/image/infographic_extractor.py +253 -0
  25. nv_ingest_api/internal/extract/image/table_extractor.py +344 -0
  26. nv_ingest_api/internal/extract/pdf/__init__.py +3 -0
  27. nv_ingest_api/internal/extract/pdf/engines/__init__.py +19 -0
  28. nv_ingest_api/internal/extract/pdf/engines/adobe.py +484 -0
  29. nv_ingest_api/internal/extract/pdf/engines/llama.py +243 -0
  30. nv_ingest_api/internal/extract/pdf/engines/nemoretriever.py +597 -0
  31. nv_ingest_api/internal/extract/pdf/engines/pdf_helpers/__init__.py +146 -0
  32. nv_ingest_api/internal/extract/pdf/engines/pdfium.py +603 -0
  33. nv_ingest_api/internal/extract/pdf/engines/tika.py +96 -0
  34. nv_ingest_api/internal/extract/pdf/engines/unstructured_io.py +426 -0
  35. nv_ingest_api/internal/extract/pdf/pdf_extractor.py +74 -0
  36. nv_ingest_api/internal/extract/pptx/__init__.py +5 -0
  37. nv_ingest_api/internal/extract/pptx/engines/__init__.py +0 -0
  38. nv_ingest_api/internal/extract/pptx/engines/pptx_helper.py +799 -0
  39. nv_ingest_api/internal/extract/pptx/pptx_extractor.py +187 -0
  40. nv_ingest_api/internal/mutate/__init__.py +3 -0
  41. nv_ingest_api/internal/mutate/deduplicate.py +110 -0
  42. nv_ingest_api/internal/mutate/filter.py +133 -0
  43. nv_ingest_api/internal/primitives/__init__.py +0 -0
  44. nv_ingest_api/{primitives → internal/primitives}/control_message_task.py +4 -0
  45. nv_ingest_api/{primitives → internal/primitives}/ingest_control_message.py +5 -2
  46. nv_ingest_api/internal/primitives/nim/__init__.py +8 -0
  47. nv_ingest_api/internal/primitives/nim/default_values.py +15 -0
  48. nv_ingest_api/internal/primitives/nim/model_interface/__init__.py +3 -0
  49. nv_ingest_api/internal/primitives/nim/model_interface/cached.py +274 -0
  50. nv_ingest_api/internal/primitives/nim/model_interface/decorators.py +56 -0
  51. nv_ingest_api/internal/primitives/nim/model_interface/deplot.py +270 -0
  52. nv_ingest_api/internal/primitives/nim/model_interface/helpers.py +275 -0
  53. nv_ingest_api/internal/primitives/nim/model_interface/nemoretriever_parse.py +238 -0
  54. nv_ingest_api/internal/primitives/nim/model_interface/paddle.py +462 -0
  55. nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +367 -0
  56. nv_ingest_api/internal/primitives/nim/model_interface/text_embedding.py +132 -0
  57. nv_ingest_api/internal/primitives/nim/model_interface/vlm.py +152 -0
  58. nv_ingest_api/internal/primitives/nim/model_interface/yolox.py +1400 -0
  59. nv_ingest_api/internal/primitives/nim/nim_client.py +344 -0
  60. nv_ingest_api/internal/primitives/nim/nim_model_interface.py +81 -0
  61. nv_ingest_api/internal/primitives/tracing/__init__.py +0 -0
  62. nv_ingest_api/internal/primitives/tracing/latency.py +69 -0
  63. nv_ingest_api/internal/primitives/tracing/logging.py +96 -0
  64. nv_ingest_api/internal/primitives/tracing/tagging.py +197 -0
  65. nv_ingest_api/internal/schemas/__init__.py +3 -0
  66. nv_ingest_api/internal/schemas/extract/__init__.py +3 -0
  67. nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +130 -0
  68. nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +135 -0
  69. nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +124 -0
  70. nv_ingest_api/internal/schemas/extract/extract_image_schema.py +124 -0
  71. nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +128 -0
  72. nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +218 -0
  73. nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +124 -0
  74. nv_ingest_api/internal/schemas/extract/extract_table_schema.py +129 -0
  75. nv_ingest_api/internal/schemas/message_brokers/__init__.py +3 -0
  76. nv_ingest_api/internal/schemas/message_brokers/message_broker_client_schema.py +23 -0
  77. nv_ingest_api/internal/schemas/message_brokers/request_schema.py +34 -0
  78. nv_ingest_api/internal/schemas/message_brokers/response_schema.py +19 -0
  79. nv_ingest_api/internal/schemas/meta/__init__.py +3 -0
  80. nv_ingest_api/internal/schemas/meta/base_model_noext.py +11 -0
  81. nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +237 -0
  82. nv_ingest_api/internal/schemas/meta/metadata_schema.py +221 -0
  83. nv_ingest_api/internal/schemas/mutate/__init__.py +3 -0
  84. nv_ingest_api/internal/schemas/mutate/mutate_image_dedup_schema.py +16 -0
  85. nv_ingest_api/internal/schemas/store/__init__.py +3 -0
  86. nv_ingest_api/internal/schemas/store/store_embedding_schema.py +28 -0
  87. nv_ingest_api/internal/schemas/store/store_image_schema.py +30 -0
  88. nv_ingest_api/internal/schemas/transform/__init__.py +3 -0
  89. nv_ingest_api/internal/schemas/transform/transform_image_caption_schema.py +15 -0
  90. nv_ingest_api/internal/schemas/transform/transform_image_filter_schema.py +17 -0
  91. nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +25 -0
  92. nv_ingest_api/internal/schemas/transform/transform_text_splitter_schema.py +22 -0
  93. nv_ingest_api/internal/store/__init__.py +3 -0
  94. nv_ingest_api/internal/store/embed_text_upload.py +236 -0
  95. nv_ingest_api/internal/store/image_upload.py +232 -0
  96. nv_ingest_api/internal/transform/__init__.py +3 -0
  97. nv_ingest_api/internal/transform/caption_image.py +205 -0
  98. nv_ingest_api/internal/transform/embed_text.py +496 -0
  99. nv_ingest_api/internal/transform/split_text.py +157 -0
  100. nv_ingest_api/util/__init__.py +0 -0
  101. nv_ingest_api/util/control_message/__init__.py +0 -0
  102. nv_ingest_api/util/control_message/validators.py +47 -0
  103. nv_ingest_api/util/converters/__init__.py +0 -0
  104. nv_ingest_api/util/converters/bytetools.py +78 -0
  105. nv_ingest_api/util/converters/containers.py +65 -0
  106. nv_ingest_api/util/converters/datetools.py +90 -0
  107. nv_ingest_api/util/converters/dftools.py +127 -0
  108. nv_ingest_api/util/converters/formats.py +64 -0
  109. nv_ingest_api/util/converters/type_mappings.py +27 -0
  110. nv_ingest_api/util/detectors/__init__.py +5 -0
  111. nv_ingest_api/util/detectors/language.py +38 -0
  112. nv_ingest_api/util/exception_handlers/__init__.py +0 -0
  113. nv_ingest_api/util/exception_handlers/converters.py +72 -0
  114. nv_ingest_api/util/exception_handlers/decorators.py +223 -0
  115. nv_ingest_api/util/exception_handlers/detectors.py +74 -0
  116. nv_ingest_api/util/exception_handlers/pdf.py +116 -0
  117. nv_ingest_api/util/exception_handlers/schemas.py +68 -0
  118. nv_ingest_api/util/image_processing/__init__.py +5 -0
  119. nv_ingest_api/util/image_processing/clustering.py +260 -0
  120. nv_ingest_api/util/image_processing/processing.py +179 -0
  121. nv_ingest_api/util/image_processing/table_and_chart.py +449 -0
  122. nv_ingest_api/util/image_processing/transforms.py +407 -0
  123. nv_ingest_api/util/logging/__init__.py +0 -0
  124. nv_ingest_api/util/logging/configuration.py +31 -0
  125. nv_ingest_api/util/message_brokers/__init__.py +3 -0
  126. nv_ingest_api/util/message_brokers/simple_message_broker/__init__.py +9 -0
  127. nv_ingest_api/util/message_brokers/simple_message_broker/broker.py +465 -0
  128. nv_ingest_api/util/message_brokers/simple_message_broker/ordered_message_queue.py +71 -0
  129. nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +435 -0
  130. nv_ingest_api/util/metadata/__init__.py +5 -0
  131. nv_ingest_api/util/metadata/aggregators.py +469 -0
  132. nv_ingest_api/util/multi_processing/__init__.py +8 -0
  133. nv_ingest_api/util/multi_processing/mp_pool_singleton.py +194 -0
  134. nv_ingest_api/util/nim/__init__.py +56 -0
  135. nv_ingest_api/util/pdf/__init__.py +3 -0
  136. nv_ingest_api/util/pdf/pdfium.py +427 -0
  137. nv_ingest_api/util/schema/__init__.py +0 -0
  138. nv_ingest_api/util/schema/schema_validator.py +10 -0
  139. nv_ingest_api/util/service_clients/__init__.py +3 -0
  140. nv_ingest_api/util/service_clients/client_base.py +72 -0
  141. nv_ingest_api/util/service_clients/kafka/__init__.py +3 -0
  142. nv_ingest_api/util/service_clients/redis/__init__.py +0 -0
  143. nv_ingest_api/util/service_clients/redis/redis_client.py +334 -0
  144. nv_ingest_api/util/service_clients/rest/__init__.py +0 -0
  145. nv_ingest_api/util/service_clients/rest/rest_client.py +398 -0
  146. nv_ingest_api/util/string_processing/__init__.py +51 -0
  147. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/METADATA +1 -1
  148. nv_ingest_api-2025.4.17.dev20250417.dist-info/RECORD +152 -0
  149. nv_ingest_api-2025.4.15.dev20250415.dist-info/RECORD +0 -9
  150. /nv_ingest_api/{primitives → internal}/__init__.py +0 -0
  151. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/WHEEL +0 -0
  152. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/licenses/LICENSE +0 -0
  153. {nv_ingest_api-2025.4.15.dev20250415.dist-info → nv_ingest_api-2025.4.17.dev20250417.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import uuid
6
+ import socket
7
+ import socketserver
8
+ import json
9
+ import logging
10
+ import threading
11
+ from typing import Optional
12
+
13
+ from pydantic import ValidationError
14
+
15
+ from nv_ingest_api.internal.schemas.message_brokers.request_schema import (
16
+ PushRequestSchema,
17
+ PopRequestSchema,
18
+ SizeRequestSchema,
19
+ )
20
+ from nv_ingest_api.internal.schemas.message_brokers.response_schema import ResponseSchema
21
+ from nv_ingest_api.util.message_brokers.simple_message_broker.ordered_message_queue import OrderedMessageQueue
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SimpleMessageBrokerHandler(socketserver.BaseRequestHandler):
27
+ """
28
+ Handles incoming client requests for the SimpleMessageBroker server, processes commands such as
29
+ PUSH, POP, SIZE, and PING, and manages message queues with thread-safe operations.
30
+ """
31
+
32
+ def handle(self):
33
+ """
34
+ Handles incoming client requests, validates the request data, and dispatches to the appropriate
35
+ command handler.
36
+
37
+ Raises
38
+ ------
39
+ ValidationError
40
+ If the incoming request data fails schema validation.
41
+ Exception
42
+ If there is an unexpected error while processing the request.
43
+ """
44
+
45
+ client_address = self.client_address
46
+
47
+ data_bytes = None
48
+ try:
49
+ data_length_bytes = self._recv_exact(8)
50
+ if not data_length_bytes:
51
+ logger.debug("No data length received. Closing connection.")
52
+ return
53
+ data_length = int.from_bytes(data_length_bytes, "big")
54
+
55
+ data_bytes = self._recv_exact(data_length)
56
+ if not data_bytes:
57
+ logger.debug("No data received. Closing connection.")
58
+ return
59
+
60
+ data = data_bytes.decode("utf-8").strip()
61
+ request_data = json.loads(data)
62
+
63
+ command = request_data.get("command")
64
+ if not command:
65
+ response = ResponseSchema(response_code=1, response_reason="No command specified")
66
+ self._send_response(response)
67
+ return
68
+
69
+ # Handle the PING command directly
70
+ if command == "PING":
71
+ self._handle_ping()
72
+ return
73
+
74
+ # Validate and extract common fields
75
+ queue_name = request_data.get("queue_name")
76
+
77
+ # Initialize the queue and its lock if necessary
78
+ if queue_name:
79
+ self.server._initialize_queue(queue_name)
80
+ queue_lock = self.server.queue_locks[queue_name]
81
+ queue = self.server.queues[queue_name]
82
+ else:
83
+ queue_lock = None
84
+ queue = None # For commands that don't require a queue
85
+
86
+ # Dispatch to the appropriate handler
87
+ if command == "PUSH":
88
+ validated_data = PushRequestSchema(**request_data)
89
+ self._handle_push(validated_data, transaction_id=str(uuid.uuid4()), queue=queue, queue_lock=queue_lock)
90
+ elif command == "PUSH_FOR_NV_INGEST":
91
+ validated_data = PushRequestSchema(**request_data)
92
+ self._handle_push_for_nv_ingest(validated_data, queue=queue, queue_lock=queue_lock)
93
+ elif command == "POP":
94
+ validated_data = PopRequestSchema(**request_data)
95
+ self._handle_pop(validated_data, queue=queue, queue_lock=queue_lock)
96
+ elif command == "SIZE":
97
+ validated_data = SizeRequestSchema(**request_data)
98
+ response = self._size_of_queue(validated_data, queue, queue_lock)
99
+ self._send_response(response)
100
+ else:
101
+ response = ResponseSchema(response_code=1, response_reason="Unknown command")
102
+ self._send_response(response)
103
+
104
+ except ValidationError as ve:
105
+ response = ResponseSchema(response_code=1, response_reason=str(ve))
106
+ self._send_response(response)
107
+ except Exception as e:
108
+ logger.error(f"Error processing command from {client_address}: {e}\n{data_bytes}")
109
+ response = ResponseSchema(response_code=1, response_reason=str(e))
110
+ try:
111
+ self._send_response(response)
112
+ except BrokenPipeError:
113
+ logger.error("Cannot send error response; client connection closed.")
114
+
115
+ def _handle_ping(self):
116
+ """
117
+ Responds to a PING command with a PONG response.
118
+ """
119
+
120
+ response = ResponseSchema(response_code=0, response="PONG")
121
+ self._send_response(response)
122
+
123
+ def _handle_push(
124
+ self, data: PushRequestSchema, transaction_id: str, queue: OrderedMessageQueue, queue_lock: threading.Lock
125
+ ):
126
+ """
127
+ Handles a PUSH command to add a message to the specified queue.
128
+
129
+ Parameters
130
+ ----------
131
+ data : PushRequestSchema
132
+ The validated data for the PUSH command.
133
+ transaction_id : str
134
+ The unique transaction ID for the operation.
135
+ queue : OrderedMessageQueue
136
+ The queue where the message will be pushed.
137
+ queue_lock : threading.Lock
138
+ The lock object to ensure thread-safe access to the queue.
139
+ """
140
+
141
+ timeout = data.timeout
142
+
143
+ with queue_lock:
144
+ if queue.full():
145
+ # Return failure response immediately
146
+ response = ResponseSchema(response_code=1, response_reason="Queue is full")
147
+ self._send_response(response)
148
+ return
149
+
150
+ # Proceed with the 3-way handshake
151
+ initial_response = ResponseSchema(
152
+ response_code=0, response="Transaction initiated. Waiting for ACK.", transaction_id=transaction_id
153
+ )
154
+ self._send_response(initial_response)
155
+
156
+ # Wait for ACK
157
+ if not self._wait_for_ack(transaction_id, timeout):
158
+ logger.debug(f"Transaction {transaction_id}: ACK not received. Discarding data.")
159
+ final_response = ResponseSchema(
160
+ response_code=1, response_reason="ACK not received.", transaction_id=transaction_id
161
+ )
162
+ else:
163
+ # Perform the PUSH operation after ACK
164
+ with queue_lock:
165
+ queue.push(data.message)
166
+ final_response = ResponseSchema(response_code=0, response="Data stored.", transaction_id=transaction_id)
167
+
168
+ # Send final response
169
+ self._send_response(final_response)
170
+
171
+ def _handle_push_for_nv_ingest(
172
+ self, data: PushRequestSchema, queue: OrderedMessageQueue, queue_lock: threading.Lock
173
+ ):
174
+ """
175
+ Handles a PUSH_FOR_NV_INGEST command, which includes generating a unique job ID and
176
+ updating the message payload accordingly.
177
+
178
+ Parameters
179
+ ----------
180
+ data : PushRequestSchema
181
+ The validated data for the PUSH_FOR_NV_INGEST command.
182
+ queue : OrderedMessageQueue
183
+ The queue where the message will be pushed.
184
+ queue_lock : threading.Lock
185
+ The lock object to ensure thread-safe access to the queue.
186
+ """
187
+
188
+ timeout = data.timeout
189
+
190
+ # Deserialize the message
191
+ try:
192
+ message_dict = json.loads(data.message)
193
+ except json.JSONDecodeError:
194
+ response = ResponseSchema(response_code=1, response_reason="Invalid JSON message")
195
+ self._send_response(response)
196
+ return
197
+
198
+ # Generate a UUID for 'job_id' and use it as transaction_id
199
+ transaction_id = str(uuid.uuid4())
200
+ message_dict["job_id"] = transaction_id
201
+
202
+ # Re-serialize the message
203
+ updated_message = json.dumps(message_dict)
204
+
205
+ with queue_lock:
206
+ if queue.full():
207
+ # Return failure response immediately
208
+ response = ResponseSchema(response_code=1, response_reason="Queue is full")
209
+ self._send_response(response)
210
+ return
211
+
212
+ # Proceed with the 3-way handshake
213
+ initial_response = ResponseSchema(
214
+ response_code=0, response="Transaction initiated. Waiting for ACK.", transaction_id=transaction_id
215
+ )
216
+ self._send_response(initial_response)
217
+
218
+ # Wait for ACK
219
+ if not self._wait_for_ack(transaction_id, timeout):
220
+ logger.debug(f"Transaction {transaction_id}: ACK not received. Discarding data.")
221
+ final_response = ResponseSchema(
222
+ response_code=1, response_reason="ACK not received.", transaction_id=transaction_id
223
+ )
224
+ else:
225
+ # Perform the PUSH operation after ACK
226
+ with queue_lock:
227
+ queue.push(updated_message)
228
+ final_response = ResponseSchema(response_code=0, response="Data stored.", transaction_id=transaction_id)
229
+
230
+ # Send final response
231
+ self._send_response(final_response)
232
+
233
+ def _handle_pop(self, data: PopRequestSchema, queue: OrderedMessageQueue, queue_lock: threading.Lock):
234
+ """
235
+ Handles a POP command to retrieve a message from the specified queue.
236
+
237
+ Parameters
238
+ ----------
239
+ data : PopRequestSchema
240
+ The validated data for the POP command.
241
+ queue : OrderedMessageQueue
242
+ The queue from which the message will be retrieved.
243
+ queue_lock : threading.Lock
244
+ The lock object to ensure thread-safe access to the queue.
245
+ """
246
+
247
+ timeout = data.timeout
248
+ transaction_id = str(uuid.uuid4())
249
+
250
+ with queue_lock:
251
+ if queue.empty():
252
+ # Return failure response immediately
253
+ response = ResponseSchema(response_code=1, response_reason="Queue is empty")
254
+ self._send_response(response)
255
+ return
256
+ # Pop the message from the queue
257
+ message = queue.pop(transaction_id)
258
+
259
+ # Proceed with the 3-way handshake
260
+ initial_response = ResponseSchema(response_code=0, response=message, transaction_id=transaction_id)
261
+ self._send_response(initial_response)
262
+
263
+ # Wait for ACK
264
+ if not self._wait_for_ack(transaction_id, timeout):
265
+ logger.debug(f"Transaction {transaction_id}: ACK not received. Returning data to queue.")
266
+ with queue_lock:
267
+ queue.return_message(transaction_id)
268
+ final_response = ResponseSchema(
269
+ response_code=1, response_reason="ACK not received.", transaction_id=transaction_id
270
+ )
271
+ else:
272
+ with queue_lock:
273
+ queue.acknowledge(transaction_id)
274
+ final_response = ResponseSchema(response_code=0, response="Data processed.", transaction_id=transaction_id)
275
+
276
+ # Send final response
277
+ self._send_response(final_response)
278
+
279
+ def _size_of_queue(
280
+ self, data: SizeRequestSchema, queue: OrderedMessageQueue, queue_lock: threading.Lock
281
+ ) -> ResponseSchema:
282
+ """
283
+ Retrieves the size of the specified queue.
284
+
285
+ Parameters
286
+ ----------
287
+ data : SizeRequestSchema
288
+ The validated data for the SIZE command.
289
+ queue : OrderedMessageQueue
290
+ The queue whose size will be queried.
291
+ queue_lock : threading.Lock
292
+ The lock object to ensure thread-safe access to the queue.
293
+
294
+ Returns
295
+ -------
296
+ ResponseSchema
297
+ A response containing the size of the queue.
298
+ """
299
+
300
+ with queue_lock:
301
+ size = queue.qsize()
302
+ return ResponseSchema(response_code=0, response=str(size))
303
+
304
+ def _wait_for_ack(self, transaction_id: str, timeout: Optional[float]) -> bool:
305
+ """
306
+ Waits for an acknowledgment (ACK) from the client for a specific transaction.
307
+
308
+ Parameters
309
+ ----------
310
+ transaction_id : str
311
+ The unique transaction ID for the operation.
312
+ timeout : float, optional
313
+ The timeout period for waiting for the ACK.
314
+
315
+ Returns
316
+ -------
317
+ bool
318
+ True if the ACK is received, False otherwise.
319
+ """
320
+
321
+ try:
322
+ self.request.settimeout(timeout)
323
+ ack_length_bytes = self._recv_exact(8)
324
+ if not ack_length_bytes or len(ack_length_bytes) < 8:
325
+ return False
326
+ ack_length = int.from_bytes(ack_length_bytes, "big")
327
+ ack_data_bytes = self._recv_exact(ack_length)
328
+ if not ack_data_bytes or len(ack_data_bytes) < ack_length:
329
+ return False
330
+ ack_data = ack_data_bytes.decode("utf-8")
331
+ ack_response = json.loads(ack_data)
332
+ return ack_response.get("transaction_id") == transaction_id and ack_response.get("ack") is True
333
+ except (socket.timeout, json.JSONDecodeError, ConnectionResetError) as e:
334
+ logger.error(f"Error waiting for ACK: {e}")
335
+ return False
336
+ finally:
337
+ self.request.settimeout(None)
338
+
339
+ def _send_response(self, response: ResponseSchema):
340
+ """
341
+ Sends a response back to the client.
342
+
343
+ Parameters
344
+ ----------
345
+ response : ResponseSchema
346
+ The response to send to the client.
347
+
348
+ Raises
349
+ ------
350
+ Exception
351
+ If there is an error while sending the response.
352
+ """
353
+
354
+ try:
355
+ response_json = response.model_dump_json().encode("utf-8")
356
+ total_length = len(response_json)
357
+ self.request.sendall(total_length.to_bytes(8, "big"))
358
+ self.request.sendall(response_json)
359
+ except BrokenPipeError as e:
360
+ logger.error(f"BrokenPipeError while sending response: {e}")
361
+ # Handle the broken pipe gracefully
362
+ except Exception as e:
363
+ logger.error(f"Unexpected error while sending response: {e}")
364
+
365
+ def _recv_exact(self, num_bytes: int) -> Optional[bytes]:
366
+ """
367
+ Receives an exact number of bytes from the client connection.
368
+
369
+ Parameters
370
+ ----------
371
+ num_bytes : int
372
+ The number of bytes to receive.
373
+
374
+ Returns
375
+ -------
376
+ Optional[bytes]
377
+ The received bytes, or None if the connection is closed.
378
+ """
379
+
380
+ data = bytearray()
381
+ while len(data) < num_bytes:
382
+ packet = self.request.recv(num_bytes - len(data))
383
+ if not packet:
384
+ return None
385
+ data.extend(packet)
386
+ return bytes(data)
387
+
388
+
389
+ class SimpleMessageBroker(socketserver.ThreadingMixIn, socketserver.TCPServer):
390
+ """
391
+ A thread-safe message broker server that manages multiple message queues and supports commands
392
+ such as PUSH, POP, SIZE, and PING.
393
+ """
394
+
395
+ allow_reuse_address = True
396
+ _instances = {}
397
+ _instances_lock = threading.Lock()
398
+
399
+ def __new__(cls, host: str, port: int, max_queue_size: int):
400
+ """
401
+ Ensures that only one instance of SimpleMessageBroker is created per host and port combination.
402
+
403
+ Parameters
404
+ ----------
405
+ host : str
406
+ The hostname or IP address for the server.
407
+ port : int
408
+ The port number for the server.
409
+ max_queue_size : int
410
+ The maximum size of each message queue.
411
+
412
+ Returns
413
+ -------
414
+ SimpleMessageBroker
415
+ The singleton instance of the server.
416
+ """
417
+
418
+ key = (host, port)
419
+ with cls._instances_lock:
420
+ if key not in cls._instances:
421
+ # Create a new instance and store it in the instances dictionary
422
+ instance = super(SimpleMessageBroker, cls).__new__(cls)
423
+ cls._instances[key] = instance
424
+ else:
425
+ instance = cls._instances[key]
426
+ return instance
427
+
428
+ def __init__(self, host: str, port: int, max_queue_size: int):
429
+ """
430
+ Initializes the SimpleMessageBroker server, setting up message queues and locks.
431
+
432
+ Parameters
433
+ ----------
434
+ host : str
435
+ The hostname or IP address for the server.
436
+ port : int
437
+ The port number for the server.
438
+ max_queue_size : int
439
+ The maximum size of each message queue.
440
+ """
441
+
442
+ # Prevent __init__ from running multiple times on the same instance
443
+ if hasattr(self, "_initialized") and self._initialized:
444
+ return
445
+ super().__init__((host, port), SimpleMessageBrokerHandler)
446
+ self.max_queue_size = max_queue_size
447
+ self.queues = {}
448
+ self.queue_locks = {} # Dictionary to hold locks for each queue
449
+ self.lock = threading.Lock() # Global lock to protect access to queues and locks
450
+ self._initialized = True # Flag to indicate initialization is complete
451
+
452
+ def _initialize_queue(self, queue_name: str):
453
+ """
454
+ Initializes a new message queue with the specified name if it doesn't already exist.
455
+
456
+ Parameters
457
+ ----------
458
+ queue_name : str
459
+ The name of the queue to initialize.
460
+ """
461
+
462
+ with self.lock:
463
+ if queue_name not in self.queues:
464
+ self.queues[queue_name] = OrderedMessageQueue(maxsize=self.max_queue_size)
465
+ self.queue_locks[queue_name] = threading.Lock()
@@ -0,0 +1,71 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+
6
+ import threading
7
+ import heapq
8
+
9
+
10
+ class OrderedMessageQueue:
11
+ def __init__(self, maxsize=0):
12
+ self.queue = [] # List of (index, message) tuples
13
+ self.maxsize = maxsize
14
+ self.next_index = 0 # Monotonically increasing message index
15
+ self.in_flight = {} # Mapping of transaction_id to (index, message)
16
+ self.lock = threading.Lock()
17
+ self.not_empty = threading.Condition(self.lock)
18
+ self.not_full = threading.Condition(self.lock)
19
+
20
+ def can_push(self):
21
+ """Check if the queue can accept more messages."""
22
+ with self.lock:
23
+ return self.maxsize == 0 or (len(self.queue) + len(self.in_flight)) < self.maxsize
24
+
25
+ def push(self, message):
26
+ """Add a message to the queue after it has been acknowledged."""
27
+ with self.lock:
28
+ index = self.next_index
29
+ self.next_index += 1
30
+ heapq.heappush(self.queue, (index, message))
31
+ self.not_empty.notify()
32
+ return index
33
+
34
+ def pop(self, transaction_id):
35
+ """Pop a message from the queue and mark it as in-flight."""
36
+ with self.lock:
37
+ while not self.queue:
38
+ self.not_empty.wait()
39
+ index, message = heapq.heappop(self.queue)
40
+ self.in_flight[transaction_id] = (index, message)
41
+ self.not_full.notify()
42
+ return message
43
+
44
+ def acknowledge(self, transaction_id):
45
+ """Acknowledge that a message has been processed."""
46
+ with self.lock:
47
+ if transaction_id in self.in_flight:
48
+ del self.in_flight[transaction_id]
49
+
50
+ def return_message(self, transaction_id):
51
+ """Return an unacknowledged message back to the queue."""
52
+ with self.lock:
53
+ if transaction_id in self.in_flight:
54
+ index, message = self.in_flight.pop(transaction_id)
55
+ heapq.heappush(self.queue, (index, message))
56
+ self.not_empty.notify()
57
+
58
+ def qsize(self):
59
+ """Get the number of messages currently in the queue."""
60
+ with self.lock:
61
+ return len(self.queue)
62
+
63
+ def empty(self):
64
+ """Check if the queue is empty."""
65
+ with self.lock:
66
+ return not self.queue
67
+
68
+ def full(self):
69
+ """Check if the queue is full."""
70
+ with self.lock:
71
+ return self.maxsize > 0 and (len(self.queue) + len(self.in_flight)) >= self.maxsize