deepdoctection 0.46__py3-none-any.whl → 0.46.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -22,7 +22,7 @@
22
22
  from __future__ import annotations
23
23
 
24
24
  from os import environ
25
- from typing import TYPE_CHECKING, Literal, Union
25
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union
26
26
 
27
27
  from lazy_imports import try_import
28
28
 
@@ -42,7 +42,7 @@ from ..extern.hflayoutlm import (
42
42
  get_tokenizer_from_model_class,
43
43
  )
44
44
  from ..extern.hflm import HFLmSequenceClassifier, HFLmTokenClassifier
45
- from ..extern.model import ModelCatalog, ModelDownloadManager
45
+ from ..extern.model import ModelCatalog, ModelDownloadManager, ModelProfile
46
46
  from ..extern.pdftext import PdfPlumberTextDetector
47
47
  from ..extern.tessocr import TesseractOcrDetector, TesseractRotationTransformer
48
48
  from ..extern.texocr import TextractOcrDetector
@@ -68,7 +68,7 @@ from ..pipe.transform import SimpleTransformService
68
68
  from ..utils.error import DependencyError
69
69
  from ..utils.fs import get_configs_dir_path
70
70
  from ..utils.metacfg import AttrDict
71
- from ..utils.settings import CellType, LayoutType, Relationships
71
+ from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships
72
72
  from ..utils.transform import PadTransform
73
73
 
74
74
  with try_import() as image_guard:
@@ -104,20 +104,14 @@ class ServiceFactory:
104
104
  """
105
105
 
106
106
  @staticmethod
107
- def _build_layout_detector(
108
- config: AttrDict,
109
- mode: str,
110
- ) -> Union[D2FrcnnDetector, TPFrcnnDetector, HFDetrDerivedDetector, D2FrcnnTracingDetector]:
107
+ def _get_layout_detector_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
111
108
  """
112
- Building a D2-Detector, a TP-Detector as Detr-Detector or a D2-Torch Tracing Detector according to
113
- the config.
109
+ Extracting layout detector kwargs from config.
114
110
 
115
111
  Args:
116
112
  config: Configuration object.
117
113
  mode: Either `LAYOUT`, `CELL`, or `ITEM`.
118
114
  """
119
- if config.LIB is None:
120
- raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
121
115
 
122
116
  weights = (
123
117
  getattr(config.TF, mode).WEIGHTS
@@ -128,16 +122,52 @@ class ServiceFactory:
128
122
  else getattr(config.PT, mode).WEIGHTS_TS
129
123
  )
130
124
  )
125
+
131
126
  filter_categories = (
132
127
  getattr(getattr(config.TF, mode), "FILTER")
133
128
  if config.LIB == "TF"
134
129
  else getattr(getattr(config.PT, mode), "FILTER")
135
130
  )
136
- config_path = ModelCatalog.get_full_path_configs(weights)
137
- weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
131
+
138
132
  profile = ModelCatalog.get_profile(weights)
133
+
139
134
  if config.LIB == "PT" and profile.padding is not None:
140
135
  getattr(config.PT, mode).PADDING = profile.padding
136
+
137
+ device = config.DEVICE
138
+
139
+ return {
140
+ "weights": weights,
141
+ "filter_categories": filter_categories,
142
+ "profile": profile,
143
+ "device": device,
144
+ "lib": config.LIB,
145
+ }
146
+
147
+ @staticmethod
148
+ def _build_layout_detector(
149
+ weights: str,
150
+ filter_categories: list[str],
151
+ profile: ModelProfile,
152
+ device: Literal["cpu", "cuda"],
153
+ lib: Literal["TF", "PT", None],
154
+ ) -> Union[D2FrcnnDetector, TPFrcnnDetector, HFDetrDerivedDetector, D2FrcnnTracingDetector]:
155
+ """
156
+ Building a D2-Detector, a TP-Detector as Detr-Detector or a D2-Torch Tracing Detector according to
157
+ the config.
158
+
159
+ Args:
160
+ weights: Weights for the layout detector.
161
+ filter_categories: Categories to filter during detection.
162
+ profile: Model profile for the layout detector.
163
+ device: Device to use for computation.
164
+ lib: Deep learning library to use.
165
+ """
166
+ if lib is None:
167
+ raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
168
+
169
+ config_path = ModelCatalog.get_full_path_configs(weights)
170
+ weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
141
171
  categories = profile.categories if profile.categories is not None else {}
142
172
 
143
173
  if profile.model_wrapper in ("TPFrcnnDetector",):
@@ -152,7 +182,7 @@ class ServiceFactory:
152
182
  path_yaml=config_path,
153
183
  path_weights=weights_path,
154
184
  categories=categories,
155
- device=config.DEVICE,
185
+ device=device,
156
186
  filter_categories=filter_categories,
157
187
  )
158
188
  if profile.model_wrapper in ("D2FrcnnTracingDetector",):
@@ -169,7 +199,7 @@ class ServiceFactory:
169
199
  path_weights=weights_path,
170
200
  path_feature_extractor_config_json=preprocessor_config,
171
201
  categories=categories,
172
- device=config.DEVICE,
202
+ device=device,
173
203
  filter_categories=filter_categories,
174
204
  )
175
205
  raise TypeError(
@@ -188,7 +218,8 @@ class ServiceFactory:
188
218
  config: Configuration object.
189
219
  mode: Either `LAYOUT`, `CELL`, or `ITEM`.
190
220
  """
191
- return ServiceFactory._build_layout_detector(config, mode)
221
+ layout_detector_kwargs = ServiceFactory._get_layout_detector_kwargs_from_config(config, mode)
222
+ return ServiceFactory._build_layout_detector(**layout_detector_kwargs)
192
223
 
193
224
  @staticmethod
194
225
  def _build_rotation_detector(rotator_name: Literal["tesseract", "doctr"]) -> RotationTransformer:
@@ -245,24 +276,36 @@ class ServiceFactory:
245
276
  return ServiceFactory._build_transform_service(transform_predictor)
246
277
 
247
278
  @staticmethod
248
- def _build_padder(config: AttrDict, mode: str) -> PadTransform:
279
+ def _get_padder_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
249
280
  """
250
- Building a padder according to the config.
281
+ Extracting padder kwargs from config.
251
282
 
252
283
  Args:
253
284
  config: Configuration object.
254
285
  mode: Either `LAYOUT`, `CELL`, or `ITEM`.
286
+ """
287
+ return {
288
+ "top": getattr(config.PT, mode).PAD.TOP,
289
+ "right": getattr(config.PT, mode).PAD.RIGHT,
290
+ "bottom": getattr(config.PT, mode).PAD.BOTTOM,
291
+ "left": getattr(config.PT, mode).PAD.LEFT,
292
+ }
293
+
294
+ @staticmethod
295
+ def _build_padder(top: int, right: int, bottom: int, left: int) -> PadTransform:
296
+ """
297
+ Building a padder according to the config.
298
+
299
+ Args:
300
+ top: Padding on the top side.
301
+ right: Padding on the right side.
302
+ bottom: Padding on the bottom side.
303
+ left: Padding on the left side.
255
304
 
256
305
  Returns:
257
306
  PadTransform: `PadTransform` instance.
258
307
  """
259
- top, right, bottom, left = (
260
- getattr(config.PT, mode).PAD.TOP,
261
- getattr(config.PT, mode).PAD.RIGHT,
262
- getattr(config.PT, mode).PAD.BOTTOM,
263
- getattr(config.PT, mode).PAD.LEFT,
264
- )
265
- return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left) #
308
+ return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left)
266
309
 
267
310
  @staticmethod
268
311
  def build_padder(config: AttrDict, mode: str) -> PadTransform:
@@ -276,24 +319,37 @@ class ServiceFactory:
276
319
  Returns:
277
320
  PadTransform: `PadTransform` instance.
278
321
  """
279
- return ServiceFactory._build_padder(config, mode)
322
+ padder_kwargs = ServiceFactory._get_padder_kwargs_from_config(config, mode)
323
+ return ServiceFactory._build_padder(**padder_kwargs)
280
324
 
281
325
  @staticmethod
282
- def _build_layout_service(config: AttrDict, detector: ObjectDetector, mode: str) -> ImageLayoutService:
326
+ def _get_layout_service_kwargs_from_config(config: AttrDict, mode: str) -> dict[str, Any]:
283
327
  """
284
- Building a layout service with a given detector.
328
+ Extracting layout service kwargs from config.
285
329
 
286
330
  Args:
287
331
  config: Configuration object.
288
- detector: Will be passed to the `ImageLayoutService`.
289
332
  mode: Either `LAYOUT`, `CELL`, or `ITEM`.
290
-
291
- Returns:
292
- ImageLayoutService: `ImageLayoutService` instance.
293
333
  """
294
334
  padder = None
295
335
  if getattr(config.PT, mode).PADDING:
296
336
  padder = ServiceFactory.build_padder(config, mode=mode)
337
+ return {
338
+ "padder": padder,
339
+ }
340
+
341
+ @staticmethod
342
+ def _build_layout_service(detector: ObjectDetector, padder: PadTransform) -> ImageLayoutService:
343
+ """
344
+ Building a layout service with a given detector.
345
+
346
+ Args:
347
+ detector: Will be passed to the `ImageLayoutService`.
348
+ padder: PadTransform instance.
349
+
350
+ Returns:
351
+ ImageLayoutService: `ImageLayoutService` instance.
352
+ """
297
353
  return ImageLayoutService(layout_detector=detector, to_image=True, crop_image=True, padder=padder)
298
354
 
299
355
  @staticmethod
@@ -309,27 +365,51 @@ class ServiceFactory:
309
365
  Returns:
310
366
  ImageLayoutService: `ImageLayoutService` instance.
311
367
  """
312
- return ServiceFactory._build_layout_service(config, detector, mode)
368
+ layout_service_kwargs = ServiceFactory._get_layout_service_kwargs_from_config(config, mode)
369
+ return ServiceFactory._build_layout_service(detector, **layout_service_kwargs)
313
370
 
314
371
  @staticmethod
315
- def _build_layout_nms_service(config: AttrDict) -> AnnotationNmsService:
372
+ def _get_layout_nms_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
316
373
  """
317
- Building a NMS service for layout annotations.
374
+ Extracting layout NMS service kwargs from config.
318
375
 
319
376
  Args:
320
377
  config: Configuration object.
321
-
322
- Returns:
323
- AnnotationNmsService: NMS service instance.
324
378
  """
379
+
325
380
  if not isinstance(config.LAYOUT_NMS_PAIRS.COMBINATIONS, list) and not isinstance(
326
381
  config.LAYOUT_NMS_PAIRS.COMBINATIONS[0], list
327
382
  ):
328
383
  raise ValueError("LAYOUT_NMS_PAIRS must be a list of lists")
384
+
385
+ return {
386
+ "nms_pairs": config.LAYOUT_NMS_PAIRS.COMBINATIONS,
387
+ "thresholds": config.LAYOUT_NMS_PAIRS.THRESHOLDS,
388
+ "priority": config.LAYOUT_NMS_PAIRS.PRIORITY,
389
+ }
390
+
391
+ @staticmethod
392
+ def _build_layout_nms_service(
393
+ nms_pairs: Sequence[Sequence[Union[ObjectTypes, str]]],
394
+ thresholds: Union[float, Sequence[float]],
395
+ priority: Sequence[Union[ObjectTypes, str, None]],
396
+ ) -> AnnotationNmsService:
397
+ """
398
+ Building a NMS service for layout annotations.
399
+
400
+ Args:
401
+ nms_pairs: Pairs of categories for NMS.
402
+ thresholds: NMS thresholds.
403
+ priority: Priority of categories.
404
+
405
+ Returns:
406
+ AnnotationNmsService: NMS service instance.
407
+ """
408
+
329
409
  return AnnotationNmsService(
330
- nms_pairs=config.LAYOUT_NMS_PAIRS.COMBINATIONS,
331
- thresholds=config.LAYOUT_NMS_PAIRS.THRESHOLDS,
332
- priority=config.LAYOUT_NMS_PAIRS.PRIORITY,
410
+ nms_pairs=nms_pairs,
411
+ thresholds=thresholds,
412
+ priority=priority,
333
413
  )
334
414
 
335
415
  @staticmethod
@@ -343,29 +423,41 @@ class ServiceFactory:
343
423
  Returns:
344
424
  AnnotationNmsService: NMS service instance.
345
425
  """
346
- return ServiceFactory._build_layout_nms_service(config)
426
+ nms_service_kwargs = ServiceFactory._get_layout_nms_service_kwargs_from_config(config)
427
+ return ServiceFactory._build_layout_nms_service(**nms_service_kwargs)
347
428
 
348
429
  @staticmethod
349
- def _build_sub_image_service(config: AttrDict, detector: ObjectDetector, mode: str) -> SubImageLayoutService:
430
+ def _get_sub_image_layout_service_kwargs_from_config(detector: ObjectDetector, mode: str) -> dict[str, Any]:
350
431
  """
351
- Building a sub image layout service with a given detector.
432
+ Extracting sub image service kwargs from config.
352
433
 
353
434
  Args:
354
- config: Configuration object.
355
- detector: Will be passed to the `SubImageLayoutService`.
356
435
  mode: Either `LAYOUT`, `CELL`, or `ITEM`.
357
-
358
- Returns:
359
- SubImageLayoutService: `SubImageLayoutService` instance.
360
436
  """
437
+
361
438
  exclude_category_names = []
362
- padder = None
363
439
  if mode == "ITEM":
364
440
  if detector.__class__.__name__ in ("HFDetrDerivedDetector",):
365
441
  exclude_category_names.extend(
366
442
  [LayoutType.TABLE, CellType.COLUMN_HEADER, CellType.PROJECTED_ROW_HEADER, CellType.SPANNING]
367
443
  )
368
- padder = ServiceFactory.build_padder(config, mode)
444
+ return {"exclude_category_names": exclude_category_names}
445
+
446
+ @staticmethod
447
+ def _build_sub_image_service(
448
+ detector: ObjectDetector, padder: Optional[PadTransform], exclude_category_names: list[ObjectTypes]
449
+ ) -> SubImageLayoutService:
450
+ """
451
+ Building a sub image layout service with a given detector.
452
+
453
+ Args:
454
+ detector: Will be passed to the `SubImageLayoutService`.
455
+ padder: PadTransform instance.
456
+ exclude_category_names: Category names to exclude during detection.
457
+
458
+ Returns:
459
+ SubImageLayoutService: `SubImageLayoutService` instance.
460
+ """
369
461
  detect_result_generator = DetectResultGenerator(
370
462
  categories_name_as_key=detector.categories.get_categories(as_dict=True, name_as_key=True),
371
463
  exclude_category_names=exclude_category_names,
@@ -390,26 +482,37 @@ class ServiceFactory:
390
482
  Returns:
391
483
  SubImageLayoutService: `SubImageLayoutService` instance.
392
484
  """
393
- return ServiceFactory._build_sub_image_service(config, detector, mode)
485
+ padder = None
486
+ if mode == "ITEM":
487
+ padder = ServiceFactory.build_padder(config, mode)
488
+ sub_image_layout_service_kwargs = ServiceFactory._get_sub_image_layout_service_kwargs_from_config(
489
+ detector, mode
490
+ )
491
+ return ServiceFactory._build_sub_image_service(detector, padder, **sub_image_layout_service_kwargs)
394
492
 
395
493
  @staticmethod
396
- def _build_ocr_detector(config: AttrDict) -> Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]:
494
+ def _get_ocr_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
397
495
  """
398
- Building OCR predictor.
496
+ Extracting OCR detector kwargs from config.
399
497
 
400
498
  Args:
401
499
  config: Configuration object.
402
-
403
- Returns:
404
- Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
405
500
  """
501
+ ocr_config_path = None
502
+ weights = None
503
+ languages = None
504
+ credentials_kwargs = None
505
+ use_tesseract = False
506
+ use_doctr = False
507
+ use_textract = False
508
+
406
509
  if config.OCR.USE_TESSERACT:
510
+ use_tesseract = True
407
511
  ocr_config_path = get_configs_dir_path() / config.OCR.CONFIG.TESSERACT
408
- return TesseractOcrDetector(
409
- ocr_config_path,
410
- config_overwrite=[f"LANGUAGES={config.LANGUAGE}"] if config.LANGUAGE is not None else None,
411
- )
512
+ languages = [f"LANGUAGES={config.LANGUAGE}"] if config.LANGUAGE is not None else None
513
+
412
514
  if config.OCR.USE_DOCTR:
515
+ use_doctr = True
413
516
  if config.LIB is None:
414
517
  raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
415
518
  weights = (
@@ -417,6 +520,63 @@ class ServiceFactory:
417
520
  if config.LIB == "TF"
418
521
  else (config.OCR.WEIGHTS.DOCTR_RECOGNITION.PT)
419
522
  )
523
+ if config.OCR.USE_TEXTRACT:
524
+ use_textract = True
525
+ credentials_kwargs = {
526
+ "aws_access_key_id": environ.get("AWS_ACCESS_KEY", None),
527
+ "aws_secret_access_key": environ.get("AWS_SECRET_KEY", None),
528
+ "config": Config(region_name=environ.get("AWS_REGION", None)),
529
+ }
530
+
531
+ return {
532
+ "use_tesseract": use_tesseract,
533
+ "use_doctr": use_doctr,
534
+ "use_textract": use_textract,
535
+ "ocr_config_path": ocr_config_path,
536
+ "languages": languages,
537
+ "weights": weights,
538
+ "credentials_kwargs": credentials_kwargs,
539
+ "lib": config.LIB,
540
+ "device": config.DEVICE,
541
+ }
542
+
543
+ @staticmethod
544
+ def _build_ocr_detector(
545
+ use_tesseract: bool,
546
+ use_doctr: bool,
547
+ use_textract: bool,
548
+ ocr_config_path: str,
549
+ languages: Union[list[str], None],
550
+ weights: str,
551
+ credentials_kwargs: dict[str, Any],
552
+ lib: Literal["TF", "PT", None],
553
+ device: Literal["cuda", "cpu"],
554
+ ) -> Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]:
555
+ """
556
+ Building OCR predictor.
557
+
558
+ Args:
559
+ use_tesseract: Whether to use Tesseract OCR.
560
+ use_doctr: Whether to use Doctr OCR.
561
+ use_textract: Whether to use Textract OCR.
562
+ ocr_config_path: Path to OCR config.
563
+ languages: Languages for OCR.
564
+ weights: Weights for Doctr OCR.
565
+ credentials_kwargs: Credentials for Textract OCR.
566
+ lib: Deep learning library to use.
567
+ device: Device to use for computation.
568
+
569
+ Returns:
570
+ Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
571
+ """
572
+ if use_tesseract:
573
+ return TesseractOcrDetector(
574
+ ocr_config_path,
575
+ config_overwrite=languages,
576
+ )
577
+ if use_doctr:
578
+ if lib is None:
579
+ raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
420
580
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
421
581
  profile = ModelCatalog.get_profile(weights)
422
582
  # get_full_path_configs will complete the path even if the model is not registered
@@ -426,16 +586,11 @@ class ServiceFactory:
426
586
  return DoctrTextRecognizer(
427
587
  architecture=profile.architecture,
428
588
  path_weights=weights_path,
429
- device=config.DEVICE,
430
- lib=config.LIB,
589
+ device=device,
590
+ lib=lib,
431
591
  path_config_json=config_path,
432
592
  )
433
- if config.OCR.USE_TEXTRACT:
434
- credentials_kwargs = {
435
- "aws_access_key_id": environ.get("AWS_ACCESS_KEY", None),
436
- "aws_secret_access_key": environ.get("AWS_SECRET_KEY", None),
437
- "config": Config(region_name=environ.get("AWS_REGION", None)),
438
- }
593
+ if use_textract:
439
594
  return TextractOcrDetector(**credentials_kwargs)
440
595
  raise ValueError("You have set USE_OCR=True but any of USE_TESSERACT, USE_DOCTR, USE_TEXTRACT is set to False")
441
596
 
@@ -450,36 +605,114 @@ class ServiceFactory:
450
605
  Returns:
451
606
  Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]: OCR detector instance.
452
607
  """
453
- return ServiceFactory._build_ocr_detector(config)
608
+ ocr_detector_kwargs = ServiceFactory._get_ocr_detector_kwargs_from_config(config)
609
+ return ServiceFactory._build_ocr_detector(**ocr_detector_kwargs)
454
610
 
455
611
  @staticmethod
456
- def build_doctr_word_detector(config: AttrDict) -> DoctrTextlineDetector:
612
+ def _get_doctr_word_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
457
613
  """
458
- Building `DoctrTextlineDetector` instance.
614
+ Extracting Doctr word detector kwargs from config.
459
615
 
460
616
  Args:
461
617
  config: Configuration object.
618
+ """
619
+ weights = config.OCR.WEIGHTS.DOCTR_WORD.TF if config.LIB == "TF" else config.OCR.WEIGHTS.DOCTR_WORD.PT
620
+ profile = ModelCatalog.get_profile(weights)
621
+ return {
622
+ "weights": weights,
623
+ "profile": profile,
624
+ "device": config.DEVICE,
625
+ "lib": config.LIB,
626
+ }
627
+
628
+ @staticmethod
629
+ def _build_doctr_word_detector(
630
+ weights: str, profile: ModelProfile, device: Literal["cuda", "cpu"], lib: Literal["PT", "TF"]
631
+ ) -> DoctrTextlineDetector:
632
+ """
633
+ Building `DoctrTextlineDetector` instance.
634
+
635
+ Args:
636
+ weights: Weights for Doctr word detector.
637
+ profile: Model profile for Doctr word detector.
638
+ device: Device to use for computation.
639
+ lib: Deep learning library to use.
462
640
 
463
641
  Returns:
464
642
  DoctrTextlineDetector: Textline detector instance.
465
643
  """
466
- if config.LIB is None:
644
+ if lib is None:
467
645
  raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
468
- weights = config.OCR.WEIGHTS.DOCTR_WORD.TF if config.LIB == "TF" else config.OCR.WEIGHTS.DOCTR_WORD.PT
469
646
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs(weights)
470
- profile = ModelCatalog.get_profile(weights)
471
647
  if profile.architecture is None:
472
648
  raise ValueError("model profile.architecture must be specified")
473
649
  if profile.categories is None:
474
650
  raise ValueError("model profile.categories must be specified")
475
- return DoctrTextlineDetector(
476
- profile.architecture, weights_path, profile.categories, config.DEVICE, lib=config.LIB
477
- )
651
+ return DoctrTextlineDetector(profile.architecture, weights_path, profile.categories, device, lib=lib)
652
+
653
+ @staticmethod
654
+ def build_doctr_word_detector(config: AttrDict) -> DoctrTextlineDetector:
655
+ """
656
+ Building `DoctrTextlineDetector` instance.
657
+
658
+ Args:
659
+ config: Configuration object.
660
+
661
+ Returns:
662
+ DoctrTextlineDetector: Textline detector instance.
663
+ """
664
+ doctr_word_detector_kwargs = ServiceFactory._get_doctr_word_detector_kwargs_from_config(config)
665
+ return ServiceFactory._build_doctr_word_detector(**doctr_word_detector_kwargs)
666
+
667
+ @staticmethod
668
+ def _get_table_segmentation_service_kwargs_from_config(config: AttrDict, detector_name: str) -> dict[str, Any]:
669
+ """
670
+ Extracting table segmentation service kwargs from config.
671
+
672
+ Args:
673
+ config: Configuration object.
674
+ detector_name: An instance name of `ObjectDetector`.
675
+ """
676
+ return {
677
+ "segment_rule": config.SEGMENTATION.ASSIGNMENT_RULE,
678
+ "threshold_rows": config.SEGMENTATION.THRESHOLD_ROWS,
679
+ "threshold_cols": config.SEGMENTATION.THRESHOLD_COLS,
680
+ "tile_table_with_items": config.SEGMENTATION.FULL_TABLE_TILING,
681
+ "remove_iou_threshold_rows": config.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
682
+ "remove_iou_threshold_cols": config.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
683
+ "table_name": config.SEGMENTATION.TABLE_NAME,
684
+ "cell_names": config.SEGMENTATION.PUBTABLES_CELL_NAMES
685
+ if detector_name in ("HFDetrDerivedDetector",)
686
+ else config.SEGMENTATION.CELL_NAMES,
687
+ "spanning_cell_names": config.SEGMENTATION.PUBTABLES_SPANNING_CELL_NAMES,
688
+ "item_names": config.SEGMENTATION.PUBTABLES_ITEM_NAMES
689
+ if detector_name in ("HFDetrDerivedDetector",)
690
+ else config.SEGMENTATION.ITEM_NAMES,
691
+ "sub_item_names": config.SEGMENTATION.PUBTABLES_SUB_ITEM_NAMES
692
+ if detector_name in ("HFDetrDerivedDetector",)
693
+ else config.SEGMENTATION.SUB_ITEM_NAMES,
694
+ "item_header_cell_names": config.SEGMENTATION.PUBTABLES_ITEM_HEADER_CELL_NAMES,
695
+ "item_header_thresholds": config.SEGMENTATION.PUBTABLES_ITEM_HEADER_THRESHOLDS,
696
+ "stretch_rule": config.SEGMENTATION.STRETCH_RULE,
697
+ }
478
698
 
479
699
  @staticmethod
480
700
  def _build_table_segmentation_service(
481
- config: AttrDict,
482
701
  detector: ObjectDetector,
702
+ segment_rule: Literal["iou", "ioa"],
703
+ threshold_rows: float,
704
+ threshold_cols: float,
705
+ tile_table_with_items: bool,
706
+ remove_iou_threshold_rows: float,
707
+ remove_iou_threshold_cols: float,
708
+ table_name: Union[ObjectTypes, str],
709
+ cell_names: Sequence[Union[ObjectTypes, str]],
710
+ spanning_cell_names: Sequence[Union[ObjectTypes, str]],
711
+ item_names: Sequence[Union[ObjectTypes, str]],
712
+ sub_item_names: Sequence[Union[ObjectTypes, str]],
713
+ item_header_cell_names: Sequence[Union[ObjectTypes, str]],
714
+ item_header_thresholds: Sequence[float],
715
+ stretch_rule: Literal["left", "equal"],
483
716
  ) -> Union[PubtablesSegmentationService, TableSegmentationService]:
484
717
  """
485
718
  Build and return a table segmentation service based on the provided detector.
@@ -495,8 +728,21 @@ class ServiceFactory:
495
728
  configuration parameters from the `cfg` object but is tailored for different segmentation needs.
496
729
 
497
730
  Args:
498
- config: Configuration object.
499
731
  detector: An instance of `ObjectDetector` used to determine the type of table segmentation service to build.
732
+ segment_rule: Rule for segmenting tables.
733
+ threshold_rows: Threshold for row segmentation.
734
+ threshold_cols: Threshold for column segmentation.
735
+ tile_table_with_items: Whether to tile the table with items.
736
+ remove_iou_threshold_rows: IOU threshold for removing rows.
737
+ remove_iou_threshold_cols: IOU threshold for removing columns.
738
+ table_name: Name of the table object type.
739
+ cell_names: Names of the cell object types.
740
+ spanning_cell_names: Names of the spanning cell object types.
741
+ item_names: Names of the item object types.
742
+ sub_item_names: Names of the sub-item object types.
743
+ item_header_cell_names: Names of the item header cell object types.
744
+ item_header_thresholds: Thresholds for item header segmentation.
745
+ stretch_rule: Rule for stretching cells.
500
746
 
501
747
  Returns:
502
748
  Table segmentation service instance.
@@ -504,35 +750,35 @@ class ServiceFactory:
504
750
  table_segmentation: Union[PubtablesSegmentationService, TableSegmentationService]
505
751
  if detector.__class__.__name__ in ("HFDetrDerivedDetector",):
506
752
  table_segmentation = PubtablesSegmentationService(
507
- segment_rule=config.SEGMENTATION.ASSIGNMENT_RULE,
508
- threshold_rows=config.SEGMENTATION.THRESHOLD_ROWS,
509
- threshold_cols=config.SEGMENTATION.THRESHOLD_COLS,
510
- tile_table_with_items=config.SEGMENTATION.FULL_TABLE_TILING,
511
- remove_iou_threshold_rows=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
512
- remove_iou_threshold_cols=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
513
- table_name=config.SEGMENTATION.TABLE_NAME,
514
- cell_names=config.SEGMENTATION.PUBTABLES_CELL_NAMES,
515
- spanning_cell_names=config.SEGMENTATION.PUBTABLES_SPANNING_CELL_NAMES,
516
- item_names=config.SEGMENTATION.PUBTABLES_ITEM_NAMES,
517
- sub_item_names=config.SEGMENTATION.PUBTABLES_SUB_ITEM_NAMES,
518
- item_header_cell_names=config.SEGMENTATION.PUBTABLES_ITEM_HEADER_CELL_NAMES,
519
- item_header_thresholds=config.SEGMENTATION.PUBTABLES_ITEM_HEADER_THRESHOLDS,
520
- stretch_rule=config.SEGMENTATION.STRETCH_RULE,
753
+ segment_rule=segment_rule,
754
+ threshold_rows=threshold_rows,
755
+ threshold_cols=threshold_cols,
756
+ tile_table_with_items=tile_table_with_items,
757
+ remove_iou_threshold_rows=remove_iou_threshold_rows,
758
+ remove_iou_threshold_cols=remove_iou_threshold_cols,
759
+ table_name=table_name,
760
+ cell_names=cell_names,
761
+ spanning_cell_names=spanning_cell_names,
762
+ item_names=item_names,
763
+ sub_item_names=sub_item_names,
764
+ item_header_cell_names=item_header_cell_names,
765
+ item_header_thresholds=item_header_thresholds,
766
+ stretch_rule=stretch_rule,
521
767
  )
522
768
 
523
769
  else:
524
770
  table_segmentation = TableSegmentationService(
525
- segment_rule=config.SEGMENTATION.ASSIGNMENT_RULE,
526
- threshold_rows=config.SEGMENTATION.THRESHOLD_ROWS,
527
- threshold_cols=config.SEGMENTATION.THRESHOLD_COLS,
528
- tile_table_with_items=config.SEGMENTATION.FULL_TABLE_TILING,
529
- remove_iou_threshold_rows=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
530
- remove_iou_threshold_cols=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
531
- table_name=config.SEGMENTATION.TABLE_NAME,
532
- cell_names=config.SEGMENTATION.CELL_NAMES,
533
- item_names=config.SEGMENTATION.ITEM_NAMES,
534
- sub_item_names=config.SEGMENTATION.SUB_ITEM_NAMES,
535
- stretch_rule=config.SEGMENTATION.STRETCH_RULE,
771
+ segment_rule=segment_rule,
772
+ threshold_rows=threshold_rows,
773
+ threshold_cols=threshold_cols,
774
+ tile_table_with_items=tile_table_with_items,
775
+ remove_iou_threshold_rows=remove_iou_threshold_rows,
776
+ remove_iou_threshold_cols=remove_iou_threshold_cols,
777
+ table_name=table_name,
778
+ cell_names=cell_names,
779
+ item_names=item_names,
780
+ sub_item_names=sub_item_names,
781
+ stretch_rule=stretch_rule,
536
782
  )
537
783
  return table_segmentation
538
784
 
@@ -561,10 +807,29 @@ class ServiceFactory:
561
807
  Returns:
562
808
  Table segmentation service instance.
563
809
  """
564
- return ServiceFactory._build_table_segmentation_service(config, detector)
810
+ table_segmentation_service_kwargs = ServiceFactory._get_table_segmentation_service_kwargs_from_config(
811
+ config, detector.__class__.__name__
812
+ )
813
+ return ServiceFactory._build_table_segmentation_service(detector, **table_segmentation_service_kwargs)
565
814
 
566
815
  @staticmethod
567
- def _build_table_refinement_service(config: AttrDict) -> TableSegmentationRefinementService:
816
+ def _get_table_refinement_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
817
+ """
818
+ Extracting table segmentation refinement service kwargs from config.
819
+
820
+ Args:
821
+ config: Configuration object.
822
+ """
823
+
824
+ return {
825
+ "table_names": [config.SEGMENTATION.TABLE_NAME],
826
+ "cell_names": config.SEGMENTATION.PUBTABLES_CELL_NAMES,
827
+ }
828
+
829
+ @staticmethod
830
+ def _build_table_refinement_service(
831
+ table_names: Sequence[ObjectTypes], cell_names: Sequence[ObjectTypes]
832
+ ) -> TableSegmentationRefinementService:
568
833
  """
569
834
  Building a table segmentation refinement service.
570
835
 
@@ -574,10 +839,7 @@ class ServiceFactory:
574
839
  Returns:
575
840
  TableSegmentationRefinementService: Refinement service instance.
576
841
  """
577
- return TableSegmentationRefinementService(
578
- [config.SEGMENTATION.TABLE_NAME],
579
- config.SEGMENTATION.PUBTABLES_CELL_NAMES,
580
- )
842
+ return TableSegmentationRefinementService(table_names=table_names, cell_names=cell_names)
581
843
 
582
844
  @staticmethod
583
845
  def build_table_refinement_service(config: AttrDict) -> TableSegmentationRefinementService:
@@ -590,22 +852,35 @@ class ServiceFactory:
590
852
  Returns:
591
853
  TableSegmentationRefinementService: Refinement service instance.
592
854
  """
593
- return ServiceFactory._build_table_refinement_service(config)
855
+ table_refinement_service_kwargs = ServiceFactory._get_table_refinement_service_kwargs_from_config(config)
856
+ return ServiceFactory._build_table_refinement_service(**table_refinement_service_kwargs)
594
857
 
595
858
  @staticmethod
596
- def _build_pdf_text_detector(config: AttrDict) -> PdfPlumberTextDetector:
859
+ def _get_pdf_text_detector_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
597
860
  """
598
- Building a PDF text detector.
861
+ Extracting PDF text detector kwargs from config.
599
862
 
600
863
  Args:
601
864
  config: Configuration object.
865
+ """
866
+ return {
867
+ "x_tolerance": config.PDF_MINER.X_TOLERANCE,
868
+ "y_tolerance": config.PDF_MINER.Y_TOLERANCE,
869
+ }
870
+
871
+ @staticmethod
872
+ def _build_pdf_text_detector(x_tolerance: int, y_tolerance: int) -> PdfPlumberTextDetector:
873
+ """
874
+ Building a PDF text detector.
875
+
876
+ Args:
877
+ x_tolerance: X tolerance for text extraction.
878
+ y_tolerance: Y tolerance for text extraction.
602
879
 
603
880
  Returns:
604
881
  PdfPlumberTextDetector: PDF text detector instance.
605
882
  """
606
- return PdfPlumberTextDetector(
607
- x_tolerance=config.PDF_MINER.X_TOLERANCE, y_tolerance=config.PDF_MINER.Y_TOLERANCE
608
- )
883
+ return PdfPlumberTextDetector(x_tolerance=x_tolerance, y_tolerance=y_tolerance)
609
884
 
610
885
  @staticmethod
611
886
  def build_pdf_text_detector(config: AttrDict) -> PdfPlumberTextDetector:
@@ -618,7 +893,8 @@ class ServiceFactory:
618
893
  Returns:
619
894
  PdfPlumberTextDetector: PDF text detector instance.
620
895
  """
621
- return ServiceFactory._build_pdf_text_detector(config)
896
+ pdf_text_detector_kwargs = ServiceFactory._get_pdf_text_detector_kwargs_from_config(config)
897
+ return ServiceFactory._build_pdf_text_detector(**pdf_text_detector_kwargs)
622
898
 
623
899
  @staticmethod
624
900
  def _build_pdf_miner_text_service(detector: PdfMiner) -> TextExtractionService:
@@ -672,24 +948,34 @@ class ServiceFactory:
672
948
  """
673
949
  return ServiceFactory._build_doctr_word_detector_service(detector)
674
950
 
951
+ @staticmethod
952
+ def _get_text_extraction_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
953
+ """
954
+ Extracting text extraction service kwargs from config.
955
+
956
+ Args:
957
+ config: Configuration object.
958
+ """
959
+ return {
960
+ "extract_from_roi": config.TEXT_CONTAINER if config.OCR.USE_DOCTR else None,
961
+ }
962
+
675
963
  @staticmethod
676
964
  def _build_text_extraction_service(
677
- config: AttrDict, detector: Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector]
965
+ detector: Union[TesseractOcrDetector, DoctrTextRecognizer, TextractOcrDetector],
966
+ extract_from_roi: Union[Sequence[ObjectTypes], ObjectTypes, None] = None,
678
967
  ) -> TextExtractionService:
679
968
  """
680
969
  Building a text extraction service.
681
970
 
682
971
  Args:
683
- config: Configuration object.
684
972
  detector: OCR detector instance.
973
+ extract_from_roi: ROI categories to extract text from.
685
974
 
686
975
  Returns:
687
976
  TextExtractionService: Text extraction service instance.
688
977
  """
689
- return TextExtractionService(
690
- detector,
691
- extract_from_roi=config.TEXT_CONTAINER if config.OCR.USE_DOCTR else None,
692
- )
978
+ return TextExtractionService(detector, extract_from_roi=extract_from_roi)
693
979
 
694
980
  @staticmethod
695
981
  def build_text_extraction_service(
@@ -705,28 +991,55 @@ class ServiceFactory:
705
991
  Returns:
706
992
  TextExtractionService: Text extraction service instance.
707
993
  """
708
- return ServiceFactory._build_text_extraction_service(config, detector)
994
+ text_extraction_service_kwargs = ServiceFactory._get_text_extraction_service_kwargs_from_config(config)
995
+ return ServiceFactory._build_text_extraction_service(detector, **text_extraction_service_kwargs)
709
996
 
710
997
  @staticmethod
711
- def _build_word_matching_service(config: AttrDict) -> MatchingService:
998
+ def _get_word_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
712
999
  """
713
- Building a word matching service.
1000
+ Extracting word matching service kwargs from config.
714
1001
 
715
1002
  Args:
716
1003
  config: Configuration object.
1004
+ """
1005
+ return {
1006
+ "matching_rule": config.WORD_MATCHING.RULE,
1007
+ "threshold": config.WORD_MATCHING.THRESHOLD,
1008
+ "max_parent_only": config.WORD_MATCHING.MAX_PARENT_ONLY,
1009
+ "parental_categories": config.WORD_MATCHING.PARENTAL_CATEGORIES,
1010
+ "text_container": config.TEXT_CONTAINER,
1011
+ }
1012
+
1013
+ @staticmethod
1014
+ def _build_word_matching_service(
1015
+ matching_rule: Literal["iou", "ioa"],
1016
+ threshold: float,
1017
+ max_parent_only: bool,
1018
+ parental_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
1019
+ text_container: Union[Sequence[ObjectTypes], ObjectTypes, None],
1020
+ ) -> MatchingService:
1021
+ """
1022
+ Building a word matching service.
1023
+
1024
+ Args:
1025
+ matching_rule: Matching rule for intersection matcher.
1026
+ threshold: Threshold for intersection matcher.
1027
+ max_parent_only: Whether to use max parent only.
1028
+ parental_categories: Parent categories for matching.
1029
+ text_container: Text container categories.
717
1030
 
718
1031
  Returns:
719
1032
  MatchingService: Word matching service instance.
720
1033
  """
721
1034
  matcher = IntersectionMatcher(
722
- matching_rule=config.WORD_MATCHING.RULE,
723
- threshold=config.WORD_MATCHING.THRESHOLD,
724
- max_parent_only=config.WORD_MATCHING.MAX_PARENT_ONLY,
1035
+ matching_rule=matching_rule,
1036
+ threshold=threshold,
1037
+ max_parent_only=max_parent_only,
725
1038
  )
726
1039
  family_compounds = [
727
1040
  FamilyCompound(
728
- parent_categories=config.WORD_MATCHING.PARENTAL_CATEGORIES,
729
- child_categories=config.TEXT_CONTAINER,
1041
+ parent_categories=parental_categories,
1042
+ child_categories=text_container,
730
1043
  relationship_key=Relationships.CHILD,
731
1044
  ),
732
1045
  FamilyCompound(
@@ -753,15 +1066,33 @@ class ServiceFactory:
753
1066
  Returns:
754
1067
  MatchingService: Word matching service instance.
755
1068
  """
756
- return ServiceFactory._build_word_matching_service(config)
1069
+ word_matching_service_kwargs = ServiceFactory._get_word_matching_service_kwargs_from_config(config)
1070
+ return ServiceFactory._build_word_matching_service(**word_matching_service_kwargs)
757
1071
 
758
1072
  @staticmethod
759
- def _build_layout_link_matching_service(config: AttrDict) -> MatchingService:
1073
+ def _get_layout_link_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
760
1074
  """
761
- Building a layout link matching service.
1075
+ Extracting layout link matching service kwargs from config.
762
1076
 
763
1077
  Args:
764
1078
  config: Configuration object.
1079
+ """
1080
+ return {
1081
+ "parental_categories": config.LAYOUT_LINK.PARENTAL_CATEGORIES,
1082
+ "child_categories": config.LAYOUT_LINK.CHILD_CATEGORIES,
1083
+ }
1084
+
1085
+ @staticmethod
1086
+ def _build_layout_link_matching_service(
1087
+ parental_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
1088
+ child_categories: Union[Sequence[ObjectTypes], ObjectTypes, None],
1089
+ ) -> MatchingService:
1090
+ """
1091
+ Building a layout link matching service.
1092
+
1093
+ Args:
1094
+ parental_categories: Parent categories for layout linking.
1095
+ child_categories: Child categories for layout linking.
765
1096
 
766
1097
  Returns:
767
1098
  MatchingService: Layout link matching service instance.
@@ -769,8 +1100,8 @@ class ServiceFactory:
769
1100
  neighbor_matcher = NeighbourMatcher()
770
1101
  family_compounds = [
771
1102
  FamilyCompound(
772
- parent_categories=config.LAYOUT_LINK.PARENTAL_CATEGORIES,
773
- child_categories=config.LAYOUT_LINK.CHILD_CATEGORIES,
1103
+ parent_categories=parental_categories,
1104
+ child_categories=child_categories,
774
1105
  relationship_key=Relationships.LAYOUT_LINK,
775
1106
  )
776
1107
  ]
@@ -790,23 +1121,44 @@ class ServiceFactory:
790
1121
  Returns:
791
1122
  MatchingService: Layout link matching service instance.
792
1123
  """
793
- return ServiceFactory._build_layout_link_matching_service(config)
1124
+ layout_link_matching_service_kwargs = ServiceFactory._get_layout_link_matching_service_kwargs_from_config(
1125
+ config
1126
+ )
1127
+ return ServiceFactory._build_layout_link_matching_service(**layout_link_matching_service_kwargs)
794
1128
 
795
1129
  @staticmethod
796
- def _build_line_matching_service(config: AttrDict) -> MatchingService:
1130
+ def _get_line_matching_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
797
1131
  """
798
- Building a line matching service.
1132
+ Extracting line matching service kwargs from config.
799
1133
 
800
1134
  Args:
801
1135
  config: Configuration object.
1136
+ """
1137
+ return {
1138
+ "matching_rule": config.WORD_MATCHING.RULE,
1139
+ "threshold": config.WORD_MATCHING.THRESHOLD,
1140
+ "max_parent_only": config.WORD_MATCHING.MAX_PARENT_ONLY,
1141
+ }
1142
+
1143
+ @staticmethod
1144
+ def _build_line_matching_service(
1145
+ matching_rule: Literal["iou", "ioa"], threshold: float, max_parent_only: bool
1146
+ ) -> MatchingService:
1147
+ """
1148
+ Building a line matching service.
1149
+
1150
+ Args:
1151
+ matching_rule: Matching rule for intersection matcher.
1152
+ threshold: Threshold for intersection matcher.
1153
+ max_parent_only: Whether to use max parent only.
802
1154
 
803
1155
  Returns:
804
1156
  MatchingService: Line matching service instance.
805
1157
  """
806
1158
  matcher = IntersectionMatcher(
807
- matching_rule=config.WORD_MATCHING.RULE,
808
- threshold=config.WORD_MATCHING.THRESHOLD,
809
- max_parent_only=config.WORD_MATCHING.MAX_PARENT_ONLY,
1159
+ matching_rule=matching_rule,
1160
+ threshold=threshold,
1161
+ max_parent_only=max_parent_only,
810
1162
  )
811
1163
  family_compounds = [
812
1164
  FamilyCompound(
@@ -831,28 +1183,64 @@ class ServiceFactory:
831
1183
  Returns:
832
1184
  MatchingService: Line matching service instance.
833
1185
  """
834
- return ServiceFactory._build_line_matching_service(config)
1186
+ line_matching_service_kwargs = ServiceFactory._get_line_matching_service_kwargs_from_config(config)
1187
+ return ServiceFactory._build_line_matching_service(**line_matching_service_kwargs)
835
1188
 
836
1189
  @staticmethod
837
- def _build_text_order_service(config: AttrDict) -> TextOrderService:
1190
+ def _get_text_order_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
838
1191
  """
839
- Building a text order service.
1192
+ Extracting text order service kwargs from config.
840
1193
 
841
1194
  Args:
842
1195
  config: Configuration object.
1196
+ """
1197
+ return {
1198
+ "text_container": config.TEXT_CONTAINER,
1199
+ "text_block_categories": config.TEXT_ORDERING.TEXT_BLOCK_CATEGORIES,
1200
+ "floating_text_block_categories": config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
1201
+ "include_residual_text_container": config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
1202
+ "starting_point_tolerance": config.TEXT_ORDERING.STARTING_POINT_TOLERANCE,
1203
+ "broken_line_tolerance": config.TEXT_ORDERING.BROKEN_LINE_TOLERANCE,
1204
+ "height_tolerance": config.TEXT_ORDERING.HEIGHT_TOLERANCE,
1205
+ "paragraph_break": config.TEXT_ORDERING.PARAGRAPH_BREAK,
1206
+ }
1207
+
1208
+ @staticmethod
1209
+ def _build_text_order_service(
1210
+ text_container: str,
1211
+ text_block_categories: Sequence[str],
1212
+ floating_text_block_categories: Sequence[str],
1213
+ include_residual_text_container: bool,
1214
+ starting_point_tolerance: float,
1215
+ broken_line_tolerance: float,
1216
+ height_tolerance: float,
1217
+ paragraph_break: float,
1218
+ ) -> TextOrderService:
1219
+ """
1220
+ Building a text order service.
1221
+
1222
+ Args:
1223
+ text_container: Text container categories.
1224
+ text_block_categories: Text block categories for ordering.
1225
+ floating_text_block_categories: Floating text block categories.
1226
+ include_residual_text_container: Whether to include residual text container.
1227
+ starting_point_tolerance: Starting point tolerance for text ordering.
1228
+ broken_line_tolerance: Broken line tolerance for text ordering.
1229
+ height_tolerance: Height tolerance for text ordering.
1230
+ paragraph_break: Paragraph break threshold.
843
1231
 
844
1232
  Returns:
845
1233
  TextOrderService: Text order service instance.
846
1234
  """
847
1235
  return TextOrderService(
848
- text_container=config.TEXT_CONTAINER,
849
- text_block_categories=config.TEXT_ORDERING.TEXT_BLOCK_CATEGORIES,
850
- floating_text_block_categories=config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
851
- include_residual_text_container=config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
852
- starting_point_tolerance=config.TEXT_ORDERING.STARTING_POINT_TOLERANCE,
853
- broken_line_tolerance=config.TEXT_ORDERING.BROKEN_LINE_TOLERANCE,
854
- height_tolerance=config.TEXT_ORDERING.HEIGHT_TOLERANCE,
855
- paragraph_break=config.TEXT_ORDERING.PARAGRAPH_BREAK,
1236
+ text_container=text_container,
1237
+ text_block_categories=text_block_categories,
1238
+ floating_text_block_categories=floating_text_block_categories,
1239
+ include_residual_text_container=include_residual_text_container,
1240
+ starting_point_tolerance=starting_point_tolerance,
1241
+ broken_line_tolerance=broken_line_tolerance,
1242
+ height_tolerance=height_tolerance,
1243
+ paragraph_break=paragraph_break,
856
1244
  )
857
1245
 
858
1246
  @staticmethod
@@ -866,18 +1254,16 @@ class ServiceFactory:
866
1254
  Returns:
867
1255
  TextOrderService: Text order service instance.
868
1256
  """
869
- return ServiceFactory._build_text_order_service(config)
1257
+ text_order_service_kwargs = ServiceFactory._get_text_order_service_kwargs_from_config(config)
1258
+ return ServiceFactory._build_text_order_service(**text_order_service_kwargs)
870
1259
 
871
1260
  @staticmethod
872
- def _build_sequence_classifier(config: AttrDict) -> Union[LayoutSequenceModels, LmSequenceModels]:
1261
+ def _get_sequence_classifier_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
873
1262
  """
874
- Builds and returns a sequence classifier instance.
1263
+ Extracting sequence classifier kwargs from config.
875
1264
 
876
1265
  Args:
877
- config: Configuration object that determines the type of sequence classifier to construct.
878
-
879
- Returns:
880
- A sequence classifier instance constructed according to the specified configuration.
1266
+ config: Configuration object.
881
1267
  """
882
1268
  config_path = ModelCatalog.get_full_path_configs(config.LM_SEQUENCE_CLASS.WEIGHTS)
883
1269
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs(config.LM_SEQUENCE_CLASS.WEIGHTS)
@@ -885,47 +1271,79 @@ class ServiceFactory:
885
1271
  categories = profile.categories if profile.categories is not None else {}
886
1272
  use_xlm_tokenizer = "xlm_tokenizer" == profile.architecture
887
1273
 
888
- if profile.model_wrapper in ("HFLayoutLmSequenceClassifier",):
1274
+ return {
1275
+ "config_path": config_path,
1276
+ "weights_path": weights_path,
1277
+ "categories": categories,
1278
+ "device": config.DEVICE,
1279
+ "use_xlm_tokenizer": use_xlm_tokenizer,
1280
+ "model_wrapper": profile.model_wrapper,
1281
+ }
1282
+
1283
+ @staticmethod
1284
+ def _build_sequence_classifier(
1285
+ config_path: str,
1286
+ weights_path: str,
1287
+ categories: Mapping[int, Union[ObjectTypes, str]],
1288
+ device: Literal["cuda", "cpu"],
1289
+ use_xlm_tokenizer: bool,
1290
+ model_wrapper: str,
1291
+ ) -> Union[LayoutSequenceModels, LmSequenceModels]:
1292
+ """
1293
+ Builds and returns a sequence classifier instance.
1294
+
1295
+ Args:
1296
+ config_path: Path to model configuration.
1297
+ weights_path: Path to model weights.
1298
+ categories: Model categories mapping.
1299
+ device: Device to run model on.
1300
+ use_xlm_tokenizer: Whether to use XLM tokenizer.
1301
+ model_wrapper: Model wrapper class name.
1302
+
1303
+ Returns:
1304
+ A sequence classifier instance constructed according to the specified configuration.
1305
+ """
1306
+ if model_wrapper in ("HFLayoutLmSequenceClassifier",):
889
1307
  return HFLayoutLmSequenceClassifier(
890
1308
  path_config_json=config_path,
891
1309
  path_weights=weights_path,
892
1310
  categories=categories,
893
- device=config.DEVICE,
1311
+ device=device,
894
1312
  use_xlm_tokenizer=use_xlm_tokenizer,
895
1313
  )
896
- if profile.model_wrapper in ("HFLayoutLmv2SequenceClassifier",):
1314
+ if model_wrapper in ("HFLayoutLmv2SequenceClassifier",):
897
1315
  return HFLayoutLmv2SequenceClassifier(
898
1316
  path_config_json=config_path,
899
1317
  path_weights=weights_path,
900
1318
  categories=categories,
901
- device=config.DEVICE,
1319
+ device=device,
902
1320
  use_xlm_tokenizer=use_xlm_tokenizer,
903
1321
  )
904
- if profile.model_wrapper in ("HFLayoutLmv3SequenceClassifier",):
1322
+ if model_wrapper in ("HFLayoutLmv3SequenceClassifier",):
905
1323
  return HFLayoutLmv3SequenceClassifier(
906
1324
  path_config_json=config_path,
907
1325
  path_weights=weights_path,
908
1326
  categories=categories,
909
- device=config.DEVICE,
1327
+ device=device,
910
1328
  use_xlm_tokenizer=use_xlm_tokenizer,
911
1329
  )
912
- if profile.model_wrapper in ("HFLiltSequenceClassifier",):
1330
+ if model_wrapper in ("HFLiltSequenceClassifier",):
913
1331
  return HFLiltSequenceClassifier(
914
1332
  path_config_json=config_path,
915
1333
  path_weights=weights_path,
916
1334
  categories=categories,
917
- device=config.DEVICE,
1335
+ device=device,
918
1336
  use_xlm_tokenizer=use_xlm_tokenizer,
919
1337
  )
920
- if profile.model_wrapper in ("HFLmSequenceClassifier",):
1338
+ if model_wrapper in ("HFLmSequenceClassifier",):
921
1339
  return HFLmSequenceClassifier(
922
1340
  path_config_json=config_path,
923
1341
  path_weights=weights_path,
924
1342
  categories=categories,
925
- device=config.DEVICE,
1343
+ device=device,
926
1344
  use_xlm_tokenizer=use_xlm_tokenizer,
927
1345
  )
928
- raise ValueError(f"Unsupported model wrapper: {profile.model_wrapper}")
1346
+ raise ValueError(f"Unsupported model wrapper: {model_wrapper}")
929
1347
 
930
1348
  @staticmethod
931
1349
  def build_sequence_classifier(config: AttrDict) -> Union[LayoutSequenceModels, LmSequenceModels]:
@@ -938,18 +1356,31 @@ class ServiceFactory:
938
1356
  Returns:
939
1357
  A sequence classifier instance constructed according to the specified configuration.
940
1358
  """
941
- return ServiceFactory._build_sequence_classifier(config)
1359
+ sequence_classifier_kwargs = ServiceFactory._get_sequence_classifier_kwargs_from_config(config)
1360
+ return ServiceFactory._build_sequence_classifier(**sequence_classifier_kwargs)
1361
+
1362
+ @staticmethod
1363
+ def _get_sequence_classifier_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
1364
+ """
1365
+ Extracting sequence classifier service kwargs from config.
1366
+
1367
+ Args:
1368
+ config: Configuration object.
1369
+ """
1370
+ return {
1371
+ "use_other_as_default_category": config.LM_SEQUENCE_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
1372
+ }
942
1373
 
943
1374
  @staticmethod
944
1375
  def _build_sequence_classifier_service(
945
- config: AttrDict, sequence_classifier: Union[LayoutSequenceModels, LmSequenceModels]
1376
+ sequence_classifier: Union[LayoutSequenceModels, LmSequenceModels], use_other_as_default_category: bool
946
1377
  ) -> LMSequenceClassifierService:
947
1378
  """
948
1379
  Building a sequence classifier service.
949
1380
 
950
1381
  Args:
951
- config: Configuration object.
952
1382
  sequence_classifier: Sequence classifier instance.
1383
+ use_other_as_default_category: Whether to use other as default category.
953
1384
 
954
1385
  Returns:
955
1386
  LMSequenceClassifierService: Text order service instance.
@@ -961,7 +1392,7 @@ class ServiceFactory:
961
1392
  return LMSequenceClassifierService(
962
1393
  tokenizer=tokenizer_fast,
963
1394
  language_model=sequence_classifier,
964
- use_other_as_default_category=config.LM_SEQUENCE_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
1395
+ use_other_as_default_category=use_other_as_default_category,
965
1396
  )
966
1397
 
967
1398
  @staticmethod
@@ -978,60 +1409,93 @@ class ServiceFactory:
978
1409
  Returns:
979
1410
  LMSequenceClassifierService: Text order service instance.
980
1411
  """
981
- return ServiceFactory._build_sequence_classifier_service(config, sequence_classifier)
1412
+ sequence_classifier_service_kwargs = ServiceFactory._get_sequence_classifier_service_kwargs_from_config(config)
1413
+ return ServiceFactory._build_sequence_classifier_service(
1414
+ sequence_classifier, **sequence_classifier_service_kwargs
1415
+ )
982
1416
 
983
1417
  @staticmethod
984
- def _build_token_classifier(config: AttrDict) -> Union[LayoutTokenModels, LmTokenModels]:
1418
+ def _get_token_classifier_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
985
1419
  """
986
- Builds and returns a token classifier model.
1420
+ Extracting token classifier kwargs from config.
987
1421
 
988
1422
  Args:
989
1423
  config: Configuration object.
990
-
991
- Returns:
992
- The instantiated token classifier model.
993
1424
  """
994
1425
  config_path = ModelCatalog.get_full_path_configs(config.LM_TOKEN_CLASS.WEIGHTS)
995
1426
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs(config.LM_TOKEN_CLASS.WEIGHTS)
996
1427
  profile = ModelCatalog.get_profile(config.LM_TOKEN_CLASS.WEIGHTS)
997
1428
  categories = profile.categories if profile.categories is not None else {}
998
1429
  use_xlm_tokenizer = "xlm_tokenizer" == profile.architecture
999
- if profile.model_wrapper in ("HFLayoutLmTokenClassifier",):
1430
+
1431
+ return {
1432
+ "config_path": config_path,
1433
+ "weights_path": weights_path,
1434
+ "categories": categories,
1435
+ "device": config.DEVICE,
1436
+ "use_xlm_tokenizer": use_xlm_tokenizer,
1437
+ "model_wrapper": profile.model_wrapper,
1438
+ }
1439
+
1440
+ @staticmethod
1441
+ def _build_token_classifier(
1442
+ config_path: str,
1443
+ weights_path: str,
1444
+ categories: Mapping[int, Union[ObjectTypes, str]],
1445
+ device: Literal["cpu", "cuda"],
1446
+ use_xlm_tokenizer: bool,
1447
+ model_wrapper: str,
1448
+ ) -> Union[LayoutTokenModels, LmTokenModels]:
1449
+ """
1450
+ Builds and returns a token classifier model.
1451
+
1452
+ Args:
1453
+ config_path: Path to model configuration.
1454
+ weights_path: Path to model weights.
1455
+ categories: Model categories mapping.
1456
+ device: Device to run model on.
1457
+ use_xlm_tokenizer: Whether to use XLM tokenizer.
1458
+ model_wrapper: Model wrapper class name.
1459
+
1460
+ Returns:
1461
+ The instantiated token classifier model.
1462
+ """
1463
+ if model_wrapper in ("HFLayoutLmTokenClassifier",):
1000
1464
  return HFLayoutLmTokenClassifier(
1001
1465
  path_config_json=config_path,
1002
1466
  path_weights=weights_path,
1003
1467
  categories=categories,
1004
- device=config.DEVICE,
1468
+ device=device,
1005
1469
  use_xlm_tokenizer=use_xlm_tokenizer,
1006
1470
  )
1007
- if profile.model_wrapper in ("HFLayoutLmv2TokenClassifier",):
1471
+ if model_wrapper in ("HFLayoutLmv2TokenClassifier",):
1008
1472
  return HFLayoutLmv2TokenClassifier(
1009
1473
  path_config_json=config_path,
1010
1474
  path_weights=weights_path,
1011
1475
  categories=categories,
1012
- device=config.DEVICE,
1476
+ device=device,
1013
1477
  )
1014
- if profile.model_wrapper in ("HFLayoutLmv3TokenClassifier",):
1478
+ if model_wrapper in ("HFLayoutLmv3TokenClassifier",):
1015
1479
  return HFLayoutLmv3TokenClassifier(
1016
1480
  path_config_json=config_path,
1017
1481
  path_weights=weights_path,
1018
1482
  categories=categories,
1019
- device=config.DEVICE,
1483
+ device=device,
1020
1484
  )
1021
- if profile.model_wrapper in ("HFLiltTokenClassifier",):
1485
+ if model_wrapper in ("HFLiltTokenClassifier",):
1022
1486
  return HFLiltTokenClassifier(
1023
1487
  path_config_json=config_path,
1024
1488
  path_weights=weights_path,
1025
1489
  categories=categories,
1026
- device=config.DEVICE,
1490
+ device=device,
1027
1491
  )
1028
- if profile.model_wrapper in ("HFLmTokenClassifier",):
1492
+ if model_wrapper in ("HFLmTokenClassifier",):
1029
1493
  return HFLmTokenClassifier(
1030
1494
  path_config_json=config_path,
1031
1495
  path_weights=weights_path,
1032
1496
  categories=categories,
1033
1497
  )
1034
- raise ValueError(f"Unsupported model wrapper: {profile.model_wrapper}")
1498
+ raise ValueError(f"Unsupported model wrapper: {model_wrapper}")
1035
1499
 
1036
1500
  @staticmethod
1037
1501
  def build_token_classifier(config: AttrDict) -> Union[LayoutTokenModels, LmTokenModels]:
@@ -1044,18 +1508,38 @@ class ServiceFactory:
1044
1508
  Returns:
1045
1509
  The instantiated token classifier model.
1046
1510
  """
1047
- return ServiceFactory._build_token_classifier(config)
1511
+ token_classifier_kwargs = ServiceFactory._get_token_classifier_kwargs_from_config(config)
1512
+ return ServiceFactory._build_token_classifier(**token_classifier_kwargs)
1513
+
1514
+ @staticmethod
1515
+ def _get_token_classifier_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
1516
+ """
1517
+ Extracting token classifier service kwargs from config.
1518
+
1519
+ Args:
1520
+ config: Configuration object.
1521
+ """
1522
+ return {
1523
+ "use_other_as_default_category": config.LM_TOKEN_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
1524
+ "segment_positions": config.LM_TOKEN_CLASS.SEGMENT_POSITIONS,
1525
+ "sliding_window_stride": config.LM_TOKEN_CLASS.SLIDING_WINDOW_STRIDE,
1526
+ }
1048
1527
 
1049
1528
  @staticmethod
1050
1529
  def _build_token_classifier_service(
1051
- config: AttrDict, token_classifier: Union[LayoutTokenModels, LmTokenModels]
1530
+ token_classifier: Union[LayoutTokenModels, LmTokenModels],
1531
+ use_other_as_default_category: bool,
1532
+ segment_positions: Union[LayoutType, Sequence[LayoutType], None],
1533
+ sliding_window_stride: int,
1052
1534
  ) -> LMTokenClassifierService:
1053
1535
  """
1054
1536
  Building a token classifier service.
1055
1537
 
1056
1538
  Args:
1057
- config: Configuration object.
1058
1539
  token_classifier: Token classifier instance.
1540
+ use_other_as_default_category: Whether to use other as default category.
1541
+ segment_positions: Segment positions configuration.
1542
+ sliding_window_stride: Sliding window stride.
1059
1543
 
1060
1544
  Returns:
1061
1545
  A LMTokenClassifierService instance.
@@ -1067,9 +1551,9 @@ class ServiceFactory:
1067
1551
  return LMTokenClassifierService(
1068
1552
  tokenizer=tokenizer_fast,
1069
1553
  language_model=token_classifier,
1070
- use_other_as_default_category=config.LM_TOKEN_CLASS.USE_OTHER_AS_DEFAULT_CATEGORY,
1071
- segment_positions=config.LM_TOKEN_CLASS.SEGMENT_POSITIONS,
1072
- sliding_window_stride=config.LM_TOKEN_CLASS.SLIDING_WINDOW_STRIDE,
1554
+ use_other_as_default_category=use_other_as_default_category,
1555
+ segment_positions=segment_positions,
1556
+ sliding_window_stride=sliding_window_stride,
1073
1557
  )
1074
1558
 
1075
1559
  @staticmethod
@@ -1086,23 +1570,44 @@ class ServiceFactory:
1086
1570
  Returns:
1087
1571
  A LMTokenClassifierService instance.
1088
1572
  """
1089
- return ServiceFactory._build_token_classifier_service(config, token_classifier)
1573
+ token_classifier_service_kwargs = ServiceFactory._get_token_classifier_service_kwargs_from_config(config)
1574
+ return ServiceFactory._build_token_classifier_service(token_classifier, **token_classifier_service_kwargs)
1090
1575
 
1091
1576
  @staticmethod
1092
- def _build_page_parsing_service(config: AttrDict) -> PageParsingService:
1577
+ def _get_page_parsing_service_kwargs_from_config(config: AttrDict) -> dict[str, Any]:
1093
1578
  """
1094
- Building a page parsing service.
1579
+ Extracting page parsing service kwargs from config.
1095
1580
 
1096
1581
  Args:
1097
1582
  config: Configuration object.
1583
+ """
1584
+ return {
1585
+ "text_container": config.TEXT_CONTAINER,
1586
+ "floating_text_block_categories": config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
1587
+ "include_residual_text_container": config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
1588
+ }
1589
+
1590
+ @staticmethod
1591
+ def _build_page_parsing_service(
1592
+ text_container: Union[ObjectTypes, str],
1593
+ floating_text_block_categories: Sequence[str],
1594
+ include_residual_text_container: bool,
1595
+ ) -> PageParsingService:
1596
+ """
1597
+ Building a page parsing service.
1598
+
1599
+ Args:
1600
+ text_container: Text container categories.
1601
+ floating_text_block_categories: Floating text block categories.
1602
+ include_residual_text_container: Whether to include residual text container.
1098
1603
 
1099
1604
  Returns:
1100
1605
  PageParsingService: Page parsing service instance.
1101
1606
  """
1102
1607
  return PageParsingService(
1103
- text_container=config.TEXT_CONTAINER,
1104
- floating_text_block_categories=config.TEXT_ORDERING.FLOATING_TEXT_BLOCK_CATEGORIES,
1105
- include_residual_text_container=config.TEXT_ORDERING.INCLUDE_RESIDUAL_TEXT_CONTAINER,
1608
+ text_container=text_container,
1609
+ floating_text_block_categories=floating_text_block_categories,
1610
+ include_residual_text_container=include_residual_text_container,
1106
1611
  )
1107
1612
 
1108
1613
  @staticmethod
@@ -1116,7 +1621,8 @@ class ServiceFactory:
1116
1621
  Returns:
1117
1622
  PageParsingService: Page parsing service instance.
1118
1623
  """
1119
- return ServiceFactory._build_page_parsing_service(config)
1624
+ page_parsing_service_kwargs = ServiceFactory._get_page_parsing_service_kwargs_from_config(config)
1625
+ return ServiceFactory._build_page_parsing_service(**page_parsing_service_kwargs)
1120
1626
 
1121
1627
  @staticmethod
1122
1628
  def build_analyzer(config: AttrDict) -> DoctectionPipe: