nv-ingest-client 2025.7.24.dev20250724__py3-none-any.whl → 2025.11.2.dev20251102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nv-ingest-client might be problematic. Click here for more details.

Files changed (38) hide show
  1. nv_ingest_client/cli/util/click.py +182 -30
  2. nv_ingest_client/cli/util/processing.py +0 -393
  3. nv_ingest_client/client/client.py +561 -207
  4. nv_ingest_client/client/ingest_job_handler.py +412 -0
  5. nv_ingest_client/client/interface.py +466 -59
  6. nv_ingest_client/client/util/processing.py +11 -1
  7. nv_ingest_client/nv_ingest_cli.py +58 -6
  8. nv_ingest_client/primitives/jobs/job_spec.py +32 -10
  9. nv_ingest_client/primitives/tasks/__init__.py +6 -4
  10. nv_ingest_client/primitives/tasks/audio_extraction.py +27 -23
  11. nv_ingest_client/primitives/tasks/caption.py +10 -16
  12. nv_ingest_client/primitives/tasks/chart_extraction.py +16 -10
  13. nv_ingest_client/primitives/tasks/dedup.py +12 -21
  14. nv_ingest_client/primitives/tasks/embed.py +37 -76
  15. nv_ingest_client/primitives/tasks/extract.py +68 -169
  16. nv_ingest_client/primitives/tasks/filter.py +22 -28
  17. nv_ingest_client/primitives/tasks/infographic_extraction.py +16 -13
  18. nv_ingest_client/primitives/tasks/split.py +17 -18
  19. nv_ingest_client/primitives/tasks/store.py +29 -29
  20. nv_ingest_client/primitives/tasks/task_base.py +1 -72
  21. nv_ingest_client/primitives/tasks/task_factory.py +10 -11
  22. nv_ingest_client/primitives/tasks/udf.py +349 -0
  23. nv_ingest_client/util/dataset.py +8 -2
  24. nv_ingest_client/util/document_analysis.py +314 -0
  25. nv_ingest_client/util/image_disk_utils.py +300 -0
  26. nv_ingest_client/util/transport.py +12 -6
  27. nv_ingest_client/util/util.py +66 -0
  28. nv_ingest_client/util/vdb/milvus.py +220 -75
  29. {nv_ingest_client-2025.7.24.dev20250724.dist-info → nv_ingest_client-2025.11.2.dev20251102.dist-info}/METADATA +1 -3
  30. nv_ingest_client-2025.11.2.dev20251102.dist-info/RECORD +55 -0
  31. nv_ingest_client/cli/util/tasks.py +0 -3
  32. nv_ingest_client/primitives/exceptions.py +0 -0
  33. nv_ingest_client/primitives/tasks/transform.py +0 -0
  34. nv_ingest_client-2025.7.24.dev20250724.dist-info/RECORD +0 -54
  35. {nv_ingest_client-2025.7.24.dev20250724.dist-info → nv_ingest_client-2025.11.2.dev20251102.dist-info}/WHEEL +0 -0
  36. {nv_ingest_client-2025.7.24.dev20250724.dist-info → nv_ingest_client-2025.11.2.dev20251102.dist-info}/entry_points.txt +0 -0
  37. {nv_ingest_client-2025.7.24.dev20250724.dist-info → nv_ingest_client-2025.11.2.dev20251102.dist-info}/licenses/LICENSE +0 -0
  38. {nv_ingest_client-2025.7.24.dev20250724.dist-info → nv_ingest_client-2025.11.2.dev20251102.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import gzip
1
2
  import io
2
3
  import json
3
4
  import logging
@@ -6,6 +7,7 @@ import re
6
7
  from typing import Any
7
8
  from typing import Dict
8
9
  from typing import List
10
+ from typing import Optional
9
11
  from typing import Tuple
10
12
 
11
13
  try:
@@ -33,6 +35,7 @@ def save_document_results_to_jsonl(
33
35
  jsonl_output_filepath: str,
34
36
  original_source_name_for_log: str,
35
37
  ensure_parent_dir_exists: bool = True,
38
+ compression: Optional[str] = None,
36
39
  ) -> Tuple[int, Dict[str, str]]:
37
40
  """
38
41
  Saves a list of extraction items (for a single source document) to a JSON Lines file.
@@ -50,6 +53,13 @@ def save_document_results_to_jsonl(
50
53
  if parent_dir:
51
54
  os.makedirs(parent_dir, exist_ok=True)
52
55
 
56
+ if compression == "gzip":
57
+ open_func = gzip.open
58
+ elif compression is None:
59
+ open_func = open
60
+ else:
61
+ raise ValueError(f"Unsupported compression type: {compression}")
62
+
53
63
  with io.BytesIO() as buffer:
54
64
  for extraction_item in doc_response_data:
55
65
  if USING_ORJSON:
@@ -60,7 +70,7 @@ def save_document_results_to_jsonl(
60
70
 
61
71
  count_items_written = len(doc_response_data)
62
72
 
63
- with open(jsonl_output_filepath, "wb") as f_jsonl:
73
+ with open_func(jsonl_output_filepath, "wb") as f_jsonl:
64
74
  f_jsonl.write(full_byte_content)
65
75
 
66
76
  logger.info(
@@ -25,13 +25,14 @@ from nv_ingest_client.cli.util.click import click_match_and_validate_files
25
25
  from nv_ingest_client.cli.util.click import click_validate_batch_size
26
26
  from nv_ingest_client.cli.util.click import click_validate_file_exists
27
27
  from nv_ingest_client.cli.util.click import click_validate_task
28
- from nv_ingest_client.cli.util.processing import create_and_process_jobs
29
28
  from nv_ingest_client.cli.util.processing import report_statistics
30
29
  from nv_ingest_client.cli.util.system import configure_logging
31
30
  from nv_ingest_client.client import NvIngestClient
31
+ from nv_ingest_client.client.ingest_job_handler import IngestJobHandler
32
32
  from nv_ingest_client.util.dataset import get_dataset_files
33
33
  from nv_ingest_client.util.dataset import get_dataset_statistics
34
34
  from nv_ingest_client.util.system import ensure_directory_with_permissions
35
+ from nv_ingest_api.util.logging.sanitize import sanitize_for_logging
35
36
 
36
37
  try:
37
38
  NV_INGEST_VERSION = version("nv_ingest")
@@ -73,6 +74,12 @@ logger = logging.getLogger(__name__)
73
74
  @click.option("--client_host", default="localhost", help="DNS name or URL for the endpoint.")
74
75
  @click.option("--client_port", default=7670, type=int, help="Port for the client endpoint.")
75
76
  @click.option("--client_kwargs", help="Additional arguments to pass to the client.", default="{}")
77
+ @click.option(
78
+ "--api_version",
79
+ default="v1",
80
+ type=click.Choice(["v1", "v2"], case_sensitive=False),
81
+ help="API version to use (v1 or v2). V2 required for PDF split page count feature.",
82
+ )
76
83
  @click.option(
77
84
  "--client_type",
78
85
  default="rest",
@@ -118,6 +125,8 @@ Example:
118
125
  --task 'extract:{"document_type":"docx", "extract_text":true, "extract_images":true}'
119
126
  --task 'embed'
120
127
  --task 'caption:{}'
128
+ --pdf_split_page_count 64 # Configure PDF splitting (requires --api_version v2)
129
+ --api_version v2 # Use V2 API for PDF splitting support
121
130
 
122
131
  \b
123
132
  Tasks and Options:
@@ -169,6 +178,22 @@ Tasks and Options:
169
178
  - split_length (int): Segment length. No default.
170
179
  - split_overlap (int): Segment overlap. No default.
171
180
  \b
181
+ - udf: Executes user-defined functions (UDFs) for custom processing logic.
182
+ Options:
183
+ - udf_function (str): UDF specification. Supports three formats:
184
+ 1. Inline function: 'def my_func(control_message): ...'
185
+ 2. Import path: 'my_module.my_function'
186
+ 3. File path: '/path/to/file.py:function_name' or '/path/to/file.py' (assumes 'process' function)
187
+ - udf_function_name (str): Name of the function to execute from the UDF specification. Required.
188
+ - target_stage (str): Specific pipeline stage name to target for UDF execution (e.g.,
189
+ 'text_extractor', 'text_embedder', 'image_extractor'). Cannot be used with phase.
190
+ - run_before (bool): If True and target_stage is specified, run UDF before the target stage. Default: False.
191
+ - run_after (bool): If True and target_stage is specified, run UDF after the target stage. Default: False.
192
+ Examples:
193
+ --task 'udf:{"udf_function": "my_file.py:my_func", "target_stage": "text_embedder", "run_before": true}'
194
+ --task 'udf:{"udf_function": "def process(cm): return cm",
195
+ "target_stage": "image_extractor", "run_after": true}'
196
+ \b
172
197
  Note: The 'extract_method' automatically selects the optimal method based on 'document_type' if not explicitly stated.
173
198
  """,
174
199
  )
@@ -190,6 +215,12 @@ for locating portions of the system that might be bottlenecks for the overall ru
190
215
  )
191
216
  @click.option("--zipkin_host", default="localhost", help="DNS name or Zipkin API.")
192
217
  @click.option("--zipkin_port", default=9411, type=int, help="Port for the Zipkin trace API")
218
+ @click.option(
219
+ "--pdf_split_page_count",
220
+ default=None,
221
+ type=int,
222
+ help="Number of pages per PDF chunk for splitting. Allows per-request tuning of PDF split size in v2 api.",
223
+ )
193
224
  @click.option("--version", is_flag=True, help="Show version.")
194
225
  @click.pass_context
195
226
  def main(
@@ -198,6 +229,7 @@ def main(
198
229
  client_host: str,
199
230
  client_kwargs: str,
200
231
  client_port: int,
232
+ api_version: str,
201
233
  client_type: str,
202
234
  concurrency_n: int,
203
235
  dataset: str,
@@ -211,6 +243,7 @@ def main(
211
243
  collect_profiling_traces: bool,
212
244
  zipkin_host: str,
213
245
  zipkin_port: int,
246
+ pdf_split_page_count: int,
214
247
  task: [str],
215
248
  version: [bool],
216
249
  ):
@@ -221,7 +254,9 @@ def main(
221
254
 
222
255
  try:
223
256
  configure_logging(logger, log_level)
224
- logging.debug(f"nv-ingest-cli:params:\n{json.dumps(ctx.params, indent=2, default=repr)}")
257
+ # Sanitize CLI params before logging to avoid leaking secrets
258
+ _sanitized_params = sanitize_for_logging(dict(ctx.params))
259
+ logging.debug(f"nv-ingest-cli:params:\n{json.dumps(_sanitized_params, indent=2, default=repr)}")
225
260
 
226
261
  docs = list(doc)
227
262
  if dataset:
@@ -244,7 +279,20 @@ def main(
244
279
  logger.info(_msg)
245
280
 
246
281
  if not dry_run:
247
- logging.debug(f"Creating message client: {client_host} and port: {client_port} -> {client_kwargs}")
282
+ # Sanitize client kwargs (JSON string) before logging
283
+ try:
284
+ _client_kwargs_obj = json.loads(client_kwargs)
285
+ except Exception:
286
+ _client_kwargs_obj = {"raw": client_kwargs}
287
+
288
+ # Merge api_version into client_kwargs
289
+ _client_kwargs_obj["api_version"] = api_version
290
+
291
+ _sanitized_client_kwargs = sanitize_for_logging(_client_kwargs_obj)
292
+ logging.debug(
293
+ f"Creating message client: {client_host} and port: {client_port} -> "
294
+ f"{json.dumps(_sanitized_client_kwargs, indent=2, default=repr)}"
295
+ )
248
296
 
249
297
  if client_type == "rest":
250
298
  client_allocator = RestClient
@@ -257,20 +305,24 @@ def main(
257
305
  message_client_allocator=client_allocator,
258
306
  message_client_hostname=client_host,
259
307
  message_client_port=client_port,
260
- message_client_kwargs=json.loads(client_kwargs),
308
+ message_client_kwargs=_client_kwargs_obj,
261
309
  worker_pool_size=concurrency_n,
262
310
  )
263
311
 
264
312
  start_time_ns = time.time_ns()
265
- (total_files, trace_times, pages_processed, trace_ids) = create_and_process_jobs(
266
- files=docs,
313
+ handler = IngestJobHandler(
267
314
  client=ingest_client,
315
+ files=docs,
268
316
  tasks=task,
269
317
  output_directory=output_directory,
270
318
  batch_size=batch_size,
271
319
  fail_on_error=fail_on_error,
272
320
  save_images_separately=save_images_separately,
321
+ show_progress=True,
322
+ show_telemetry=True,
323
+ pdf_split_page_count=pdf_split_page_count,
273
324
  )
325
+ (total_files, trace_times, pages_processed, trace_ids) = handler.run()
274
326
 
275
327
  report_statistics(start_time_ns, trace_times, pages_processed, total_files)
276
328
 
@@ -110,6 +110,7 @@ class JobSpec:
110
110
  "job_id": str(self._job_id),
111
111
  "tasks": [task.to_dict() for task in self._tasks],
112
112
  "tracing_options": self._extended_options.get("tracing_options", {}),
113
+ "pdf_config": self._extended_options.get("pdf_config", {}),
113
114
  }
114
115
 
115
116
  @property
@@ -150,23 +151,48 @@ class JobSpec:
150
151
 
151
152
  def add_task(self, task) -> None:
152
153
  """
153
- Adds a task to the job specification.
154
+ Adds a task or list of tasks to the job specification.
155
+
156
+ Parameters
157
+ ----------
158
+ task : Task or list of Task
159
+ The task(s) to add to the job specification. Can be a single task or a list of tasks.
160
+ Each task must derive from the Task class and have a to_dict method.
161
+
162
+ Raises
163
+ ------
164
+ ValueError
165
+ If any task does not derive from the Task class.
166
+ """
167
+ # Handle both single tasks and lists of tasks
168
+ if isinstance(task, list):
169
+ # Process each task in the list
170
+ for single_task in task:
171
+ self._add_single_task(single_task)
172
+ else:
173
+ # Process single task
174
+ self._add_single_task(task)
175
+
176
+ def _add_single_task(self, task) -> None:
177
+ """
178
+ Adds a single task to the job specification with automatic task expansion.
154
179
 
155
180
  Parameters
156
181
  ----------
157
- task
158
- The task to add to the job specification. Assumes the task has a to_dict method.
182
+ task : Task
183
+ The task to add to the job specification.
159
184
 
160
185
  Raises
161
186
  ------
162
187
  ValueError
163
- If the task does not have a to_dict method.
188
+ If the task does not derive from the Task class.
164
189
  """
165
190
  if not isinstance(task, Task):
166
191
  raise ValueError("Task must derive from nv_ingest_client.primitives.Task class")
167
192
 
168
193
  self._tasks.append(task)
169
194
 
195
+ # Automatic task expansion for ExtractTask
170
196
  if isinstance(task, ExtractTask) and (task._extract_tables is True):
171
197
  self._tasks.append(TableExtractionTask())
172
198
  if isinstance(task, ExtractTask) and (task._extract_charts is True):
@@ -239,15 +265,16 @@ class BatchJobSpec:
239
265
  """
240
266
  from nv_ingest_client.util.util import create_job_specs_for_batch
241
267
  from nv_ingest_client.util.util import generate_matching_files
268
+ from nv_ingest_client.util.util import balanced_groups_flat_order
242
269
 
243
270
  if isinstance(files, str):
244
271
  files = [files]
245
272
 
246
273
  matching_files = list(generate_matching_files(files))
274
+ matching_files = balanced_groups_flat_order(matching_files)
247
275
  if not matching_files:
248
276
  logger.warning(f"No files found matching {files}.")
249
277
  return
250
-
251
278
  job_specs = create_job_specs_for_batch(matching_files)
252
279
  for job_spec in job_specs:
253
280
  self.add_job_spec(job_spec)
@@ -321,11 +348,6 @@ class BatchJobSpec:
321
348
  document_type : str, optional
322
349
  The document type used to filter job specifications. If not provided, the
323
350
  `document_type` is inferred from the task, or the task is applied to all job specifications.
324
-
325
- Raises
326
- ------
327
- ValueError
328
- If the task does not derive from the `Task` class.
329
351
  """
330
352
  if not isinstance(task, Task):
331
353
  raise ValueError("Task must derive from nv_ingest_client.primitives.Task class")
@@ -18,14 +18,18 @@ from .task_base import Task
18
18
  from .task_base import TaskType
19
19
  from .task_base import is_valid_task_type
20
20
  from .task_factory import task_factory
21
+ from .udf import UDFTask
21
22
 
22
23
  __all__ = [
23
24
  "AudioExtractionTask",
24
25
  "CaptionTask",
25
26
  "ChartExtractionTask",
27
+ "DedupTask",
28
+ "EmbedTask",
26
29
  "ExtractTask",
27
- "is_valid_task_type",
30
+ "FilterTask",
28
31
  "InfographicExtractionTask",
32
+ "is_valid_task_type",
29
33
  "SplitTask",
30
34
  "StoreEmbedTask",
31
35
  "StoreTask",
@@ -33,7 +37,5 @@ __all__ = [
33
37
  "Task",
34
38
  "task_factory",
35
39
  "TaskType",
36
- "DedupTask",
37
- "FilterTask",
38
- "EmbedTask",
40
+ "UDFTask",
39
41
  ]
@@ -10,33 +10,19 @@ import logging
10
10
  from typing import Dict
11
11
  from typing import Optional
12
12
 
13
- from pydantic import BaseModel
14
- from pydantic import ConfigDict
13
+ from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskAudioExtraction
15
14
 
16
15
  from .task_base import Task
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
20
19
 
21
- class AudioExtractionSchema(BaseModel):
22
- auth_token: Optional[str] = None
23
- grpc_endpoint: Optional[str] = None
24
- http_endpoint: Optional[str] = None
25
- infer_protocol: Optional[str] = None
26
- function_id: Optional[str] = None
27
- use_ssl: Optional[bool] = None
28
- ssl_cert: Optional[str] = None
29
- segment_audio: Optional[bool] = None
30
-
31
- model_config = ConfigDict(extra="forbid")
32
- model_config["protected_namespaces"] = ()
33
-
34
-
35
20
  class AudioExtractionTask(Task):
36
21
  def __init__(
37
22
  self,
38
23
  auth_token: str = None,
39
24
  grpc_endpoint: str = None,
25
+ http_endpoint: str = None,
40
26
  infer_protocol: str = None,
41
27
  function_id: Optional[str] = None,
42
28
  use_ssl: bool = None,
@@ -45,13 +31,26 @@ class AudioExtractionTask(Task):
45
31
  ) -> None:
46
32
  super().__init__()
47
33
 
48
- self._auth_token = auth_token
49
- self._grpc_endpoint = grpc_endpoint
50
- self._infer_protocol = infer_protocol
51
- self._function_id = function_id
52
- self._use_ssl = use_ssl
53
- self._ssl_cert = ssl_cert
54
- self._segment_audio = segment_audio
34
+ # Use the API schema for validation
35
+ validated_data = IngestTaskAudioExtraction(
36
+ auth_token=auth_token,
37
+ grpc_endpoint=grpc_endpoint,
38
+ http_endpoint=http_endpoint,
39
+ infer_protocol=infer_protocol,
40
+ function_id=function_id,
41
+ use_ssl=use_ssl,
42
+ ssl_cert=ssl_cert,
43
+ segment_audio=segment_audio,
44
+ )
45
+
46
+ self._auth_token = validated_data.auth_token
47
+ self._grpc_endpoint = validated_data.grpc_endpoint
48
+ self._http_endpoint = validated_data.http_endpoint
49
+ self._infer_protocol = validated_data.infer_protocol
50
+ self._function_id = validated_data.function_id
51
+ self._use_ssl = validated_data.use_ssl
52
+ self._ssl_cert = validated_data.ssl_cert
53
+ self._segment_audio = validated_data.segment_audio
55
54
 
56
55
  def __str__(self) -> str:
57
56
  """
@@ -64,6 +63,8 @@ class AudioExtractionTask(Task):
64
63
  info += " auth_token: [redacted]\n"
65
64
  if self._grpc_endpoint:
66
65
  info += f" grpc_endpoint: {self._grpc_endpoint}\n"
66
+ if self._http_endpoint:
67
+ info += f" http_endpoint: {self._http_endpoint}\n"
67
68
  if self._infer_protocol:
68
69
  info += f" infer_protocol: {self._infer_protocol}\n"
69
70
  if self._function_id:
@@ -89,6 +90,9 @@ class AudioExtractionTask(Task):
89
90
  if self._grpc_endpoint:
90
91
  task_properties["grpc_endpoint"] = self._grpc_endpoint
91
92
 
93
+ if self._http_endpoint:
94
+ task_properties["http_endpoint"] = self._http_endpoint
95
+
92
96
  if self._infer_protocol:
93
97
  task_properties["infer_protocol"] = self._infer_protocol
94
98
 
@@ -8,25 +8,14 @@
8
8
 
9
9
  import logging
10
10
  from typing import Dict
11
- from typing import Optional
12
11
 
13
- from pydantic import ConfigDict, BaseModel
14
12
 
13
+ from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskCaptionSchema
15
14
  from .task_base import Task
16
15
 
17
16
  logger = logging.getLogger(__name__)
18
17
 
19
18
 
20
- class CaptionTaskSchema(BaseModel):
21
- api_key: Optional[str] = None
22
- endpoint_url: Optional[str] = None
23
- prompt: Optional[str] = None
24
- model_name: Optional[str] = None
25
-
26
- model_config = ConfigDict(extra="forbid")
27
- model_config["protected_namespaces"] = ()
28
-
29
-
30
19
  class CaptionTask(Task):
31
20
  def __init__(
32
21
  self,
@@ -37,10 +26,15 @@ class CaptionTask(Task):
37
26
  ) -> None:
38
27
  super().__init__()
39
28
 
40
- self._api_key = api_key
41
- self._endpoint_url = endpoint_url
42
- self._prompt = prompt
43
- self._model_name = model_name
29
+ # Use the API schema for validation
30
+ validated_data = IngestTaskCaptionSchema(
31
+ api_key=api_key, endpoint_url=endpoint_url, prompt=prompt, model_name=model_name
32
+ )
33
+
34
+ self._api_key = validated_data.api_key
35
+ self._endpoint_url = validated_data.endpoint_url
36
+ self._prompt = validated_data.prompt
37
+ self._model_name = validated_data.model_name
44
38
 
45
39
  def __str__(self) -> str:
46
40
  """
@@ -9,35 +9,41 @@
9
9
  import logging
10
10
  from typing import Dict
11
11
 
12
- from pydantic import BaseModel
12
+ from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskChartExtraction
13
13
 
14
14
  from .task_base import Task
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
- class ChartExtractionSchema(BaseModel):
20
- class Config:
21
- extra = "forbid"
22
-
23
-
24
19
  class ChartExtractionTask(Task):
25
20
  """
26
21
  Object for chart extraction task
27
22
  """
28
23
 
29
- def __init__(self) -> None:
24
+ def __init__(self, params: dict = None) -> None:
30
25
  """
31
- Setup Dedup Task Config
26
+ Setup Chart Extraction Task Config
32
27
  """
33
28
  super().__init__()
34
29
 
30
+ # Handle None params by converting to empty dict for backward compatibility
31
+ if params is None:
32
+ params = {}
33
+
34
+ # Use the API schema for validation
35
+ validated_data = IngestTaskChartExtraction(params=params)
36
+
37
+ self._params = validated_data.params
38
+
35
39
  def __str__(self) -> str:
36
40
  """
37
41
  Returns a string with the object's config and run time state
38
42
  """
39
43
  info = ""
40
- info += "chart extraction task\n"
44
+ info += "Chart Extraction Task:\n"
45
+ if self._params:
46
+ info += f" params: {self._params}\n"
41
47
  return info
42
48
 
43
49
  def to_dict(self) -> Dict:
@@ -46,7 +52,7 @@ class ChartExtractionTask(Task):
46
52
  """
47
53
 
48
54
  task_properties = {
49
- "params": {},
55
+ "params": self._params,
50
56
  }
51
57
 
52
58
  return {"type": "chart_data_extract", "task_properties": task_properties}
@@ -10,29 +10,13 @@ import logging
10
10
  from typing import Dict
11
11
  from typing import Literal
12
12
 
13
- from pydantic import BaseModel, field_validator
14
-
13
+ from nv_ingest_api.internal.schemas.meta.ingest_job_schema import IngestTaskDedupSchema
15
14
 
16
15
  from .task_base import Task
17
16
 
18
17
  logger = logging.getLogger(__name__)
19
18
 
20
19
 
21
- class DedupTaskSchema(BaseModel):
22
- content_type: str = "image"
23
- filter: bool = False
24
-
25
- @field_validator("content_type")
26
- def content_type_must_be_valid(cls, v):
27
- valid_criteria = ["image"]
28
- if v not in valid_criteria:
29
- raise ValueError(f"content_type must be one of {valid_criteria}")
30
- return v
31
-
32
- class Config:
33
- extra = "forbid"
34
-
35
-
36
20
  class DedupTask(Task):
37
21
  """
38
22
  Object for document dedup task
@@ -49,8 +33,15 @@ class DedupTask(Task):
49
33
  Setup Dedup Task Config
50
34
  """
51
35
  super().__init__()
52
- self._content_type = content_type
53
- self._filter = filter
36
+
37
+ # Use the API schema for validation
38
+ validated_data = IngestTaskDedupSchema(
39
+ content_type=content_type,
40
+ params={"filter": filter},
41
+ )
42
+
43
+ self._content_type = validated_data.content_type
44
+ self._filter = validated_data.params.filter
54
45
 
55
46
  def __str__(self) -> str:
56
47
  """
@@ -58,7 +49,7 @@ class DedupTask(Task):
58
49
  """
59
50
  info = ""
60
51
  info += "Dedup Task:\n"
61
- info += f" content_type: {self._content_type}\n"
52
+ info += f" content_type: {self._content_type.value}\n"
62
53
  info += f" filter: {self._filter}\n"
63
54
  return info
64
55
 
@@ -69,7 +60,7 @@ class DedupTask(Task):
69
60
  dedup_params = {"filter": self._filter}
70
61
 
71
62
  task_properties = {
72
- "content_type": self._content_type,
63
+ "content_type": self._content_type.value,
73
64
  "params": dedup_params,
74
65
  }
75
66