deepdoctection 0.39.7__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -25,7 +25,7 @@ from .utils.logger import LoggingRecord, logger
25
25
 
26
26
  # pylint: enable=wrong-import-position
27
27
 
28
- __version__ = "0.39.7"
28
+ __version__ = "0.41.0"
29
29
 
30
30
  _IMPORT_STRUCTURE = {
31
31
  "analyzer": ["config_sanity_checks", "get_dd_analyzer", "ServiceFactory"],
@@ -90,8 +90,6 @@ _IMPORT_STRUCTURE = {
90
90
  "convert_np_array_to_b64_b",
91
91
  "convert_bytes_to_np_array",
92
92
  "convert_pdf_bytes_to_np_array_v2",
93
- "box_to_point4",
94
- "point4_to_box",
95
93
  "as_dict",
96
94
  "ImageAnnotationBaseView",
97
95
  "Image",
@@ -164,6 +162,7 @@ _IMPORT_STRUCTURE = {
164
162
  "LMSequenceClassifier",
165
163
  "LanguageDetector",
166
164
  "ImageTransformer",
165
+ "DeterministicImageTransformer",
167
166
  "InferenceResize",
168
167
  "D2FrcnnDetector",
169
168
  "D2FrcnnTracingDetector",
@@ -260,6 +259,7 @@ _IMPORT_STRUCTURE = {
260
259
  "ImageCroppingService",
261
260
  "IntersectionMatcher",
262
261
  "NeighbourMatcher",
262
+ "FamilyCompound",
263
263
  "MatchingService",
264
264
  "PageParsingService",
265
265
  "AnnotationNmsService",
@@ -400,11 +400,14 @@ _IMPORT_STRUCTURE = {
400
400
  "get_type",
401
401
  "get_tqdm",
402
402
  "get_tqdm_default_kwargs",
403
+ "box_to_point4",
404
+ "point4_to_box",
403
405
  "ResizeTransform",
404
406
  "InferenceResize",
405
407
  "normalize_image",
406
408
  "pad_image",
407
409
  "PadTransform",
410
+ "RotationTransform",
408
411
  "delete_keys_from_dict",
409
412
  "split_string",
410
413
  "string_to_dict",
@@ -72,7 +72,6 @@ cfg.SEGMENTATION.THRESHOLD_COLS = 0.4
72
72
  cfg.SEGMENTATION.FULL_TABLE_TILING = True
73
73
  cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS = 0.001
74
74
  cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS = 0.001
75
- cfg.SEGMENTATION.CELL_CATEGORY_ID = 12
76
75
  cfg.SEGMENTATION.TABLE_NAME = LayoutType.TABLE
77
76
  cfg.SEGMENTATION.PUBTABLES_CELL_NAMES = [
78
77
  CellType.SPANNING,
@@ -35,13 +35,14 @@ from ..extern.tpdetect import TPFrcnnDetector
35
35
  from ..pipe.base import PipelineComponent
36
36
  from ..pipe.common import (
37
37
  AnnotationNmsService,
38
+ FamilyCompound,
38
39
  IntersectionMatcher,
39
40
  MatchingService,
40
41
  NeighbourMatcher,
41
42
  PageParsingService,
42
43
  )
43
44
  from ..pipe.doctectionpipe import DoctectionPipe
44
- from ..pipe.layout import ImageLayoutService
45
+ from ..pipe.layout import ImageLayoutService, skip_if_category_or_service_extracted
45
46
  from ..pipe.order import TextOrderService
46
47
  from ..pipe.refine import TableSegmentationRefinementService
47
48
  from ..pipe.segment import PubtablesSegmentationService, TableSegmentationService
@@ -196,7 +197,7 @@ class ServiceFactory:
196
197
  getattr(config.PT, mode).PAD.BOTTOM,
197
198
  getattr(config.PT, mode).PAD.LEFT,
198
199
  )
199
- return PadTransform(top=top, right=right, bottom=bottom, left=left) #
200
+ return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left) #
200
201
 
201
202
  @staticmethod
202
203
  def build_padder(config: AttrDict, mode: str) -> PadTransform:
@@ -284,7 +285,6 @@ class ServiceFactory:
284
285
  return SubImageLayoutService(
285
286
  sub_image_detector=detector,
286
287
  sub_image_names=[LayoutType.TABLE, LayoutType.TABLE_ROTATED],
287
- category_id_mapping=None,
288
288
  detect_result_generator=detect_result_generator,
289
289
  padder=padder,
290
290
  )
@@ -405,7 +405,6 @@ class ServiceFactory:
405
405
  tile_table_with_items=config.SEGMENTATION.FULL_TABLE_TILING,
406
406
  remove_iou_threshold_rows=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
407
407
  remove_iou_threshold_cols=config.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
408
- cell_class_id=config.SEGMENTATION.CELL_CATEGORY_ID,
409
408
  table_name=config.SEGMENTATION.TABLE_NAME,
410
409
  cell_names=config.SEGMENTATION.PUBTABLES_CELL_NAMES,
411
410
  spanning_cell_names=config.SEGMENTATION.PUBTABLES_SPANNING_CELL_NAMES,
@@ -516,6 +515,15 @@ class ServiceFactory:
516
515
  """
517
516
  return ServiceFactory._build_pdf_miner_text_service(detector)
518
517
 
518
+ @staticmethod
519
+ def _build_doctr_word_detector_service(detector: DoctrTextlineDetector) -> ImageLayoutService:
520
+ """Building a Doctr word detector service
521
+
522
+ :param detector: DoctrTextlineDetector
523
+ :return: ImageLayoutService
524
+ """
525
+ return ImageLayoutService(layout_detector=detector, to_image=True, crop_image=True)
526
+
519
527
  @staticmethod
520
528
  def build_doctr_word_detector_service(detector: DoctrTextlineDetector) -> ImageLayoutService:
521
529
  """Building a Doctr word detector service
@@ -523,9 +531,7 @@ class ServiceFactory:
523
531
  :param detector: DoctrTextlineDetector
524
532
  :return: ImageLayoutService
525
533
  """
526
- return ImageLayoutService(
527
- layout_detector=detector, to_image=True, crop_image=True, skip_if_layout_extracted=True
528
- )
534
+ return ServiceFactory._build_doctr_word_detector_service(detector)
529
535
 
530
536
  @staticmethod
531
537
  def _build_text_extraction_service(
@@ -539,7 +545,6 @@ class ServiceFactory:
539
545
  """
540
546
  return TextExtractionService(
541
547
  detector,
542
- skip_if_text_extracted=config.USE_PDF_MINER,
543
548
  extract_from_roi=config.TEXT_CONTAINER if config.OCR.USE_DOCTR else None,
544
549
  )
545
550
 
@@ -567,11 +572,16 @@ class ServiceFactory:
567
572
  threshold=config.WORD_MATCHING.THRESHOLD,
568
573
  max_parent_only=config.WORD_MATCHING.MAX_PARENT_ONLY,
569
574
  )
575
+ family_compounds = [
576
+ FamilyCompound(
577
+ parent_categories=config.WORD_MATCHING.PARENTAL_CATEGORIES,
578
+ child_categories=config.TEXT_CONTAINER,
579
+ relationship_key=Relationships.CHILD,
580
+ )
581
+ ]
570
582
  return MatchingService(
571
- parent_categories=config.WORD_MATCHING.PARENTAL_CATEGORIES,
572
- child_categories=config.TEXT_CONTAINER,
583
+ family_compounds=family_compounds,
573
584
  matcher=matcher,
574
- relationship_key=Relationships.CHILD,
575
585
  )
576
586
 
577
587
  @staticmethod
@@ -591,11 +601,16 @@ class ServiceFactory:
591
601
  :return: MatchingService
592
602
  """
593
603
  neighbor_matcher = NeighbourMatcher()
604
+ family_compounds = [
605
+ FamilyCompound(
606
+ parent_categories=config.LAYOUT_LINK.PARENTAL_CATEGORIES,
607
+ child_categories=config.LAYOUT_LINK.CHILD_CATEGORIES,
608
+ relationship_key=Relationships.LAYOUT_LINK,
609
+ )
610
+ ]
594
611
  return MatchingService(
595
- parent_categories=config.LAYOUT_LINK.PARENTAL_CATEGORIES,
596
- child_categories=config.LAYOUT_LINK.CHILD_CATEGORIES,
612
+ family_compounds=family_compounds,
597
613
  matcher=neighbor_matcher,
598
- relationship_key=Relationships.LAYOUT_LINK,
599
614
  )
600
615
 
601
616
  @staticmethod
@@ -699,9 +714,11 @@ class ServiceFactory:
699
714
  table_refinement_service = ServiceFactory.build_table_refinement_service(config)
700
715
  pipe_component_list.append(table_refinement_service)
701
716
 
717
+ d_text_service_id = ""
702
718
  if config.USE_PDF_MINER:
703
719
  pdf_miner = ServiceFactory.build_pdf_text_detector(config)
704
720
  d_text = ServiceFactory.build_pdf_miner_text_service(pdf_miner)
721
+ d_text_service_id = d_text.service_id
705
722
  pipe_component_list.append(d_text)
706
723
 
707
724
  # setup ocr
@@ -710,10 +727,14 @@ class ServiceFactory:
710
727
  if config.OCR.USE_DOCTR:
711
728
  word_detector = ServiceFactory.build_doctr_word_detector(config)
712
729
  word_service = ServiceFactory.build_doctr_word_detector_service(word_detector)
730
+ word_service.set_inbound_filter(skip_if_category_or_service_extracted(service_ids=d_text_service_id))
713
731
  pipe_component_list.append(word_service)
714
732
 
715
733
  ocr_detector = ServiceFactory.build_ocr_detector(config)
716
734
  text_extraction_service = ServiceFactory.build_text_extraction_service(config, ocr_detector)
735
+ text_extraction_service.set_inbound_filter(
736
+ skip_if_category_or_service_extracted(service_ids=d_text_service_id)
737
+ )
717
738
  pipe_component_list.append(text_extraction_service)
718
739
 
719
740
  if config.USE_PDF_MINER or config.USE_OCR:
@@ -27,7 +27,6 @@ from typing import Any, Optional, Union, no_type_check
27
27
 
28
28
  import numpy as np
29
29
  from numpy import uint8
30
- from numpy.typing import NDArray
31
30
  from pypdf import PdfReader
32
31
 
33
32
  from ..utils.develop import deprecated
@@ -42,8 +41,6 @@ __all__ = [
42
41
  "convert_np_array_to_b64_b",
43
42
  "convert_bytes_to_np_array",
44
43
  "convert_pdf_bytes_to_np_array_v2",
45
- "box_to_point4",
46
- "point4_to_box",
47
44
  "as_dict",
48
45
  ]
49
46
 
@@ -187,24 +184,3 @@ def convert_pdf_bytes_to_np_array_v2(
187
184
  width = shape[2] - shape[0]
188
185
  return pdf_to_np_array(pdf_bytes, size=(int(width), int(height))) # type: ignore
189
186
  return pdf_to_np_array(pdf_bytes, dpi=dpi)
190
-
191
-
192
- def box_to_point4(boxes: NDArray[np.float32]) -> NDArray[np.float32]:
193
- """
194
- :param boxes: nx4
195
- :return: (nx4)x2
196
- """
197
- box = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]]
198
- box = box.reshape((-1, 2))
199
- return box
200
-
201
-
202
- def point4_to_box(points: NDArray[np.float32]) -> NDArray[np.float32]:
203
- """
204
- :param points: (nx4)x2
205
- :return: nx4 boxes (x1y1x2y2)
206
- """
207
- points = points.reshape((-1, 4, 2))
208
- min_xy = points.min(axis=1) # nx2
209
- max_xy = points.max(axis=1) # nx2
210
- return np.concatenate((min_xy, max_xy), axis=1)
@@ -342,7 +342,7 @@ class Image:
342
342
  self,
343
343
  category_names: Optional[Union[str, ObjectTypes, Sequence[Union[str, ObjectTypes]]]] = None,
344
344
  annotation_ids: Optional[Union[str, Sequence[str]]] = None,
345
- service_id: Optional[Union[str, Sequence[str]]] = None,
345
+ service_ids: Optional[Union[str, Sequence[str]]] = None,
346
346
  model_id: Optional[Union[str, Sequence[str]]] = None,
347
347
  session_ids: Optional[Union[str, Sequence[str]]] = None,
348
348
  ignore_inactive: bool = True,
@@ -356,7 +356,7 @@ class Image:
356
356
 
357
357
  :param category_names: A single name or list of names
358
358
  :param annotation_ids: A single id or list of ids
359
- :param service_id: A single service name or list of service names
359
+ :param service_ids: A single service name or list of service names
360
360
  :param model_id: A single model name or list of model names
361
361
  :param session_ids: A single session id or list of session ids
362
362
  :param ignore_inactive: If set to `True` only active annotations are returned.
@@ -372,7 +372,7 @@ class Image:
372
372
  )
373
373
 
374
374
  ann_ids = [annotation_ids] if isinstance(annotation_ids, str) else annotation_ids
375
- service_id = [service_id] if isinstance(service_id, str) else service_id
375
+ service_ids = [service_ids] if isinstance(service_ids, str) else service_ids
376
376
  model_id = [model_id] if isinstance(model_id, str) else model_id
377
377
  session_id = [session_ids] if isinstance(session_ids, str) else session_ids
378
378
 
@@ -387,8 +387,8 @@ class Image:
387
387
  if ann_ids is not None:
388
388
  anns = filter(lambda x: x.annotation_id in ann_ids, anns)
389
389
 
390
- if service_id is not None:
391
- anns = filter(lambda x: x.service_id in service_id, anns)
390
+ if service_ids is not None:
391
+ anns = filter(lambda x: x.service_id in service_ids, anns)
392
392
 
393
393
  if model_id is not None:
394
394
  anns = filter(lambda x: x.model_id in model_id, anns)
@@ -41,12 +41,11 @@ from ..utils.settings import (
41
41
  WordType,
42
42
  get_type,
43
43
  )
44
- from ..utils.transform import ResizeTransform
44
+ from ..utils.transform import ResizeTransform, box_to_point4, point4_to_box
45
45
  from ..utils.types import HTML, AnnotationDict, Chunks, ImageDict, PathLikeOrStr, PixelValues, Text_, csv
46
46
  from ..utils.viz import draw_boxes, interactive_imshow, viz_handler
47
47
  from .annotation import CategoryAnnotation, ContainerAnnotation, ImageAnnotation, ann_from_dict
48
48
  from .box import BoundingBox, crop_box_from_image
49
- from .convert import box_to_point4, point4_to_box
50
49
  from .image import Image
51
50
 
52
51
 
@@ -659,7 +658,7 @@ class Page(Image):
659
658
  self,
660
659
  category_names: Optional[Union[str, ObjectTypes, Sequence[Union[str, ObjectTypes]]]] = None,
661
660
  annotation_ids: Optional[Union[str, Sequence[str]]] = None,
662
- service_id: Optional[Union[str, Sequence[str]]] = None,
661
+ service_ids: Optional[Union[str, Sequence[str]]] = None,
663
662
  model_id: Optional[Union[str, Sequence[str]]] = None,
664
663
  session_ids: Optional[Union[str, Sequence[str]]] = None,
665
664
  ignore_inactive: bool = True,
@@ -676,7 +675,7 @@ class Page(Image):
676
675
 
677
676
  :param category_names: A single name or list of names
678
677
  :param annotation_ids: A single id or list of ids
679
- :param service_id: A single service name or list of service names
678
+ :param service_ids: A single service name or list of service names
680
679
  :param model_id: A single model name or list of model names
681
680
  :param session_ids: A single session id or list of session ids
682
681
  :param ignore_inactive: If set to `True` only active annotations are returned.
@@ -691,7 +690,7 @@ class Page(Image):
691
690
  else tuple(get_type(cat_name) for cat_name in category_names)
692
691
  )
693
692
  ann_ids = [annotation_ids] if isinstance(annotation_ids, str) else annotation_ids
694
- service_id = [service_id] if isinstance(service_id, str) else service_id
693
+ service_ids = [service_ids] if isinstance(service_ids, str) else service_ids
695
694
  model_id = [model_id] if isinstance(model_id, str) else model_id
696
695
  session_id = [session_ids] if isinstance(session_ids, str) else session_ids
697
696
 
@@ -706,8 +705,8 @@ class Page(Image):
706
705
  if ann_ids is not None:
707
706
  anns = filter(lambda x: x.annotation_id in ann_ids, anns)
708
707
 
709
- if service_id is not None:
710
- anns = filter(lambda x: x.generating_service in service_id, anns)
708
+ if service_ids is not None:
709
+ anns = filter(lambda x: x.generating_service in service_ids, anns)
711
710
 
712
711
  if model_id is not None:
713
712
  anns = filter(lambda x: x.generating_model in model_id, anns)
@@ -369,7 +369,9 @@ class MergeDataset(DatasetBase):
369
369
  self.buffer_datasets(**dataflow_build_kwargs)
370
370
  split_defaultdict = defaultdict(list)
371
371
  for image in self.datapoint_list: # type: ignore
372
- split_defaultdict[ann_id_to_split[image.image_id]].append(image)
372
+ maybe_image_id = ann_id_to_split.get(image.image_id)
373
+ if maybe_image_id is not None:
374
+ split_defaultdict[maybe_image_id].append(image)
373
375
  train_dataset = split_defaultdict["train"]
374
376
  val_dataset = split_defaultdict["val"]
375
377
  test_dataset = split_defaultdict["test"]
@@ -26,6 +26,7 @@ from dataclasses import dataclass, field
26
26
  from types import MappingProxyType
27
27
  from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union, overload
28
28
 
29
+ import numpy as np
29
30
  from lazy_imports import try_import
30
31
 
31
32
  from ..utils.identifier import get_uuid_from_str
@@ -38,6 +39,7 @@ from ..utils.settings import (
38
39
  token_class_tag_to_token_class_with_tag,
39
40
  token_class_with_tag_to_token_class_and_tag,
40
41
  )
42
+ from ..utils.transform import BaseTransform, box_to_point4, point4_to_box
41
43
  from ..utils.types import JsonDict, PixelValues, Requirement
42
44
 
43
45
  if TYPE_CHECKING:
@@ -621,7 +623,7 @@ class ImageTransformer(PredictorBase, ABC):
621
623
  """
622
624
 
623
625
  @abstractmethod
624
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
626
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
625
627
  """
626
628
  Abstract method transform
627
629
  """
@@ -641,3 +643,108 @@ class ImageTransformer(PredictorBase, ABC):
641
643
  def get_category_names(self) -> tuple[ObjectTypes, ...]:
642
644
  """returns category names"""
643
645
  raise NotImplementedError()
646
+
647
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
648
+ """
649
+ Transform coordinates aligned with the transform_image method.
650
+
651
+ :param detect_results: List of DetectionResults
652
+ :return: List of DetectionResults. If you pass uuid it is possible to track the transformed bounding boxes.
653
+ """
654
+
655
+ raise NotImplementedError()
656
+
657
+ def inverse_transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
658
+ """
659
+ Inverse transform coordinates aligned with the transform_image method. Composing transform_coords with
660
+ inverse_transform_coords should return the original coordinates.
661
+
662
+ :param detect_results: List of DetectionResults
663
+ :return: List of DetectionResults. If you pass uuid it is possible to track the transformed bounding boxes.
664
+ """
665
+
666
+ raise NotImplementedError()
667
+
668
+
669
+ class DeterministicImageTransformer(ImageTransformer):
670
+ """
671
+ A wrapper for BaseTransform classes that implements the ImageTransformer interface.
672
+
673
+ This class provides a bridge between the BaseTransform system (which handles image and coordinate
674
+ transformations like rotation, padding, etc.) and the predictors framework by implementing the
675
+ ImageTransformer interface. It allows BaseTransform objects to be used within pipelines that
676
+ expect ImageTransformer components.
677
+
678
+ The transformer performs deterministic transformations on images and their associated coordinates,
679
+ enabling operations like padding, rotation, and other geometric transformations while maintaining
680
+ the relationship between image content and annotation coordinates.
681
+
682
+ :param base_transform: A BaseTransform instance that defines the actual transformation operations
683
+ to be applied to images and coordinates.
684
+ """
685
+
686
+ def __init__(self, base_transform: BaseTransform):
687
+ """
688
+ Initialize the DeterministicImageTransformer with a BaseTransform instance.
689
+
690
+ :param base_transform: A BaseTransform instance that defines the actual transformation operations
691
+ """
692
+ self.base_transform = base_transform
693
+ self.name = base_transform.__class__.__name__
694
+ self.model_id = self.get_model_id()
695
+
696
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
697
+ return self.base_transform.apply_image(np_img)
698
+
699
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
700
+ boxes = np.array([detect_result.box for detect_result in detect_results])
701
+ # boxes = box_to_point4(boxes)
702
+ boxes = self.base_transform.apply_coords(boxes)
703
+ # boxes = point4_to_box(boxes)
704
+ detection_results = []
705
+ for idx, detect_result in enumerate(detect_results):
706
+ detection_results.append(
707
+ DetectionResult(
708
+ box=boxes[idx, :].tolist(),
709
+ class_name=detect_result.class_name,
710
+ class_id=detect_result.class_id,
711
+ score=detect_result.score,
712
+ absolute_coords=detect_result.absolute_coords,
713
+ uuid=detect_result.uuid,
714
+ )
715
+ )
716
+ return detection_results
717
+
718
+ def inverse_transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
719
+ boxes = np.array([detect_result.box for detect_result in detect_results])
720
+ boxes = box_to_point4(boxes)
721
+ boxes = self.base_transform.inverse_apply_coords(boxes)
722
+ boxes = point4_to_box(boxes)
723
+ detection_results = []
724
+ for idx, detect_result in enumerate(detect_results):
725
+ detection_results.append(
726
+ DetectionResult(
727
+ box=boxes[idx, :].tolist(),
728
+ class_id=detect_result.class_id,
729
+ score=detect_result.score,
730
+ absolute_coords=detect_result.absolute_coords,
731
+ uuid=detect_result.uuid,
732
+ )
733
+ )
734
+ return detection_results
735
+
736
+ def clone(self) -> DeterministicImageTransformer:
737
+ return self.__class__(self.base_transform)
738
+
739
+ def predict(self, np_img: PixelValues) -> DetectionResult:
740
+ detect_result = DetectionResult()
741
+ for init_arg in self.base_transform.get_init_args():
742
+ setattr(detect_result, init_arg, getattr(self.base_transform, init_arg))
743
+ return detect_result
744
+
745
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
746
+ return self.base_transform.get_category_names()
747
+
748
+ @classmethod
749
+ def get_requirements(cls) -> list[Requirement]:
750
+ return []
@@ -43,7 +43,7 @@ class Jdeskewer(ImageTransformer):
43
43
  self.model_id = self.get_model_id()
44
44
  self.min_angle_rotation = min_angle_rotation
45
45
 
46
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
46
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
47
47
  """
48
48
  Rotation of the image according to the angle determined by the jdeskew estimator.
49
49
 
@@ -514,8 +514,9 @@ class DocTrRotationTransformer(ImageTransformer):
514
514
  self.number_contours = number_contours
515
515
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
516
516
  self.name = "doctr_rotation_transformer"
517
+ self.model_id = self.get_model_id()
517
518
 
518
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
519
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
519
520
  """
520
521
  Applies the predicted rotation to the image, effectively rotating the image backwards.
521
522
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -423,7 +423,7 @@ class TesseractRotationTransformer(ImageTransformer):
423
423
  self.categories = ModelCategories(init_categories={1: PageType.ANGLE})
424
424
  self.model_id = self.get_model_id()
425
425
 
426
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
426
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
427
427
  """
428
428
  Applies the predicted rotation to the image, effectively rotating the image backwards.
429
429
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -15,9 +15,9 @@ from typing import Any, List, Optional, Tuple
15
15
  import numpy as np
16
16
  from lazy_imports import try_import
17
17
 
18
- from ....datapoint.convert import box_to_point4, point4_to_box
19
18
  from ....utils.error import MalformedData
20
19
  from ....utils.logger import log_once
20
+ from ....utils.transform import box_to_point4, point4_to_box
21
21
  from ....utils.types import JsonDict, PixelValues
22
22
  from .common import filter_boxes_inside_shape, np_iou
23
23
  from .modeling.model_fpn import get_all_anchors_fpn
@@ -31,11 +31,10 @@ import numpy.typing as npt
31
31
  from lazy_imports import try_import
32
32
 
33
33
  from ..datapoint.annotation import ContainerAnnotation
34
- from ..datapoint.convert import box_to_point4, point4_to_box
35
34
  from ..datapoint.image import Image
36
35
  from ..datapoint.view import Page
37
36
  from ..utils.settings import DatasetType, LayoutType, PageType, Relationships, WordType
38
- from ..utils.transform import ResizeTransform, normalize_image
37
+ from ..utils.transform import ResizeTransform, box_to_point4, normalize_image, point4_to_box
39
38
  from ..utils.types import JsonDict
40
39
  from .maputils import curry
41
40
 
@@ -34,13 +34,15 @@ from ..utils.settings import TypeOrStr
34
34
 
35
35
  def match_anns_by_intersection(
36
36
  dp: Image,
37
- parent_ann_category_names: Union[TypeOrStr, Sequence[TypeOrStr]],
38
- child_ann_category_names: Union[TypeOrStr, Sequence[TypeOrStr]],
39
37
  matching_rule: Literal["iou", "ioa"],
40
38
  threshold: float,
41
39
  use_weighted_intersections: bool = False,
40
+ parent_ann_category_names: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
41
+ child_ann_category_names: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
42
42
  parent_ann_ids: Optional[Union[Sequence[str], str]] = None,
43
43
  child_ann_ids: Optional[Union[str, Sequence[str]]] = None,
44
+ parent_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
45
+ child_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
44
46
  max_parent_only: bool = False,
45
47
  ) -> tuple[Any, Any, Sequence[ImageAnnotation], Sequence[ImageAnnotation]]:
46
48
  """
@@ -87,13 +89,19 @@ def match_anns_by_intersection(
87
89
  dates which are not in the list.
88
90
  :param child_ann_ids: Additional filter condition. If some ids are selected, it will ignore all other children
89
91
  candidates which are not in the list.
92
+ :param parent_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
93
+ parent candidates which are not in the list.
94
+ :param child_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
95
+ children candidates which are not in the list.
90
96
  :param max_parent_only: Will assign to each child at most one parent with maximum ioa
91
97
  :return: child indices, parent indices (see Example), list of parent ids and list of children ids.
92
98
  """
93
99
 
94
100
  assert matching_rule in ["iou", "ioa"], "matching rule must be either iou or ioa"
95
101
 
96
- child_anns = dp.get_annotation(annotation_ids=child_ann_ids, category_names=child_ann_category_names)
102
+ child_anns = dp.get_annotation(
103
+ annotation_ids=child_ann_ids, category_names=child_ann_category_names, service_ids=child_ann_service_ids
104
+ )
97
105
  child_ann_boxes = np.array(
98
106
  [
99
107
  ann.get_bounding_box(dp.image_id).transform(dp.width, dp.height, absolute_coords=True).to_list(mode="xyxy")
@@ -101,7 +109,9 @@ def match_anns_by_intersection(
101
109
  ]
102
110
  )
103
111
 
104
- parent_anns = dp.get_annotation(annotation_ids=parent_ann_ids, category_names=parent_ann_category_names)
112
+ parent_anns = dp.get_annotation(
113
+ annotation_ids=parent_ann_ids, category_names=parent_ann_category_names, service_ids=parent_ann_service_ids
114
+ )
105
115
  parent_ann_boxes = np.array(
106
116
  [
107
117
  ann.get_bounding_box(dp.image_id).transform(dp.width, dp.height, absolute_coords=True).to_list(mode="xyxy")
@@ -147,10 +157,12 @@ def match_anns_by_intersection(
147
157
 
148
158
  def match_anns_by_distance(
149
159
  dp: Image,
150
- parent_ann_category_names: Union[TypeOrStr, Sequence[TypeOrStr]],
151
- child_ann_category_names: Union[TypeOrStr, Sequence[TypeOrStr]],
160
+ parent_ann_category_names: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
161
+ child_ann_category_names: Optional[Union[TypeOrStr, Sequence[TypeOrStr]]] = None,
152
162
  parent_ann_ids: Optional[Union[Sequence[str], str]] = None,
153
163
  child_ann_ids: Optional[Union[str, Sequence[str]]] = None,
164
+ parent_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
165
+ child_ann_service_ids: Optional[Union[str, Sequence[str]]] = None,
154
166
  ) -> list[tuple[ImageAnnotation, ImageAnnotation]]:
155
167
  """
156
168
  Generates pairs of parent and child annotations by calculating the euclidean distance between the centers of the
@@ -164,11 +176,19 @@ def match_anns_by_distance(
164
176
  dates which are not in the list.
165
177
  :param child_ann_ids: Additional filter condition. If some ids are selected, it will ignore all other children
166
178
  candidates which are not in the list.
179
+ :param parent_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
180
+ parent candidates which are not in the list.
181
+ :param child_ann_service_ids: Additional filter condition. If some ids are selected, it will ignore all other
182
+ children candidates which are not in the list.
167
183
  :return:
168
184
  """
169
185
 
170
- parent_anns = dp.get_annotation(annotation_ids=parent_ann_ids, category_names=parent_ann_category_names)
171
- child_anns = dp.get_annotation(annotation_ids=child_ann_ids, category_names=child_ann_category_names)
186
+ parent_anns = dp.get_annotation(
187
+ annotation_ids=parent_ann_ids, category_names=parent_ann_category_names, service_ids=parent_ann_service_ids
188
+ )
189
+ child_anns = dp.get_annotation(
190
+ annotation_ids=child_ann_ids, category_names=child_ann_category_names, service_ids=child_ann_service_ids
191
+ )
172
192
  child_centers = [block.get_bounding_box(dp.image_id).center for block in child_anns]
173
193
  parent_centers = [block.get_bounding_box(dp.image_id).center for block in parent_anns]
174
194
  if child_centers and parent_centers:
@@ -75,27 +75,6 @@ class DatapointManager:
75
75
  """
76
76
  assert self.datapoint_is_passed, "Pass datapoint to DatapointManager before creating anns"
77
77
 
78
- def maybe_map_category_id(self, category_id: Union[str, int]) -> int:
79
- """
80
- Maps categories if a category id mapping is provided in `__init__`.
81
-
82
- :param category_id: category id via integer or string.
83
- :return: mapped category id
84
- """
85
- if self.category_id_mapping is None:
86
- return int(category_id)
87
- return self.category_id_mapping[int(category_id)]
88
-
89
- def set_category_id_mapping(self, category_id_mapping: Mapping[int, int]) -> None:
90
- """
91
- In many cases the category ids sent back from a model have to be modified. Pass a mapping from model
92
- category ids to target annotation category ids.
93
-
94
- :param category_id_mapping: A mapping of model category ids (sent from DetectionResult) to category ids (saved
95
- in annotations)
96
- """
97
- self.category_id_mapping = category_id_mapping
98
-
99
78
  def set_image_annotation(
100
79
  self,
101
80
  detect_result: DetectionResult,
@@ -127,13 +106,10 @@ class DatapointManager:
127
106
  :return: the annotation_id of the generated image annotation
128
107
  """
129
108
  self.assert_datapoint_passed()
130
- if detect_result.class_id is None:
131
- raise ValueError("class_id of detect_result cannot be None")
132
109
  if not isinstance(detect_result.box, (list, np.ndarray)):
133
110
  raise TypeError(
134
111
  f"detect_result.box must be of type list or np.ndarray, but is of type {(type(detect_result.box))}"
135
112
  )
136
- detect_result.class_id = self.maybe_map_category_id(detect_result.class_id)
137
113
  with MappingContextManager(
138
114
  dp_name=self.datapoint.file_name, filter_level="annotation", detect_result=asdict(detect_result)
139
115
  ) as annotation_context:
@@ -155,7 +131,7 @@ class DatapointManager:
155
131
  ann = ImageAnnotation(
156
132
  category_name=detect_result.class_name,
157
133
  bounding_box=box,
158
- category_id=detect_result.class_id,
134
+ category_id=detect_result.class_id if detect_result.class_id is not None else DEFAULT_CATEGORY_ID,
159
135
  score=detect_result.score,
160
136
  service_id=self.service_id,
161
137
  model_id=self.model_id,