deepdoctection 0.40.0__py3-none-any.whl → 0.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -25,7 +25,7 @@ from .utils.logger import LoggingRecord, logger
25
25
 
26
26
  # pylint: enable=wrong-import-position
27
27
 
28
- __version__ = "0.40.0"
28
+ __version__ = "0.42.0"
29
29
 
30
30
  _IMPORT_STRUCTURE = {
31
31
  "analyzer": ["config_sanity_checks", "get_dd_analyzer", "ServiceFactory"],
@@ -90,13 +90,12 @@ _IMPORT_STRUCTURE = {
90
90
  "convert_np_array_to_b64_b",
91
91
  "convert_bytes_to_np_array",
92
92
  "convert_pdf_bytes_to_np_array_v2",
93
- "box_to_point4",
94
- "point4_to_box",
95
93
  "as_dict",
96
94
  "ImageAnnotationBaseView",
97
95
  "Image",
98
96
  "Word",
99
97
  "Layout",
98
+ "List",
100
99
  "Cell",
101
100
  "Table",
102
101
  "Page",
@@ -164,6 +163,7 @@ _IMPORT_STRUCTURE = {
164
163
  "LMSequenceClassifier",
165
164
  "LanguageDetector",
166
165
  "ImageTransformer",
166
+ "DeterministicImageTransformer",
167
167
  "InferenceResize",
168
168
  "D2FrcnnDetector",
169
169
  "D2FrcnnTracingDetector",
@@ -401,11 +401,14 @@ _IMPORT_STRUCTURE = {
401
401
  "get_type",
402
402
  "get_tqdm",
403
403
  "get_tqdm_default_kwargs",
404
+ "box_to_point4",
405
+ "point4_to_box",
404
406
  "ResizeTransform",
405
407
  "InferenceResize",
406
408
  "normalize_image",
407
409
  "pad_image",
408
410
  "PadTransform",
411
+ "RotationTransform",
409
412
  "delete_keys_from_dict",
410
413
  "split_string",
411
414
  "string_to_dict",
@@ -438,7 +441,7 @@ if TYPE_CHECKING:
438
441
  from .eval import *
439
442
  from .extern import * # type: ignore
440
443
  from .mapper import * # type: ignore
441
- from .pipe import *
444
+ from .pipe import * # type: ignore
442
445
  from .train import *
443
446
  from .utils import *
444
447
 
@@ -40,7 +40,7 @@ cfg.TF.CELL.FILTER = None
40
40
  cfg.TF.ITEM.WEIGHTS = "item/model-1620000_inf_only.data-00000-of-00001"
41
41
  cfg.TF.ITEM.FILTER = None
42
42
 
43
- cfg.PT.ENFORCE_WEIGHTS = False
43
+ cfg.PT.ENFORCE_WEIGHTS.LAYOUT = True
44
44
  cfg.PT.LAYOUT.WEIGHTS = "layout/d2_model_0829999_layout_inf_only.pt"
45
45
  cfg.PT.LAYOUT.WEIGHTS_TS = "layout/d2_model_0829999_layout_inf_only.ts"
46
46
  cfg.PT.LAYOUT.FILTER = None
@@ -49,6 +49,7 @@ cfg.PT.LAYOUT.PAD.RIGHT = 60
49
49
  cfg.PT.LAYOUT.PAD.BOTTOM = 60
50
50
  cfg.PT.LAYOUT.PAD.LEFT = 60
51
51
 
52
+ cfg.PT.ENFORCE_WEIGHTS.ITEM = True
52
53
  cfg.PT.ITEM.WEIGHTS = "item/d2_model_1639999_item_inf_only.pt"
53
54
  cfg.PT.ITEM.WEIGHTS_TS = "item/d2_model_1639999_item_inf_only.ts"
54
55
  cfg.PT.ITEM.FILTER = None
@@ -57,6 +58,7 @@ cfg.PT.ITEM.PAD.RIGHT = 60
57
58
  cfg.PT.ITEM.PAD.BOTTOM = 60
58
59
  cfg.PT.ITEM.PAD.LEFT = 60
59
60
 
61
+ cfg.PT.ENFORCE_WEIGHTS.CELL = True
60
62
  cfg.PT.CELL.WEIGHTS = "cell/d2_model_1849999_cell_inf_only.pt"
61
63
  cfg.PT.CELL.WEIGHTS_TS = "cell/d2_model_1849999_cell_inf_only.ts"
62
64
  cfg.PT.CELL.FILTER = None
@@ -137,6 +139,7 @@ cfg.TEXT_ORDERING.HEIGHT_TOLERANCE = 2.0
137
139
  cfg.TEXT_ORDERING.PARAGRAPH_BREAK = 0.035
138
140
 
139
141
  cfg.USE_LAYOUT_LINK = False
142
+ cfg.USE_LINE_MATCHER = False
140
143
  cfg.LAYOUT_LINK.PARENTAL_CATEGORIES = []
141
144
  cfg.LAYOUT_LINK.CHILD_CATEGORIES = []
142
145
 
@@ -32,7 +32,7 @@ from ..extern.pt.ptutils import get_torch_device
32
32
  from ..extern.tp.tfutils import disable_tp_layer_logging, get_tf_device
33
33
  from ..pipe.doctectionpipe import DoctectionPipe
34
34
  from ..utils.env_info import ENV_VARS_TRUE
35
- from ..utils.file_utils import tensorpack_available
35
+ from ..utils.file_utils import tensorpack_available, detectron2_available
36
36
  from ..utils.fs import get_configs_dir_path, get_package_path, maybe_copy_config_to_cache
37
37
  from ..utils.logger import LoggingRecord, logger
38
38
  from ..utils.metacfg import set_config_by_yaml
@@ -140,6 +140,12 @@ def get_dd_analyzer(
140
140
  cfg.LANGUAGE = None
141
141
  cfg.LIB = lib
142
142
  cfg.DEVICE = device
143
+ if not detectron2_available() or cfg.PT.LAYOUT.WEIGHTS is None:
144
+ cfg.PT.ENFORCE_WEIGHTS.LAYOUT=False
145
+ if not detectron2_available() or cfg.PT.ITEM.WEIGHTS is None:
146
+ cfg.PT.ENFORCE_WEIGHTS.ITEM=False
147
+ if not detectron2_available() or cfg.PT.CELL.WEIGHTS is None:
148
+ cfg.PT.ENFORCE_WEIGHTS.CELL=False
143
149
  cfg.freeze()
144
150
 
145
151
  if config_overwrite:
@@ -50,7 +50,6 @@ from ..pipe.sub_layout import DetectResultGenerator, SubImageLayoutService
50
50
  from ..pipe.text import TextExtractionService
51
51
  from ..pipe.transform import SimpleTransformService
52
52
  from ..utils.error import DependencyError
53
- from ..utils.file_utils import detectron2_available
54
53
  from ..utils.fs import get_configs_dir_path
55
54
  from ..utils.metacfg import AttrDict
56
55
  from ..utils.settings import CellType, LayoutType, Relationships
@@ -96,12 +95,13 @@ class ServiceFactory:
96
95
  """
97
96
  if config.LIB is None:
98
97
  raise DependencyError("At least one of the env variables DD_USE_TF or DD_USE_TORCH must be set.")
98
+
99
99
  weights = (
100
100
  getattr(config.TF, mode).WEIGHTS
101
101
  if config.LIB == "TF"
102
102
  else (
103
103
  getattr(config.PT, mode).WEIGHTS
104
- if detectron2_available() or config.PT.ENFORCE_WEIGHTS
104
+ if getattr(config.PT.ENFORCE_WEIGHTS,mode)
105
105
  else getattr(config.PT, mode).WEIGHTS_TS
106
106
  )
107
107
  )
@@ -197,7 +197,7 @@ class ServiceFactory:
197
197
  getattr(config.PT, mode).PAD.BOTTOM,
198
198
  getattr(config.PT, mode).PAD.LEFT,
199
199
  )
200
- return PadTransform(top=top, right=right, bottom=bottom, left=left) #
200
+ return PadTransform(pad_top=top, pad_right=right, pad_bottom=bottom, pad_left=left) #
201
201
 
202
202
  @staticmethod
203
203
  def build_padder(config: AttrDict, mode: str) -> PadTransform:
@@ -240,8 +240,6 @@ class ServiceFactory:
240
240
 
241
241
  :param config: configuration object
242
242
  """
243
- if not detectron2_available() and config.LIB == "PT":
244
- raise ModuleNotFoundError("LAYOUT_NMS_PAIRS is only available for detectron2")
245
243
  if not isinstance(config.LAYOUT_NMS_PAIRS.COMBINATIONS, list) and not isinstance(
246
244
  config.LAYOUT_NMS_PAIRS.COMBINATIONS[0], list
247
245
  ):
@@ -577,7 +575,14 @@ class ServiceFactory:
577
575
  parent_categories=config.WORD_MATCHING.PARENTAL_CATEGORIES,
578
576
  child_categories=config.TEXT_CONTAINER,
579
577
  relationship_key=Relationships.CHILD,
580
- )
578
+ ),
579
+ FamilyCompound(
580
+ parent_categories=[LayoutType.LIST],
581
+ child_categories=[LayoutType.LIST_ITEM],
582
+ relationship_key=Relationships.CHILD,
583
+ create_synthetic_parent=True,
584
+ synthetic_parent=LayoutType.LIST,
585
+ ),
581
586
  ]
582
587
  return MatchingService(
583
588
  family_compounds=family_compounds,
@@ -622,6 +627,34 @@ class ServiceFactory:
622
627
  """
623
628
  return ServiceFactory._build_layout_link_matching_service(config)
624
629
 
630
+ @staticmethod
631
+ def _build_line_matching_service(config: AttrDict) -> MatchingService:
632
+ matcher = IntersectionMatcher(
633
+ matching_rule=config.WORD_MATCHING.RULE,
634
+ threshold=config.WORD_MATCHING.THRESHOLD,
635
+ max_parent_only=config.WORD_MATCHING.MAX_PARENT_ONLY,
636
+ )
637
+ family_compounds = [
638
+ FamilyCompound(
639
+ parent_categories=[LayoutType.LIST],
640
+ child_categories=[LayoutType.LINE],
641
+ relationship_key=Relationships.CHILD,
642
+ ),
643
+ ]
644
+ return MatchingService(
645
+ family_compounds=family_compounds,
646
+ matcher=matcher,
647
+ )
648
+
649
+ @staticmethod
650
+ def build_line_matching_service(config: AttrDict) -> MatchingService:
651
+ """Building a word matching service
652
+
653
+ :param config: configuration object
654
+ :return: MatchingService
655
+ """
656
+ return ServiceFactory._build_line_matching_service(config)
657
+
625
658
  @staticmethod
626
659
  def _build_text_order_service(config: AttrDict) -> TextOrderService:
627
660
  """Building a text order service
@@ -748,6 +781,10 @@ class ServiceFactory:
748
781
  layout_link_matching_service = ServiceFactory.build_layout_link_matching_service(config)
749
782
  pipe_component_list.append(layout_link_matching_service)
750
783
 
784
+ if config.USE_LINE_MATCHER:
785
+ line_list_matching_service = ServiceFactory.build_line_matching_service(config)
786
+ pipe_component_list.append(line_list_matching_service)
787
+
751
788
  page_parsing_service = ServiceFactory.build_page_parsing_service(config)
752
789
 
753
790
  return DoctectionPipe(pipeline_component_list=pipe_component_list, page_parsing_service=page_parsing_service)
@@ -27,7 +27,6 @@ from typing import Any, Optional, Union, no_type_check
27
27
 
28
28
  import numpy as np
29
29
  from numpy import uint8
30
- from numpy.typing import NDArray
31
30
  from pypdf import PdfReader
32
31
 
33
32
  from ..utils.develop import deprecated
@@ -42,8 +41,6 @@ __all__ = [
42
41
  "convert_np_array_to_b64_b",
43
42
  "convert_bytes_to_np_array",
44
43
  "convert_pdf_bytes_to_np_array_v2",
45
- "box_to_point4",
46
- "point4_to_box",
47
44
  "as_dict",
48
45
  ]
49
46
 
@@ -187,24 +184,3 @@ def convert_pdf_bytes_to_np_array_v2(
187
184
  width = shape[2] - shape[0]
188
185
  return pdf_to_np_array(pdf_bytes, size=(int(width), int(height))) # type: ignore
189
186
  return pdf_to_np_array(pdf_bytes, dpi=dpi)
190
-
191
-
192
- def box_to_point4(boxes: NDArray[np.float32]) -> NDArray[np.float32]:
193
- """
194
- :param boxes: nx4
195
- :return: (nx4)x2
196
- """
197
- box = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]]
198
- box = box.reshape((-1, 2))
199
- return box
200
-
201
-
202
- def point4_to_box(points: NDArray[np.float32]) -> NDArray[np.float32]:
203
- """
204
- :param points: (nx4)x2
205
- :return: nx4 boxes (x1y1x2y2)
206
- """
207
- points = points.reshape((-1, 4, 2))
208
- min_xy = points.min(axis=1) # nx2
209
- max_xy = points.max(axis=1) # nx2
210
- return np.concatenate((min_xy, max_xy), axis=1)
@@ -25,7 +25,6 @@ from copy import copy
25
25
  from typing import Any, Mapping, Optional, Sequence, Type, TypedDict, Union, no_type_check
26
26
 
27
27
  import numpy as np
28
- from typing_extensions import LiteralString
29
28
 
30
29
  from ..utils.error import AnnotationError, ImageError
31
30
  from ..utils.logger import LoggingRecord, log_once, logger
@@ -41,12 +40,11 @@ from ..utils.settings import (
41
40
  WordType,
42
41
  get_type,
43
42
  )
44
- from ..utils.transform import ResizeTransform
43
+ from ..utils.transform import ResizeTransform, box_to_point4, point4_to_box
45
44
  from ..utils.types import HTML, AnnotationDict, Chunks, ImageDict, PathLikeOrStr, PixelValues, Text_, csv
46
45
  from ..utils.viz import draw_boxes, interactive_imshow, viz_handler
47
46
  from .annotation import CategoryAnnotation, ContainerAnnotation, ImageAnnotation, ann_from_dict
48
47
  from .box import BoundingBox, crop_box_from_image
49
- from .convert import box_to_point4, point4_to_box
50
48
  from .image import Image
51
49
 
52
50
 
@@ -286,6 +284,52 @@ class Cell(Layout):
286
284
  return set(CellType).union(super().get_attribute_names())
287
285
 
288
286
 
287
+ class List(Layout):
288
+ """
289
+ List specific subclass of `ImageAnnotationBaseView` modelled by `LayoutType`.
290
+ """
291
+
292
+ @property
293
+ def words(self) -> list[ImageAnnotationBaseView]:
294
+ """
295
+ Get a list of `ImageAnnotationBaseView` objects with `LayoutType` defined by `text_container`.
296
+ It will only select those among all annotations that have an entry in `Relationships.child` .
297
+ """
298
+ all_words: list[ImageAnnotationBaseView] = []
299
+
300
+ for list_item in self.list_items:
301
+ all_words.extend(list_item.words) # type: ignore
302
+ return all_words
303
+
304
+ def get_ordered_words(self) -> list[ImageAnnotationBaseView]:
305
+ """Returns a list of words order by reading order. Words with no reading order will not be returned"""
306
+ try:
307
+ list_items = self.list_items
308
+ all_words = []
309
+ list_items.sort(key=lambda x: x.bbox[1])
310
+ for list_item in list_items:
311
+ all_words.extend(list_item.get_ordered_words()) # type: ignore
312
+ return all_words
313
+ except (TypeError, AnnotationError):
314
+ return super().get_ordered_words()
315
+
316
+ @property
317
+ def list_items(self) -> list[ImageAnnotationBaseView]:
318
+ """
319
+ A list of a list items.
320
+ """
321
+ all_relation_ids = self.get_relationship(Relationships.CHILD)
322
+ list_items = self.base_page.get_annotation(
323
+ annotation_ids=all_relation_ids,
324
+ category_names=(
325
+ LayoutType.LIST_ITEM,
326
+ LayoutType.LINE,
327
+ ),
328
+ )
329
+ list_items.sort(key=lambda x: x.bbox[1])
330
+ return list_items
331
+
332
+
289
333
  class Table(Layout):
290
334
  """
291
335
  Table specific sub class of `ImageAnnotationBaseView` modelled by `TableType`.
@@ -373,7 +417,7 @@ class Table(Layout):
373
417
  category_names=[LayoutType.CELL, CellType.SPANNING], annotation_ids=all_relation_ids
374
418
  )
375
419
  row_cells = list(
376
- filter(lambda c: row_number in (c.row_number, c.row_number + c.row_span), all_cells) # type: ignore
420
+ filter(lambda c: c.row_number <= row_number <= c.row_number + c.row_span - 1, all_cells) # type: ignore
377
421
  )
378
422
  row_cells.sort(key=lambda c: c.column_number) # type: ignore
379
423
  column_header_cells = self.column_header_cells
@@ -561,6 +605,7 @@ IMAGE_ANNOTATION_TO_LAYOUTS: dict[ObjectTypes, Type[Union[Layout, Table, Word]]]
561
605
  LayoutType.TABLE_ROTATED: Table,
562
606
  LayoutType.WORD: Word,
563
607
  LayoutType.CELL: Cell,
608
+ LayoutType.LIST: List,
564
609
  CellType.SPANNING: Cell,
565
610
  CellType.ROW_HEADER: Cell,
566
611
  CellType.COLUMN_HEADER: Cell,
@@ -574,6 +619,7 @@ class ImageDefaults(TypedDict):
574
619
  text_container: LayoutType
575
620
  floating_text_block_categories: tuple[Union[LayoutType, CellType], ...]
576
621
  text_block_categories: tuple[Union[LayoutType, CellType], ...]
622
+ residual_layouts: tuple[LayoutType, ...]
577
623
 
578
624
 
579
625
  IMAGE_DEFAULTS: ImageDefaults = {
@@ -592,6 +638,7 @@ IMAGE_DEFAULTS: ImageDefaults = {
592
638
  LayoutType.FIGURE,
593
639
  CellType.SPANNING,
594
640
  ),
641
+ "residual_layouts": (LayoutType.LINE,),
595
642
  }
596
643
 
597
644
 
@@ -771,19 +818,8 @@ class Page(Image):
771
818
  """
772
819
  return self.get_annotation(category_names=self._get_residual_layout())
773
820
 
774
- def _get_residual_layout(self) -> list[LiteralString]:
775
- layouts = copy(list(self.floating_text_block_categories))
776
- layouts.extend(
777
- [
778
- LayoutType.TABLE,
779
- LayoutType.FIGURE,
780
- self.text_container,
781
- LayoutType.CELL,
782
- LayoutType.ROW,
783
- LayoutType.COLUMN,
784
- ]
785
- )
786
- return [layout for layout in LayoutType if layout not in layouts]
821
+ def _get_residual_layout(self) -> tuple[LayoutType, ...]:
822
+ return IMAGE_DEFAULTS["residual_layouts"]
787
823
 
788
824
  @classmethod
789
825
  def from_image(
@@ -369,7 +369,9 @@ class MergeDataset(DatasetBase):
369
369
  self.buffer_datasets(**dataflow_build_kwargs)
370
370
  split_defaultdict = defaultdict(list)
371
371
  for image in self.datapoint_list: # type: ignore
372
- split_defaultdict[ann_id_to_split[image.image_id]].append(image)
372
+ maybe_image_id = ann_id_to_split.get(image.image_id)
373
+ if maybe_image_id is not None:
374
+ split_defaultdict[maybe_image_id].append(image)
373
375
  train_dataset = split_defaultdict["train"]
374
376
  val_dataset = split_defaultdict["val"]
375
377
  test_dataset = split_defaultdict["test"]
@@ -26,6 +26,7 @@ from dataclasses import dataclass, field
26
26
  from types import MappingProxyType
27
27
  from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union, overload
28
28
 
29
+ import numpy as np
29
30
  from lazy_imports import try_import
30
31
 
31
32
  from ..utils.identifier import get_uuid_from_str
@@ -38,6 +39,7 @@ from ..utils.settings import (
38
39
  token_class_tag_to_token_class_with_tag,
39
40
  token_class_with_tag_to_token_class_and_tag,
40
41
  )
42
+ from ..utils.transform import BaseTransform, box_to_point4, point4_to_box
41
43
  from ..utils.types import JsonDict, PixelValues, Requirement
42
44
 
43
45
  if TYPE_CHECKING:
@@ -621,7 +623,7 @@ class ImageTransformer(PredictorBase, ABC):
621
623
  """
622
624
 
623
625
  @abstractmethod
624
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
626
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
625
627
  """
626
628
  Abstract method transform
627
629
  """
@@ -641,3 +643,108 @@ class ImageTransformer(PredictorBase, ABC):
641
643
  def get_category_names(self) -> tuple[ObjectTypes, ...]:
642
644
  """returns category names"""
643
645
  raise NotImplementedError()
646
+
647
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
648
+ """
649
+ Transform coordinates aligned with the transform_image method.
650
+
651
+ :param detect_results: List of DetectionResults
652
+ :return: List of DetectionResults. If you pass uuid it is possible to track the transformed bounding boxes.
653
+ """
654
+
655
+ raise NotImplementedError()
656
+
657
+ def inverse_transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
658
+ """
659
+ Inverse transform coordinates aligned with the transform_image method. Composing transform_coords with
660
+ inverse_transform_coords should return the original coordinates.
661
+
662
+ :param detect_results: List of DetectionResults
663
+ :return: List of DetectionResults. If you pass uuid it is possible to track the transformed bounding boxes.
664
+ """
665
+
666
+ raise NotImplementedError()
667
+
668
+
669
+ class DeterministicImageTransformer(ImageTransformer):
670
+ """
671
+ A wrapper for BaseTransform classes that implements the ImageTransformer interface.
672
+
673
+ This class provides a bridge between the BaseTransform system (which handles image and coordinate
674
+ transformations like rotation, padding, etc.) and the predictors framework by implementing the
675
+ ImageTransformer interface. It allows BaseTransform objects to be used within pipelines that
676
+ expect ImageTransformer components.
677
+
678
+ The transformer performs deterministic transformations on images and their associated coordinates,
679
+ enabling operations like padding, rotation, and other geometric transformations while maintaining
680
+ the relationship between image content and annotation coordinates.
681
+
682
+ :param base_transform: A BaseTransform instance that defines the actual transformation operations
683
+ to be applied to images and coordinates.
684
+ """
685
+
686
+ def __init__(self, base_transform: BaseTransform):
687
+ """
688
+ Initialize the DeterministicImageTransformer with a BaseTransform instance.
689
+
690
+ :param base_transform: A BaseTransform instance that defines the actual transformation operations
691
+ """
692
+ self.base_transform = base_transform
693
+ self.name = base_transform.__class__.__name__
694
+ self.model_id = self.get_model_id()
695
+
696
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
697
+ return self.base_transform.apply_image(np_img)
698
+
699
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
700
+ boxes = np.array([detect_result.box for detect_result in detect_results])
701
+ # boxes = box_to_point4(boxes)
702
+ boxes = self.base_transform.apply_coords(boxes)
703
+ # boxes = point4_to_box(boxes)
704
+ detection_results = []
705
+ for idx, detect_result in enumerate(detect_results):
706
+ detection_results.append(
707
+ DetectionResult(
708
+ box=boxes[idx, :].tolist(),
709
+ class_name=detect_result.class_name,
710
+ class_id=detect_result.class_id,
711
+ score=detect_result.score,
712
+ absolute_coords=detect_result.absolute_coords,
713
+ uuid=detect_result.uuid,
714
+ )
715
+ )
716
+ return detection_results
717
+
718
+ def inverse_transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
719
+ boxes = np.array([detect_result.box for detect_result in detect_results])
720
+ boxes = box_to_point4(boxes)
721
+ boxes = self.base_transform.inverse_apply_coords(boxes)
722
+ boxes = point4_to_box(boxes)
723
+ detection_results = []
724
+ for idx, detect_result in enumerate(detect_results):
725
+ detection_results.append(
726
+ DetectionResult(
727
+ box=boxes[idx, :].tolist(),
728
+ class_id=detect_result.class_id,
729
+ score=detect_result.score,
730
+ absolute_coords=detect_result.absolute_coords,
731
+ uuid=detect_result.uuid,
732
+ )
733
+ )
734
+ return detection_results
735
+
736
+ def clone(self) -> DeterministicImageTransformer:
737
+ return self.__class__(self.base_transform)
738
+
739
+ def predict(self, np_img: PixelValues) -> DetectionResult:
740
+ detect_result = DetectionResult()
741
+ for init_arg in self.base_transform.get_init_args():
742
+ setattr(detect_result, init_arg, getattr(self.base_transform, init_arg))
743
+ return detect_result
744
+
745
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
746
+ return self.base_transform.get_category_names()
747
+
748
+ @classmethod
749
+ def get_requirements(cls) -> list[Requirement]:
750
+ return []
@@ -43,7 +43,7 @@ class Jdeskewer(ImageTransformer):
43
43
  self.model_id = self.get_model_id()
44
44
  self.min_angle_rotation = min_angle_rotation
45
45
 
46
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
46
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
47
47
  """
48
48
  Rotation of the image according to the angle determined by the jdeskew estimator.
49
49
 
@@ -424,7 +424,8 @@ class DoctrTextRecognizer(TextRecognizer):
424
424
  custom_configs.pop("task", None)
425
425
  recognition_configs["mean"] = custom_configs.pop("mean")
426
426
  recognition_configs["std"] = custom_configs.pop("std")
427
- batch_size = custom_configs.pop("batch_size")
427
+ if "batch_size" in custom_configs:
428
+ batch_size = custom_configs.pop("batch_size")
428
429
  recognition_configs["batch_size"] = batch_size
429
430
 
430
431
  if isinstance(architecture, str):
@@ -514,8 +515,9 @@ class DocTrRotationTransformer(ImageTransformer):
514
515
  self.number_contours = number_contours
515
516
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
516
517
  self.name = "doctr_rotation_transformer"
518
+ self.model_id = self.get_model_id()
517
519
 
518
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
520
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
519
521
  """
520
522
  Applies the predicted rotation to the image, effectively rotating the image backwards.
521
523
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -423,7 +423,7 @@ class TesseractRotationTransformer(ImageTransformer):
423
423
  self.categories = ModelCategories(init_categories={1: PageType.ANGLE})
424
424
  self.model_id = self.get_model_id()
425
425
 
426
- def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
426
+ def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
427
427
  """
428
428
  Applies the predicted rotation to the image, effectively rotating the image backwards.
429
429
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -15,9 +15,9 @@ from typing import Any, List, Optional, Tuple
15
15
  import numpy as np
16
16
  from lazy_imports import try_import
17
17
 
18
- from ....datapoint.convert import box_to_point4, point4_to_box
19
18
  from ....utils.error import MalformedData
20
19
  from ....utils.logger import log_once
20
+ from ....utils.transform import box_to_point4, point4_to_box
21
21
  from ....utils.types import JsonDict, PixelValues
22
22
  from .common import filter_boxes_inside_shape, np_iou
23
23
  from .modeling.model_fpn import get_all_anchors_fpn
@@ -102,7 +102,7 @@ def image_to_d2_frcnn_training(
102
102
  return output
103
103
 
104
104
 
105
- def pt_nms_image_annotations(
105
+ def pt_nms_image_annotations_depr(
106
106
  anns: Sequence[ImageAnnotation], threshold: float, image_id: Optional[str] = None, prio: str = ""
107
107
  ) -> Sequence[str]:
108
108
  """
@@ -147,6 +147,69 @@ def pt_nms_image_annotations(
147
147
  return []
148
148
 
149
149
 
150
+ def pt_nms_image_annotations(
151
+ anns: Sequence[ImageAnnotation], threshold: float, image_id: Optional[str] = None, prio: str = ""
152
+ ) -> Sequence[str]:
153
+ """
154
+ Processing given image annotations through NMS. This is useful, if you want to supress some specific image
155
+ annotation, e.g. given by name or returned through different predictors. This is the pt version, for tf check
156
+ `mapper.tpstruct`
157
+
158
+ :param anns: A sequence of ImageAnnotations. All annotations will be treated as if they belong to one category
159
+ :param threshold: NMS threshold
160
+ :param image_id: id in order to get the embedding bounding box
161
+ :param prio: If an annotation has prio, it will overwrite its given score to 1 so that it will never be suppressed
162
+ :return: A list of annotation_ids that belong to the given input sequence and that survive the NMS process
163
+ """
164
+ if len(anns) == 1:
165
+ return [anns[0].annotation_id]
166
+
167
+ if not anns:
168
+ return []
169
+
170
+ # First, identify priority annotations that should always be kept
171
+ priority_ann_ids = []
172
+
173
+ if prio:
174
+ for ann in anns:
175
+ if ann.category_name == prio:
176
+ priority_ann_ids.append(ann.annotation_id)
177
+
178
+ # If all annotations are priority or none are left for NMS, return all priority IDs
179
+ if len(priority_ann_ids) == len(anns):
180
+ return priority_ann_ids
181
+
182
+ def priority_to_confidence(ann: ImageAnnotation, priority: str) -> float:
183
+ if ann.category_name == priority:
184
+ return 1.0
185
+ if ann.score:
186
+ return ann.score
187
+ raise ValueError("score cannot be None")
188
+
189
+ # Perform NMS only on non-priority annotations
190
+ ann_ids = np.array([ann.annotation_id for ann in anns], dtype="object")
191
+
192
+ # Get boxes for non-priority annotations
193
+ boxes = torch.tensor(
194
+ [ann.get_bounding_box(image_id).to_list(mode="xyxy") for ann in anns if ann.bounding_box is not None]
195
+ )
196
+
197
+ scores = torch.tensor([priority_to_confidence(ann, prio) for ann in anns])
198
+ class_mask = torch.ones(len(boxes), dtype=torch.uint8)
199
+
200
+ keep = batched_nms(boxes, scores, class_mask, threshold)
201
+ kept_ids = ann_ids[keep]
202
+
203
+ # Convert to list if necessary
204
+ if isinstance(kept_ids, str):
205
+ kept_ids = [kept_ids]
206
+ elif not isinstance(kept_ids, list):
207
+ kept_ids = kept_ids.tolist()
208
+
209
+ # Combine priority annotations with surviving non-priority annotations
210
+ return list(set(priority_ann_ids + kept_ids))
211
+
212
+
150
213
  def _get_category_attributes(
151
214
  ann: ImageAnnotation, cat_to_sub_cat: Optional[Mapping[ObjectTypes, ObjectTypes]] = None
152
215
  ) -> tuple[ObjectTypes, int, Optional[float]]:
@@ -31,11 +31,10 @@ import numpy.typing as npt
31
31
  from lazy_imports import try_import
32
32
 
33
33
  from ..datapoint.annotation import ContainerAnnotation
34
- from ..datapoint.convert import box_to_point4, point4_to_box
35
34
  from ..datapoint.image import Image
36
35
  from ..datapoint.view import Page
37
36
  from ..utils.settings import DatasetType, LayoutType, PageType, Relationships, WordType
38
- from ..utils.transform import ResizeTransform, normalize_image
37
+ from ..utils.transform import ResizeTransform, box_to_point4, normalize_image, point4_to_box
39
38
  from ..utils.types import JsonDict
40
39
  from .maputils import curry
41
40