nv-ingest-api 2025.10.22.dev20251022__py3-none-any.whl → 2025.11.2.dev20251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nv-ingest-api might be problematic. Click here for more details.
- nv_ingest_api/internal/primitives/nim/model_interface/parakeet.py +4 -0
- nv_ingest_api/internal/primitives/nim/nim_client.py +124 -14
- nv_ingest_api/internal/schemas/extract/extract_audio_schema.py +4 -2
- nv_ingest_api/internal/schemas/extract/extract_chart_schema.py +10 -1
- nv_ingest_api/internal/schemas/extract/extract_docx_schema.py +4 -2
- nv_ingest_api/internal/schemas/extract/extract_image_schema.py +4 -2
- nv_ingest_api/internal/schemas/extract/extract_infographic_schema.py +10 -1
- nv_ingest_api/internal/schemas/extract/extract_pdf_schema.py +6 -4
- nv_ingest_api/internal/schemas/extract/extract_pptx_schema.py +4 -2
- nv_ingest_api/internal/schemas/extract/extract_table_schema.py +9 -1
- nv_ingest_api/internal/schemas/meta/ingest_job_schema.py +39 -0
- nv_ingest_api/internal/schemas/meta/metadata_schema.py +9 -0
- nv_ingest_api/internal/schemas/mixins.py +39 -0
- nv_ingest_api/internal/schemas/transform/transform_text_embedding_schema.py +4 -0
- nv_ingest_api/internal/transform/embed_text.py +82 -0
- nv_ingest_api/util/dataloader/dataloader.py +20 -9
- nv_ingest_api/util/message_brokers/qos_scheduler.py +283 -0
- nv_ingest_api/util/message_brokers/simple_message_broker/simple_client.py +1 -0
- nv_ingest_api/util/multi_processing/mp_pool_singleton.py +8 -2
- nv_ingest_api/util/service_clients/redis/redis_client.py +160 -0
- {nv_ingest_api-2025.10.22.dev20251022.dist-info → nv_ingest_api-2025.11.2.dev20251102.dist-info}/METADATA +2 -1
- {nv_ingest_api-2025.10.22.dev20251022.dist-info → nv_ingest_api-2025.11.2.dev20251102.dist-info}/RECORD +25 -23
- {nv_ingest_api-2025.10.22.dev20251022.dist-info → nv_ingest_api-2025.11.2.dev20251102.dist-info}/WHEEL +0 -0
- {nv_ingest_api-2025.10.22.dev20251022.dist-info → nv_ingest_api-2025.11.2.dev20251102.dist-info}/licenses/LICENSE +0 -0
- {nv_ingest_api-2025.10.22.dev20251022.dist-info → nv_ingest_api-2025.11.2.dev20251102.dist-info}/top_level.txt +0 -0
|
@@ -355,6 +355,10 @@ def create_audio_inference_client(
|
|
|
355
355
|
if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
|
|
356
356
|
infer_protocol = "grpc"
|
|
357
357
|
|
|
358
|
+
# Normalize protocol to lowercase for case-insensitive comparison
|
|
359
|
+
if infer_protocol:
|
|
360
|
+
infer_protocol = infer_protocol.lower()
|
|
361
|
+
|
|
358
362
|
if infer_protocol == "http":
|
|
359
363
|
raise ValueError("`http` endpoints are not supported for audio. Use `grpc`.")
|
|
360
364
|
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import json
|
|
7
7
|
import logging
|
|
8
|
+
import re
|
|
8
9
|
import threading
|
|
9
10
|
import time
|
|
10
11
|
import queue
|
|
@@ -24,6 +25,12 @@ from nv_ingest_api.util.string_processing import generate_url
|
|
|
24
25
|
|
|
25
26
|
logger = logging.getLogger(__name__)
|
|
26
27
|
|
|
28
|
+
# Regex pattern to detect CUDA-related errors in Triton gRPC responses
|
|
29
|
+
CUDA_ERROR_REGEX = re.compile(
|
|
30
|
+
r"(illegal memory access|invalid argument|failed to (copy|load|perform) .*: .*|TritonModelException: failed to copy data: .*)", # noqa: E501
|
|
31
|
+
re.IGNORECASE,
|
|
32
|
+
)
|
|
33
|
+
|
|
27
34
|
# A simple structure to hold a request's data and its Future for the result
|
|
28
35
|
InferenceRequest = namedtuple("InferenceRequest", ["data", "future", "model_name", "dims", "kwargs"])
|
|
29
36
|
|
|
@@ -40,7 +47,7 @@ class NimClient:
|
|
|
40
47
|
endpoints: Tuple[str, str],
|
|
41
48
|
auth_token: Optional[str] = None,
|
|
42
49
|
timeout: float = 120.0,
|
|
43
|
-
max_retries: int =
|
|
50
|
+
max_retries: int = 10,
|
|
44
51
|
max_429_retries: int = 5,
|
|
45
52
|
enable_dynamic_batching: bool = False,
|
|
46
53
|
dynamic_batch_timeout: float = 0.1, # 100 milliseconds
|
|
@@ -60,11 +67,11 @@ class NimClient:
|
|
|
60
67
|
auth_token : str, optional
|
|
61
68
|
Authorization token for HTTP requests (default: None).
|
|
62
69
|
timeout : float, optional
|
|
63
|
-
Timeout for HTTP requests in seconds (default:
|
|
70
|
+
Timeout for HTTP requests in seconds (default: 120.0).
|
|
64
71
|
max_retries : int, optional
|
|
65
|
-
The maximum number of retries for non-429 server-side errors (default:
|
|
72
|
+
The maximum number of retries for non-429 server-side errors (default: 10).
|
|
66
73
|
max_429_retries : int, optional
|
|
67
|
-
The maximum number of retries specifically for 429 errors (default:
|
|
74
|
+
The maximum number of retries specifically for 429 errors (default: 5).
|
|
68
75
|
|
|
69
76
|
Raises
|
|
70
77
|
------
|
|
@@ -323,7 +330,7 @@ class NimClient:
|
|
|
323
330
|
|
|
324
331
|
outputs = [grpcclient.InferRequestedOutput(output_name) for output_name in output_names]
|
|
325
332
|
|
|
326
|
-
base_delay = 0
|
|
333
|
+
base_delay = 2.0
|
|
327
334
|
attempt = 0
|
|
328
335
|
retries_429 = 0
|
|
329
336
|
max_grpc_retries = self.max_429_retries
|
|
@@ -342,8 +349,58 @@ class NimClient:
|
|
|
342
349
|
return [response.as_numpy(output.name()) for output in outputs]
|
|
343
350
|
|
|
344
351
|
except grpcclient.InferenceServerException as e:
|
|
345
|
-
status = e.status()
|
|
346
|
-
|
|
352
|
+
status = str(e.status())
|
|
353
|
+
message = e.message()
|
|
354
|
+
|
|
355
|
+
# Handle CUDA memory errors
|
|
356
|
+
if status == "StatusCode.INTERNAL":
|
|
357
|
+
if CUDA_ERROR_REGEX.search(message):
|
|
358
|
+
logger.warning(
|
|
359
|
+
f"Received gRPC INTERNAL error with CUDA-related message for model '{model_name}'. "
|
|
360
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
|
|
361
|
+
)
|
|
362
|
+
if attempt >= self.max_retries - 1:
|
|
363
|
+
logger.error(f"Max retries exceeded for CUDA errors on model '{model_name}'.")
|
|
364
|
+
raise e
|
|
365
|
+
# Try to reload models before retrying
|
|
366
|
+
model_reload_succeeded = reload_models(client=self.client, client_timeout=self.timeout)
|
|
367
|
+
if not model_reload_succeeded:
|
|
368
|
+
logger.error(f"Failed to reload models for model '{model_name}'.")
|
|
369
|
+
else:
|
|
370
|
+
logger.warning(
|
|
371
|
+
f"Received gRPC INTERNAL error for model '{model_name}'. "
|
|
372
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message (truncated): {message[:500]}"
|
|
373
|
+
)
|
|
374
|
+
if attempt >= self.max_retries - 1:
|
|
375
|
+
logger.error(f"Max retries exceeded for INTERNAL error on model '{model_name}'.")
|
|
376
|
+
raise e
|
|
377
|
+
|
|
378
|
+
# Common retry logic for both CUDA and non-CUDA INTERNAL errors
|
|
379
|
+
backoff_time = base_delay * (2**attempt)
|
|
380
|
+
time.sleep(backoff_time)
|
|
381
|
+
attempt += 1
|
|
382
|
+
continue
|
|
383
|
+
|
|
384
|
+
# Handle errors that can occur after model reload (NOT_FOUND, model not loaded)
|
|
385
|
+
if status == "StatusCode.NOT_FOUND":
|
|
386
|
+
logger.warning(
|
|
387
|
+
f"Received gRPC {status} error for model '{model_name}'. "
|
|
388
|
+
f"Attempt {attempt + 1} of {self.max_retries}. Message: {message[:500]}"
|
|
389
|
+
)
|
|
390
|
+
if attempt >= self.max_retries - 1:
|
|
391
|
+
logger.error(f"Max retries exceeded for model not found errors on model '{model_name}'.")
|
|
392
|
+
raise e
|
|
393
|
+
|
|
394
|
+
# Retry with exponential backoff WITHOUT reloading
|
|
395
|
+
backoff_time = base_delay * (2**attempt)
|
|
396
|
+
logger.info(
|
|
397
|
+
f"Retrying after {backoff_time}s backoff for model not found error on model '{model_name}'."
|
|
398
|
+
)
|
|
399
|
+
time.sleep(backoff_time)
|
|
400
|
+
attempt += 1
|
|
401
|
+
continue
|
|
402
|
+
|
|
403
|
+
if status == "StatusCode.UNAVAILABLE" and "Exceeds maximum queue size".lower() in message.lower():
|
|
347
404
|
retries_429 += 1
|
|
348
405
|
logger.warning(
|
|
349
406
|
f"Received gRPC {status} for model '{model_name}'. "
|
|
@@ -357,13 +414,12 @@ class NimClient:
|
|
|
357
414
|
time.sleep(backoff_time)
|
|
358
415
|
continue
|
|
359
416
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
raise
|
|
417
|
+
# For other server-side errors (e.g., INVALID_ARGUMENT, etc.),
|
|
418
|
+
# fail fast as retrying will not help
|
|
419
|
+
logger.error(
|
|
420
|
+
f"Received non-retryable gRPC error {status} from Triton for model '{model_name}': {message}"
|
|
421
|
+
)
|
|
422
|
+
raise
|
|
367
423
|
|
|
368
424
|
except Exception as e:
|
|
369
425
|
# Catch any other unexpected exceptions (e.g., network issues not caught by Triton client)
|
|
@@ -681,3 +737,57 @@ class NimClientManager:
|
|
|
681
737
|
def get_nim_client_manager(*args, **kwargs) -> NimClientManager:
|
|
682
738
|
"""Returns the singleton instance of the NimClientManager."""
|
|
683
739
|
return NimClientManager(*args, **kwargs)
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def reload_models(client: grpcclient.InferenceServerClient, exclude: list[str] = [], client_timeout: int = 120) -> bool:
|
|
743
|
+
"""
|
|
744
|
+
Reloads all models in the Triton server except for the models in the exclude list.
|
|
745
|
+
|
|
746
|
+
Parameters
|
|
747
|
+
----------
|
|
748
|
+
client : grpcclient.InferenceServerClient
|
|
749
|
+
The gRPC client connected to the Triton server.
|
|
750
|
+
exclude : list[str], optional
|
|
751
|
+
A list of model names to exclude from reloading.
|
|
752
|
+
client_timeout : int, optional
|
|
753
|
+
Timeout for client operations in seconds (default: 120).
|
|
754
|
+
|
|
755
|
+
Returns
|
|
756
|
+
-------
|
|
757
|
+
bool
|
|
758
|
+
True if all models were successfully reloaded, False otherwise.
|
|
759
|
+
"""
|
|
760
|
+
model_index = client.get_model_repository_index()
|
|
761
|
+
exclude = set(exclude)
|
|
762
|
+
names = [m.name for m in model_index.models if m.name not in exclude]
|
|
763
|
+
|
|
764
|
+
logger.info(f"Reloading {len(names)} model(s): {', '.join(names) if names else '(none)'}")
|
|
765
|
+
|
|
766
|
+
# 1) Unload
|
|
767
|
+
for name in names:
|
|
768
|
+
try:
|
|
769
|
+
client.unload_model(name)
|
|
770
|
+
except grpcclient.InferenceServerException as e:
|
|
771
|
+
msg = e.message()
|
|
772
|
+
if "explicit model load / unload" in msg.lower():
|
|
773
|
+
status = e.status()
|
|
774
|
+
logger.warning(
|
|
775
|
+
f"[SKIP Model Reload] Explicit model control disabled; cannot unload '{name}'. Status: {status}."
|
|
776
|
+
)
|
|
777
|
+
return False
|
|
778
|
+
logger.error(f"[ERROR] Failed to unload '{name}': {msg}")
|
|
779
|
+
return False
|
|
780
|
+
|
|
781
|
+
# 2) Load
|
|
782
|
+
for name in names:
|
|
783
|
+
client.load_model(name)
|
|
784
|
+
|
|
785
|
+
# 3) Readiness check
|
|
786
|
+
for name in names:
|
|
787
|
+
ready = client.is_model_ready(model_name=name, client_timeout=client_timeout)
|
|
788
|
+
if not ready:
|
|
789
|
+
logger.warning(f"[Warning] Triton Not ready: {name}")
|
|
790
|
+
return False
|
|
791
|
+
|
|
792
|
+
logger.info("✅ Reload of models complete.")
|
|
793
|
+
return True
|
|
@@ -10,10 +10,12 @@ from typing import Tuple
|
|
|
10
10
|
from pydantic import BaseModel, Field
|
|
11
11
|
from pydantic import root_validator
|
|
12
12
|
|
|
13
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
14
|
+
|
|
13
15
|
logger = logging.getLogger(__name__)
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
class AudioConfigSchema(
|
|
18
|
+
class AudioConfigSchema(LowercaseProtocolMixin):
|
|
17
19
|
"""
|
|
18
20
|
Configuration schema for audio extraction endpoints and options.
|
|
19
21
|
|
|
@@ -87,13 +89,13 @@ class AudioConfigSchema(BaseModel):
|
|
|
87
89
|
|
|
88
90
|
values[endpoint_name] = (grpc_service, http_service)
|
|
89
91
|
|
|
92
|
+
# Auto-infer protocol from endpoints if not specified
|
|
90
93
|
protocol_name = "audio_infer_protocol"
|
|
91
94
|
protocol_value = values.get(protocol_name)
|
|
92
95
|
|
|
93
96
|
if not protocol_value:
|
|
94
97
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
95
98
|
|
|
96
|
-
protocol_value = protocol_value.lower()
|
|
97
99
|
values[protocol_name] = protocol_value
|
|
98
100
|
|
|
99
101
|
return values
|
|
@@ -8,10 +8,12 @@ from typing import Tuple
|
|
|
8
8
|
|
|
9
9
|
from pydantic import field_validator, model_validator, ConfigDict, BaseModel, Field
|
|
10
10
|
|
|
11
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
12
|
+
|
|
11
13
|
logger = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
class ChartExtractorConfigSchema(
|
|
16
|
+
class ChartExtractorConfigSchema(LowercaseProtocolMixin):
|
|
15
17
|
"""
|
|
16
18
|
Configuration schema for chart extraction service endpoints and options.
|
|
17
19
|
|
|
@@ -96,6 +98,13 @@ class ChartExtractorConfigSchema(BaseModel):
|
|
|
96
98
|
|
|
97
99
|
values[endpoint_name] = (grpc_service, http_service)
|
|
98
100
|
|
|
101
|
+
# Auto-infer protocol from endpoints if not specified
|
|
102
|
+
protocol_name = endpoint_name.replace("_endpoints", "_infer_protocol")
|
|
103
|
+
protocol_value = values.get(protocol_name)
|
|
104
|
+
if not protocol_value:
|
|
105
|
+
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
106
|
+
values[protocol_name] = protocol_value
|
|
107
|
+
|
|
99
108
|
return values
|
|
100
109
|
|
|
101
110
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -9,10 +9,12 @@ from typing import Tuple
|
|
|
9
9
|
|
|
10
10
|
from pydantic import model_validator, ConfigDict, BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
13
|
+
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class DocxConfigSchema(
|
|
17
|
+
class DocxConfigSchema(LowercaseProtocolMixin):
|
|
16
18
|
"""
|
|
17
19
|
Configuration schema for docx extraction endpoints and options.
|
|
18
20
|
|
|
@@ -85,11 +87,11 @@ class DocxConfigSchema(BaseModel):
|
|
|
85
87
|
|
|
86
88
|
values[endpoint_name] = (grpc_service, http_service)
|
|
87
89
|
|
|
90
|
+
# Auto-infer protocol from endpoints if not specified
|
|
88
91
|
protocol_name = f"{model_name}_infer_protocol"
|
|
89
92
|
protocol_value = values.get(protocol_name)
|
|
90
93
|
if not protocol_value:
|
|
91
94
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
92
|
-
protocol_value = protocol_value.lower()
|
|
93
95
|
values[protocol_name] = protocol_value
|
|
94
96
|
|
|
95
97
|
return values
|
|
@@ -9,10 +9,12 @@ from typing import Tuple
|
|
|
9
9
|
|
|
10
10
|
from pydantic import model_validator, ConfigDict, BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
13
|
+
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class ImageConfigSchema(
|
|
17
|
+
class ImageConfigSchema(LowercaseProtocolMixin):
|
|
16
18
|
"""
|
|
17
19
|
Configuration schema for image extraction endpoints and options.
|
|
18
20
|
|
|
@@ -85,11 +87,11 @@ class ImageConfigSchema(BaseModel):
|
|
|
85
87
|
|
|
86
88
|
values[endpoint_name] = (grpc_service, http_service)
|
|
87
89
|
|
|
90
|
+
# Auto-infer protocol from endpoints if not specified
|
|
88
91
|
protocol_name = f"{model_name}_infer_protocol"
|
|
89
92
|
protocol_value = values.get(protocol_name)
|
|
90
93
|
if not protocol_value:
|
|
91
94
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
92
|
-
protocol_value = protocol_value.lower()
|
|
93
95
|
values[protocol_name] = protocol_value
|
|
94
96
|
|
|
95
97
|
return values
|
|
@@ -8,10 +8,12 @@ from typing import Tuple
|
|
|
8
8
|
|
|
9
9
|
from pydantic import field_validator, model_validator, ConfigDict, BaseModel, Field
|
|
10
10
|
|
|
11
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
12
|
+
|
|
11
13
|
logger = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
class InfographicExtractorConfigSchema(
|
|
16
|
+
class InfographicExtractorConfigSchema(LowercaseProtocolMixin):
|
|
15
17
|
"""
|
|
16
18
|
Configuration schema for infographic extraction service endpoints and options.
|
|
17
19
|
|
|
@@ -89,6 +91,13 @@ class InfographicExtractorConfigSchema(BaseModel):
|
|
|
89
91
|
|
|
90
92
|
values[endpoint_name] = (grpc_service, http_service)
|
|
91
93
|
|
|
94
|
+
# Auto-infer protocol from endpoints if not specified
|
|
95
|
+
protocol_name = endpoint_name.replace("_endpoints", "_infer_protocol")
|
|
96
|
+
protocol_value = values.get(protocol_name)
|
|
97
|
+
if not protocol_value:
|
|
98
|
+
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
99
|
+
values[protocol_name] = protocol_value
|
|
100
|
+
|
|
92
101
|
return values
|
|
93
102
|
|
|
94
103
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -9,10 +9,12 @@ from typing import Tuple
|
|
|
9
9
|
|
|
10
10
|
from pydantic import model_validator, ConfigDict, BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
13
|
+
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class PDFiumConfigSchema(
|
|
17
|
+
class PDFiumConfigSchema(LowercaseProtocolMixin):
|
|
16
18
|
"""
|
|
17
19
|
Configuration schema for PDFium endpoints and options.
|
|
18
20
|
|
|
@@ -82,11 +84,11 @@ class PDFiumConfigSchema(BaseModel):
|
|
|
82
84
|
|
|
83
85
|
values[endpoint_name] = (grpc_service, http_service)
|
|
84
86
|
|
|
87
|
+
# Auto-infer protocol from endpoints if not specified
|
|
85
88
|
protocol_name = f"{model_name}_infer_protocol"
|
|
86
89
|
protocol_value = values.get(protocol_name)
|
|
87
90
|
if not protocol_value:
|
|
88
91
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
89
|
-
protocol_value = protocol_value.lower()
|
|
90
92
|
values[protocol_name] = protocol_value
|
|
91
93
|
|
|
92
94
|
return values
|
|
@@ -94,7 +96,7 @@ class PDFiumConfigSchema(BaseModel):
|
|
|
94
96
|
model_config = ConfigDict(extra="forbid")
|
|
95
97
|
|
|
96
98
|
|
|
97
|
-
class NemoRetrieverParseConfigSchema(
|
|
99
|
+
class NemoRetrieverParseConfigSchema(LowercaseProtocolMixin):
|
|
98
100
|
"""
|
|
99
101
|
Configuration schema for NemoRetrieverParse endpoints and options.
|
|
100
102
|
|
|
@@ -170,11 +172,11 @@ class NemoRetrieverParseConfigSchema(BaseModel):
|
|
|
170
172
|
|
|
171
173
|
values[endpoint_name] = (grpc_service, http_service)
|
|
172
174
|
|
|
175
|
+
# Auto-infer protocol from endpoints if not specified
|
|
173
176
|
protocol_name = f"{model_name}_infer_protocol"
|
|
174
177
|
protocol_value = values.get(protocol_name)
|
|
175
178
|
if not protocol_value:
|
|
176
179
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
177
|
-
protocol_value = protocol_value.lower()
|
|
178
180
|
values[protocol_name] = protocol_value
|
|
179
181
|
|
|
180
182
|
return values
|
|
@@ -9,10 +9,12 @@ from typing import Tuple
|
|
|
9
9
|
|
|
10
10
|
from pydantic import model_validator, ConfigDict, BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
13
|
+
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
class PPTXConfigSchema(
|
|
17
|
+
class PPTXConfigSchema(LowercaseProtocolMixin):
|
|
16
18
|
"""
|
|
17
19
|
Configuration schema for docx extraction endpoints and options.
|
|
18
20
|
|
|
@@ -85,11 +87,11 @@ class PPTXConfigSchema(BaseModel):
|
|
|
85
87
|
|
|
86
88
|
values[endpoint_name] = (grpc_service, http_service)
|
|
87
89
|
|
|
90
|
+
# Auto-infer protocol from endpoints if not specified
|
|
88
91
|
protocol_name = f"{model_name}_infer_protocol"
|
|
89
92
|
protocol_value = values.get(protocol_name)
|
|
90
93
|
if not protocol_value:
|
|
91
94
|
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
92
|
-
protocol_value = protocol_value.lower()
|
|
93
95
|
values[protocol_name] = protocol_value
|
|
94
96
|
|
|
95
97
|
return values
|
|
@@ -9,11 +9,12 @@ from typing import Tuple
|
|
|
9
9
|
|
|
10
10
|
from pydantic import field_validator, model_validator, ConfigDict, BaseModel, Field
|
|
11
11
|
|
|
12
|
+
from nv_ingest_api.internal.schemas.mixins import LowercaseProtocolMixin
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
class TableExtractorConfigSchema(
|
|
17
|
+
class TableExtractorConfigSchema(LowercaseProtocolMixin):
|
|
17
18
|
"""
|
|
18
19
|
Configuration schema for the table extraction stage settings.
|
|
19
20
|
|
|
@@ -91,6 +92,13 @@ class TableExtractorConfigSchema(BaseModel):
|
|
|
91
92
|
|
|
92
93
|
values[endpoint_name] = (grpc_service, http_service)
|
|
93
94
|
|
|
95
|
+
# Auto-infer protocol from endpoints if not specified
|
|
96
|
+
protocol_name = endpoint_name.replace("_endpoints", "_infer_protocol")
|
|
97
|
+
protocol_value = values.get(protocol_name)
|
|
98
|
+
if not protocol_value:
|
|
99
|
+
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
|
|
100
|
+
values[protocol_name] = protocol_value
|
|
101
|
+
|
|
94
102
|
return values
|
|
95
103
|
|
|
96
104
|
model_config = ConfigDict(extra="forbid")
|
|
@@ -43,6 +43,24 @@ class PdfConfigSchema(BaseModelNoExt):
|
|
|
43
43
|
split_page_count: Annotated[int, Field(ge=1)] = 32
|
|
44
44
|
|
|
45
45
|
|
|
46
|
+
class RoutingOptionsSchema(BaseModelNoExt):
|
|
47
|
+
# Queue routing hint for QoS scheduler
|
|
48
|
+
queue_hint: Optional[str] = None
|
|
49
|
+
|
|
50
|
+
@field_validator("queue_hint")
|
|
51
|
+
@classmethod
|
|
52
|
+
def validate_queue_hint(cls, v):
|
|
53
|
+
if v is None:
|
|
54
|
+
return v
|
|
55
|
+
if not isinstance(v, str):
|
|
56
|
+
raise ValueError("queue_hint must be a string")
|
|
57
|
+
s = v.lower()
|
|
58
|
+
allowed = {"default", "immediate", "micro", "small", "medium", "large"}
|
|
59
|
+
if s not in allowed:
|
|
60
|
+
raise ValueError("queue_hint must be one of: default, immediate, micro, small, medium, large")
|
|
61
|
+
return s
|
|
62
|
+
|
|
63
|
+
|
|
46
64
|
# Ingest Task Schemas
|
|
47
65
|
|
|
48
66
|
|
|
@@ -126,6 +144,8 @@ class IngestTaskEmbedSchema(BaseModelNoExt):
|
|
|
126
144
|
image_elements_modality: Optional[str] = None
|
|
127
145
|
structured_elements_modality: Optional[str] = None
|
|
128
146
|
audio_elements_modality: Optional[str] = None
|
|
147
|
+
custom_content_field: Optional[str] = None
|
|
148
|
+
result_target_field: Optional[str] = None
|
|
129
149
|
|
|
130
150
|
|
|
131
151
|
class IngestTaskVdbUploadSchema(BaseModelNoExt):
|
|
@@ -281,8 +301,27 @@ class IngestJobSchema(BaseModelNoExt):
|
|
|
281
301
|
job_id: Union[str, int]
|
|
282
302
|
tasks: List[IngestTaskSchema]
|
|
283
303
|
tracing_options: Optional[TracingOptionsSchema] = None
|
|
304
|
+
routing_options: Optional[RoutingOptionsSchema] = None
|
|
284
305
|
pdf_config: Optional[PdfConfigSchema] = None
|
|
285
306
|
|
|
307
|
+
@model_validator(mode="before")
|
|
308
|
+
@classmethod
|
|
309
|
+
def migrate_queue_hint(cls, values):
|
|
310
|
+
"""
|
|
311
|
+
Backward-compatibility shim: if a legacy client sends
|
|
312
|
+
tracing_options.queue_hint, move it into routing_options.queue_hint.
|
|
313
|
+
"""
|
|
314
|
+
try:
|
|
315
|
+
topt = values.get("tracing_options") or {}
|
|
316
|
+
ropt = values.get("routing_options") or {}
|
|
317
|
+
if isinstance(topt, dict) and "queue_hint" in topt and "queue_hint" not in ropt:
|
|
318
|
+
ropt["queue_hint"] = topt.pop("queue_hint")
|
|
319
|
+
values["routing_options"] = ropt
|
|
320
|
+
values["tracing_options"] = topt
|
|
321
|
+
except Exception:
|
|
322
|
+
pass
|
|
323
|
+
return values
|
|
324
|
+
|
|
286
325
|
|
|
287
326
|
# ------------------------------------------------------------------------------
|
|
288
327
|
# Utility Functions
|
|
@@ -352,6 +352,15 @@ class MetadataSchema(BaseModelNoExt):
|
|
|
352
352
|
raise_on_failure: bool = False
|
|
353
353
|
"""If True, indicates that processing should halt on failure."""
|
|
354
354
|
|
|
355
|
+
total_pages: Optional[int] = None
|
|
356
|
+
"""Total number of pages in the source document (V2 API)."""
|
|
357
|
+
|
|
358
|
+
original_source_id: Optional[str] = None
|
|
359
|
+
"""The original source identifier before any splitting or chunking (V2 API)."""
|
|
360
|
+
|
|
361
|
+
original_source_name: Optional[str] = None
|
|
362
|
+
"""The original source name before any splitting or chunking (V2 API)."""
|
|
363
|
+
|
|
355
364
|
custom_content: Optional[Dict[str, Any]] = None
|
|
356
365
|
|
|
357
366
|
@model_validator(mode="before")
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Shared mixins for Pydantic schemas.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
from pydantic import BaseModel, field_validator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LowercaseProtocolMixin(BaseModel):
|
|
14
|
+
"""
|
|
15
|
+
Mixin that automatically lowercases any field ending with '_infer_protocol'.
|
|
16
|
+
|
|
17
|
+
This ensures case-insensitive handling of protocol values (e.g., "HTTP" -> "http").
|
|
18
|
+
Apply this mixin to any schema that has protocol fields to normalize user input.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
>>> class MyConfigSchema(LowercaseProtocolMixin):
|
|
23
|
+
... yolox_infer_protocol: str = ""
|
|
24
|
+
... ocr_infer_protocol: str = ""
|
|
25
|
+
>>>
|
|
26
|
+
>>> config = MyConfigSchema(yolox_infer_protocol="GRPC", ocr_infer_protocol="HTTP")
|
|
27
|
+
>>> config.yolox_infer_protocol
|
|
28
|
+
'grpc'
|
|
29
|
+
>>> config.ocr_infer_protocol
|
|
30
|
+
'http'
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
@field_validator("*", mode="before")
|
|
34
|
+
@classmethod
|
|
35
|
+
def _lowercase_protocol_fields(cls, v: Any, info):
|
|
36
|
+
"""Lowercase any field ending with '_infer_protocol'."""
|
|
37
|
+
if info.field_name.endswith("_infer_protocol") and v is not None:
|
|
38
|
+
return str(v).strip().lower()
|
|
39
|
+
return v
|
|
@@ -7,6 +7,8 @@ import logging
|
|
|
7
7
|
|
|
8
8
|
from pydantic import ConfigDict, BaseModel, Field, model_validator, field_validator
|
|
9
9
|
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
10
12
|
from nv_ingest_api.util.logging.configuration import LogLevel
|
|
11
13
|
|
|
12
14
|
logger = logging.getLogger(__name__)
|
|
@@ -26,6 +28,8 @@ class TextEmbeddingSchema(BaseModel):
|
|
|
26
28
|
image_elements_modality: str = Field(default="text")
|
|
27
29
|
structured_elements_modality: str = Field(default="text")
|
|
28
30
|
audio_elements_modality: str = Field(default="text")
|
|
31
|
+
custom_content_field: Optional[str] = None
|
|
32
|
+
result_target_field: Optional[str] = None
|
|
29
33
|
|
|
30
34
|
model_config = ConfigDict(extra="forbid")
|
|
31
35
|
|