deepdoctection 0.37.2__py3-none-any.whl → 0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -28,13 +28,13 @@ from typing import Literal, Optional, Sequence, Union
28
28
  import numpy as np
29
29
 
30
30
  from ..datapoint.annotation import ImageAnnotation
31
- from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_boxes, iou
31
+ from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_box, intersection_boxes, iou, merge_boxes
32
32
  from ..datapoint.image import Image
33
33
  from ..extern.base import DetectionResult
34
34
  from ..mapper.maputils import MappingContextManager
35
35
  from ..mapper.match import match_anns_by_intersection
36
36
  from ..utils.error import ImageError
37
- from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
37
+ from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, TypeOrStr, get_type
38
38
  from .base import MetaAnnotation, PipelineComponent
39
39
  from .refine import generate_html_string
40
40
  from .registry import pipeline_component_registry
@@ -55,6 +55,15 @@ class SegmentationResult:
55
55
  cs: int
56
56
 
57
57
 
58
+ @dataclass
59
+ class ItemHeaderResult:
60
+ """
61
+ Simple mutable storage for item header results
62
+ """
63
+
64
+ annotation_id: str
65
+
66
+
58
67
  def choose_items_by_iou(
59
68
  dp: Image,
60
69
  item_proposals: list[ImageAnnotation],
@@ -314,7 +323,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
314
323
 
315
324
 
316
325
  def tile_tables_with_items_per_table(
317
- dp: Image, table: ImageAnnotation, item_name: str, stretch_rule: Literal["left", "equal"] = "left"
326
+ dp: Image, table: ImageAnnotation, item_name: ObjectTypes, stretch_rule: Literal["left", "equal"] = "left"
318
327
  ) -> Image:
319
328
  """
320
329
  Tiling a table with items (i.e. rows or columns). To ensure that every position in a table can be assigned to a row
@@ -355,9 +364,9 @@ def tile_tables_with_items_per_table(
355
364
 
356
365
  def stretch_items(
357
366
  dp: Image,
358
- table_name: str,
359
- row_name: str,
360
- col_name: str,
367
+ table_name: ObjectTypes,
368
+ row_name: ObjectTypes,
369
+ col_name: ObjectTypes,
361
370
  remove_iou_threshold_rows: float,
362
371
  remove_iou_threshold_cols: float,
363
372
  ) -> Image:
@@ -491,7 +500,7 @@ def create_intersection_cells(
491
500
  cols: Sequence[ImageAnnotation],
492
501
  table_annotation_id: str,
493
502
  cell_class_id: int,
494
- sub_item_names: Sequence[CellType],
503
+ sub_item_names: Sequence[ObjectTypes],
495
504
  ) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
496
505
  """
497
506
  Given rows and columns with row- and column number sub categories, create a list of `DetectionResult` and
@@ -511,6 +520,7 @@ def create_intersection_cells(
511
520
  detect_result_cells = []
512
521
  segment_result_cells = []
513
522
  idx = 0
523
+ break_outer_loop = False
514
524
  for row in rows:
515
525
  for col in cols:
516
526
  detect_result_cells.append(
@@ -531,17 +541,59 @@ def create_intersection_cells(
531
541
  )
532
542
  )
533
543
  idx += 1
534
- # it is possible to have less intersection boxes, e.g. if one cell has height/width 0
544
+ # it is possible to have less intersection boxes, e.g. if one cell has height/width 0. We need to break both
545
+ # loops.
535
546
  if idx >= len(boxes_cells):
547
+ break_outer_loop = True
536
548
  break
549
+ if break_outer_loop:
550
+ break
537
551
  return detect_result_cells, segment_result_cells
538
552
 
539
553
 
554
+ def header_cell_to_item_detect_result(
555
+ dp: Image,
556
+ table: ImageAnnotation,
557
+ item_name: ObjectTypes,
558
+ item_header_name: ObjectTypes,
559
+ segment_rule: Literal["iou", "ioa"],
560
+ threshold: float,
561
+ ) -> list[ItemHeaderResult]:
562
+ """
563
+ Match header cells to items (rows or columns) based on intersection-over-union (iou) or intersection-over-area (ioa)
564
+ and return a list of ItemHeaderResult.
565
+
566
+ :param dp: The image containing the table and items.
567
+ :param table: The table image annotation.
568
+ :param item_name: The type of items (e.g., rows or columns) to match with header cells.
569
+ :param item_header_name: The type of header cells to match with items.
570
+ :param segment_rule: The rule to use for matching, either 'iou' or 'ioa'.
571
+ :param threshold: The iou/ioa threshold for matching header cells with items.
572
+ :return: A list of ItemHeaderResult containing the matched header cells.
573
+ """
574
+ child_ann_ids = table.get_relationship(Relationships.CHILD)
575
+ item_index, _, items, _ = match_anns_by_intersection(
576
+ dp,
577
+ item_header_name,
578
+ item_name,
579
+ segment_rule,
580
+ threshold,
581
+ True,
582
+ child_ann_ids,
583
+ child_ann_ids,
584
+ )
585
+ item_headers = []
586
+ for idx, item in enumerate(items):
587
+ if idx in item_index:
588
+ item_headers.append(ItemHeaderResult(annotation_id=item.annotation_id))
589
+ return item_headers
590
+
591
+
540
592
  def segment_pubtables(
541
593
  dp: Image,
542
594
  table: ImageAnnotation,
543
- item_names: Sequence[LayoutType],
544
- spanning_cell_names: Sequence[Union[LayoutType, CellType]],
595
+ item_names: Sequence[ObjectTypes],
596
+ spanning_cell_names: Sequence[ObjectTypes],
545
597
  segment_rule: Literal["iou", "ioa"],
546
598
  threshold_rows: float,
547
599
  threshold_cols: float,
@@ -553,7 +605,7 @@ def segment_pubtables(
553
605
 
554
606
  Row and column positions as well as row and column lengths are determined for all types of spanning cells.
555
607
  All simple cells that are covered by a spanning cell as well in the table position (double allocation) are then
556
- removed.
608
+ replaced by the spanning cell and deactivated.
557
609
 
558
610
  :param dp: Image
559
611
  :param table: table ImageAnnotation
@@ -566,6 +618,7 @@ def segment_pubtables(
566
618
  to the column.
567
619
  :return: A list of len(number of cells) of SegmentationResult for spanning cells
568
620
  """
621
+
569
622
  child_ann_ids = table.get_relationship(Relationships.CHILD)
570
623
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
571
624
  dp,
@@ -600,29 +653,77 @@ def segment_pubtables(
600
653
  for idx, cell in enumerate(spanning_cells):
601
654
  cell_positions_rows = cell_index_rows == idx
602
655
  rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
603
- rs = (
604
- max(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
605
- - min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
606
- + 1
607
- )
608
- if len(rows_of_cell):
609
- row_number = min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
656
+ if rows_of_cell:
657
+ min_row_cell = min(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
658
+ max_row_cell = max(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
659
+ max_row = max_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
660
+ min_row = min_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
661
+ rs = max_row - min_row + 1
662
+ row_number = min_row
610
663
  else:
664
+ rs = 0
611
665
  row_number = 0
612
666
 
613
667
  cell_positions_cols = cell_index_cols == idx
614
668
  cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
615
- cs = (
616
- max(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
617
- - min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
618
- + 1
619
- )
620
669
 
621
- if len(cols_of_cell):
622
- col_number = min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
670
+ if cols_of_cell:
671
+ min_col_cell = min(
672
+ cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
673
+ )
674
+ max_col_cell = max(
675
+ cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
676
+ )
677
+ max_col = max_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
678
+ min_col = min_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
679
+ cs = max_col - min_col + 1
680
+ col_number = min_col
623
681
  else:
682
+ cs = 0
624
683
  col_number = 0
625
684
 
685
+ if rows_of_cell and cols_of_cell:
686
+ # We resize all bounding boxes of spanning cells so that they match with the grid structure, determined
687
+ # by the rows ans columns.
688
+ merge_box_image_row = merge_boxes(
689
+ *[min_row_cell.get_bounding_box(dp.image_id), max_row_cell.get_bounding_box(dp.image_id)]
690
+ )
691
+ merge_box_image_column = merge_boxes(
692
+ *[min_col_cell.get_bounding_box(dp.image_id), max_col_cell.get_bounding_box(dp.image_id)]
693
+ )
694
+ merge_box_image = intersection_box(merge_box_image_row, merge_box_image_column)
695
+ merge_box_table_row = merge_boxes(
696
+ *[
697
+ min_row_cell.get_bounding_box(table.annotation_id),
698
+ max_row_cell.get_bounding_box(table.annotation_id),
699
+ ]
700
+ )
701
+ merge_box_table_column = merge_boxes(
702
+ *[
703
+ min_col_cell.get_bounding_box(table.annotation_id),
704
+ max_col_cell.get_bounding_box(table.annotation_id),
705
+ ]
706
+ )
707
+ merge_box_table = intersection_box(merge_box_table_row, merge_box_table_column)
708
+ merge_box_spanning_cell_row = merge_boxes(
709
+ *[
710
+ min_row_cell.get_bounding_box(min_row_cell.annotation_id),
711
+ max_row_cell.get_bounding_box(max_row_cell.annotation_id),
712
+ ]
713
+ )
714
+ merge_box_spanning_cell_column = merge_boxes(
715
+ *[
716
+ min_col_cell.get_bounding_box(min_col_cell.annotation_id),
717
+ max_col_cell.get_bounding_box(max_col_cell.annotation_id),
718
+ ]
719
+ )
720
+ merge_box_spanning_cell = intersection_box(merge_box_spanning_cell_row, merge_box_spanning_cell_column)
721
+ if cell.image is None:
722
+ raise ImageError("cell.image cannot be None")
723
+ cell.image.set_embedding(dp.image_id, merge_box_image)
724
+ cell.image.set_embedding(table.annotation_id, merge_box_table)
725
+ cell.image.set_embedding(cell.annotation_id, merge_box_spanning_cell)
726
+
626
727
  raw_table_segments.append(
627
728
  SegmentationResult(
628
729
  annotation_id=cell.annotation_id,
@@ -674,10 +775,10 @@ class TableSegmentationService(PipelineComponent):
674
775
  tile_table_with_items: bool,
675
776
  remove_iou_threshold_rows: float,
676
777
  remove_iou_threshold_cols: float,
677
- table_name: ObjectTypes,
678
- cell_names: Sequence[ObjectTypes],
679
- item_names: Sequence[ObjectTypes],
680
- sub_item_names: Sequence[ObjectTypes],
778
+ table_name: TypeOrStr,
779
+ cell_names: Sequence[TypeOrStr],
780
+ item_names: Sequence[TypeOrStr],
781
+ sub_item_names: Sequence[TypeOrStr],
681
782
  stretch_rule: Literal["left", "equal"] = "left",
682
783
  ):
683
784
  """
@@ -705,10 +806,10 @@ class TableSegmentationService(PipelineComponent):
705
806
  self.tile_table = tile_table_with_items
706
807
  self.remove_iou_threshold_rows = remove_iou_threshold_rows
707
808
  self.remove_iou_threshold_cols = remove_iou_threshold_cols
708
- self.table_name = table_name
709
- self.cell_names = cell_names
710
- self.item_names = item_names # row names must be before column name
711
- self.sub_item_names = sub_item_names
809
+ self.table_name = get_type(table_name)
810
+ self.cell_names = [get_type(cell_name) for cell_name in cell_names]
811
+ self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
812
+ self.sub_item_names = [get_type(sub_item_name) for sub_item_name in sub_item_names]
712
813
  self.stretch_rule = stretch_rule
713
814
  self.item_iou_threshold = 0.0001
714
815
  super().__init__("table_segment")
@@ -876,11 +977,13 @@ class PubtablesSegmentationService(PipelineComponent):
876
977
  remove_iou_threshold_rows: float,
877
978
  remove_iou_threshold_cols: float,
878
979
  cell_class_id: int,
879
- table_name: ObjectTypes,
880
- cell_names: Sequence[Union[LayoutType, CellType]],
881
- spanning_cell_names: Sequence[Union[LayoutType, CellType]],
882
- item_names: Sequence[LayoutType],
883
- sub_item_names: Sequence[CellType],
980
+ table_name: TypeOrStr,
981
+ cell_names: Sequence[TypeOrStr],
982
+ spanning_cell_names: Sequence[TypeOrStr],
983
+ item_names: Sequence[TypeOrStr],
984
+ sub_item_names: Sequence[TypeOrStr],
985
+ item_header_cell_names: Sequence[TypeOrStr],
986
+ item_header_thresholds: Sequence[float],
884
987
  cell_to_image: bool = True,
885
988
  crop_cell_image: bool = False,
886
989
  stretch_rule: Literal["left", "equal"] = "left",
@@ -900,6 +1003,11 @@ class PubtablesSegmentationService(PipelineComponent):
900
1003
  :param spanning_cell_names: layout type of spanning cells
901
1004
  :param item_names: layout type of items (e.g. row and column)
902
1005
  :param sub_item_names: layout type of sub items (e.g. row number and column number)
1006
+ :param item_header_cell_names: layout type of item header cells (e.g. CellType.COLUMN_HEADER,
1007
+ CellType.ROW_HEADER). Note that column header, resp. row header will be first assigned to rows, resp. columns
1008
+ and then transferred to cells.
1009
+ :param item_header_thresholds: iou/ioa threshold for matching header cells with items. The first threshold
1010
+ corresponds to matching the first entry of item_names.
903
1011
  :param cell_to_image: If set to 'True' it will create an 'Image' for LayoutType.cell
904
1012
  :param crop_cell_image: If set to 'True' it will crop a numpy array image for LayoutType.cell.
905
1013
  Requires 'cell_to_image=True'
@@ -909,17 +1017,20 @@ class PubtablesSegmentationService(PipelineComponent):
909
1017
  self.threshold_rows = threshold_rows
910
1018
  self.threshold_cols = threshold_cols
911
1019
  self.tile_table = tile_table_with_items
912
- self.table_name = table_name
913
- self.cell_names = cell_names
914
- self.spanning_cell_names = spanning_cell_names
1020
+ self.table_name = get_type(table_name)
1021
+ self.cell_names = [get_type(cell_name) for cell_name in cell_names]
1022
+ self.spanning_cell_names = [get_type(cell_name) for cell_name in spanning_cell_names]
915
1023
  self.remove_iou_threshold_rows = remove_iou_threshold_rows
916
1024
  self.remove_iou_threshold_cols = remove_iou_threshold_cols
917
1025
  self.cell_class_id = cell_class_id
918
1026
  self.cell_to_image = cell_to_image
919
1027
  self.crop_cell_image = crop_cell_image
920
- self.item_names = item_names # row names must be before column name
921
- self.sub_item_names = sub_item_names
1028
+ self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
1029
+ self.sub_item_names = [get_type(item_name) for item_name in sub_item_names]
922
1030
  self.stretch_rule = stretch_rule
1031
+ self.item_header_cell_names = [get_type(item_name) for item_name in item_header_cell_names]
1032
+ self.item_header_thresholds = item_header_thresholds
1033
+
923
1034
  super().__init__("table_transformer_segment")
924
1035
 
925
1036
  def serve(self, dp: Image) -> None:
@@ -932,10 +1043,18 @@ class PubtablesSegmentationService(PipelineComponent):
932
1043
  self.remove_iou_threshold_cols,
933
1044
  )
934
1045
  table_anns = dp.get_annotation(category_names=self.table_name)
1046
+ has_item_headers = {item: False for item in self.item_names}
935
1047
  for table in table_anns:
936
1048
  item_ann_ids = table.get_relationship(Relationships.CHILD)
937
- for item_sub_item_name in zip(self.item_names, self.sub_item_names): # one pass for rows and one for cols
938
- item_name, sub_item_name = item_sub_item_name[0], item_sub_item_name[1]
1049
+ for item_sub_item_name in zip(
1050
+ self.item_names, self.sub_item_names, self.item_header_cell_names, self.item_header_thresholds
1051
+ ): # one pass for rows and one for cols
1052
+ item_name, sub_item_name, item_header_cell_name, item_header_threshold = (
1053
+ item_sub_item_name[0],
1054
+ item_sub_item_name[1],
1055
+ item_sub_item_name[2],
1056
+ item_sub_item_name[3],
1057
+ )
939
1058
  if self.tile_table:
940
1059
  dp = tile_tables_with_items_per_table(dp, table, item_name, self.stretch_rule)
941
1060
  items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
@@ -949,10 +1068,24 @@ class PubtablesSegmentationService(PipelineComponent):
949
1068
  )
950
1069
  )
951
1070
 
1071
+ item_headers_detect_results = header_cell_to_item_detect_result(
1072
+ dp, table, item_name, item_header_cell_name, self.segment_rule, item_header_threshold
1073
+ )
1074
+ if item_headers_detect_results:
1075
+ has_item_headers[item_name] = True
1076
+
952
1077
  for item_number, item in enumerate(items, 1):
953
1078
  self.dp_manager.set_category_annotation(
954
1079
  sub_item_name, item_number, sub_item_name, item.annotation_id
955
1080
  )
1081
+ for item_header_detect_result in item_headers_detect_results:
1082
+ self.dp_manager.set_category_annotation(
1083
+ category_name=item_header_cell_name,
1084
+ category_id=None,
1085
+ sub_cat_key=item_header_cell_name,
1086
+ annotation_id=item_header_detect_result.annotation_id,
1087
+ )
1088
+
956
1089
  rows = dp.get_annotation(category_names=self.item_names[0], annotation_ids=item_ann_ids)
957
1090
  columns = dp.get_annotation(category_names=self.item_names[1], annotation_ids=item_ann_ids)
958
1091
  detect_result_cells, segment_result_cells = create_intersection_cells(
@@ -979,6 +1112,7 @@ class PubtablesSegmentationService(PipelineComponent):
979
1112
  CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
980
1113
  )
981
1114
  cell_rn_cn_to_ann_id[(segment_result.row_num, segment_result.col_num)] = segment_result.annotation_id
1115
+
982
1116
  spanning_cell_raw_segments = segment_pubtables(
983
1117
  dp,
984
1118
  table,
@@ -988,7 +1122,15 @@ class PubtablesSegmentationService(PipelineComponent):
988
1122
  self.threshold_rows,
989
1123
  self.threshold_cols,
990
1124
  )
1125
+
991
1126
  for segment_result in spanning_cell_raw_segments:
1127
+ if (
1128
+ (segment_result.rs == 1 and segment_result.cs == 1)
1129
+ or segment_result.rs == 0
1130
+ or segment_result.cs == 0
1131
+ ):
1132
+ self.dp_manager.deactivate_annotation(segment_result.annotation_id)
1133
+ continue
992
1134
  self.dp_manager.set_category_annotation(
993
1135
  CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
994
1136
  )
@@ -1009,6 +1151,19 @@ class PubtablesSegmentationService(PipelineComponent):
1009
1151
  cell_ann_id = cell_rn_cn_to_ann_id[cell_position]
1010
1152
  self.dp_manager.deactivate_annotation(cell_ann_id)
1011
1153
 
1154
+ for segment_result in spanning_cell_raw_segments:
1155
+ if (
1156
+ (segment_result.rs == 1 and segment_result.cs == 1)
1157
+ or segment_result.rs == 0
1158
+ or segment_result.cs == 0
1159
+ ):
1160
+ continue
1161
+ for rs in range(segment_result.rs):
1162
+ for cs in range(segment_result.cs):
1163
+ cell_rn_cn_to_ann_id[
1164
+ (segment_result.row_num + rs, segment_result.col_num + cs)
1165
+ ] = segment_result.annotation_id
1166
+
1012
1167
  cells = []
1013
1168
  if table.image:
1014
1169
  cells = table.image.get_annotation(category_names=self.cell_names)
@@ -1022,6 +1177,32 @@ class PubtablesSegmentationService(PipelineComponent):
1022
1177
  number_of_cols = 0
1023
1178
  max_row_span = 0
1024
1179
  max_col_span = 0
1180
+
1181
+ for idx, item_vals in enumerate(zip(self.item_names, self.item_header_cell_names, self.sub_item_names)):
1182
+ item_obj_type, item_header_cell_name, sub_item_name = item_vals[0], item_vals[1], item_vals[2]
1183
+
1184
+ if has_item_headers[item_obj_type]:
1185
+ items = dp.get_annotation(category_names=item_obj_type)
1186
+
1187
+ for item_ann in items:
1188
+ if item_header_cell_name in item_ann.sub_categories:
1189
+ item_number = item_ann.get_sub_category(sub_item_name).category_id
1190
+ for key, value in cell_rn_cn_to_ann_id.items():
1191
+ if key[idx] == item_number:
1192
+ cell_ann = dp.get_annotation(annotation_ids=value)[0]
1193
+ self.dp_manager.set_category_annotation(
1194
+ item_header_cell_name,
1195
+ None,
1196
+ item_header_cell_name,
1197
+ cell_ann.annotation_id
1198
+ )
1199
+ else:
1200
+ cell_ann = dp.get_annotation(annotation_ids=value)[0]
1201
+ self.dp_manager.set_category_annotation(item_header_cell_name,
1202
+ None,
1203
+ CellType.BODY,
1204
+ cell_ann.annotation_id)
1205
+
1025
1206
  # TODO: the summaries should be sub categories of the underlying ann
1026
1207
  self.dp_manager.set_summary_annotation(
1027
1208
  TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
@@ -1038,7 +1219,7 @@ class PubtablesSegmentationService(PipelineComponent):
1038
1219
  self.dp_manager.set_summary_annotation(
1039
1220
  TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
1040
1221
  )
1041
- html = generate_html_string(table)
1222
+ html = generate_html_string(table, self.cell_names + self.spanning_cell_names)
1042
1223
  self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
1043
1224
 
1044
1225
  def clone(self) -> PubtablesSegmentationService:
@@ -1055,6 +1236,8 @@ class PubtablesSegmentationService(PipelineComponent):
1055
1236
  self.spanning_cell_names,
1056
1237
  self.item_names,
1057
1238
  self.sub_item_names,
1239
+ self.item_header_cell_names,
1240
+ self.item_header_thresholds,
1058
1241
  self.cell_to_image,
1059
1242
  self.crop_cell_image,
1060
1243
  self.stretch_rule,
@@ -49,27 +49,27 @@ class DetectResultGenerator:
49
49
 
50
50
  def __init__(
51
51
  self,
52
- categories: Mapping[int, ObjectTypes],
53
- group_categories: Optional[list[list[int]]] = None,
54
- exclude_category_ids: Optional[Sequence[int]] = None,
52
+ categories_name_as_key: Mapping[ObjectTypes, int],
53
+ group_categories: Optional[list[list[ObjectTypes]]] = None,
54
+ exclude_category_names: Optional[Sequence[ObjectTypes]] = None,
55
55
  absolute_coords: bool = True,
56
56
  ) -> None:
57
57
  """
58
- :param categories: The dict of all possible detection categories
58
+ :param categories_name_as_key: The dict of all possible detection categories
59
59
  :param group_categories: If you only want to generate only one DetectResult for a group of categories, provided
60
60
  that the sum of the group is less than one, then you can pass a list of list for
61
61
  grouping category ids.
62
62
  :param absolute_coords: 'absolute_coords' value to be set in 'DetectionResult'
63
63
  """
64
- self.categories = MappingProxyType(dict(categories.items()))
64
+ self.categories_name_as_key = MappingProxyType(dict(categories_name_as_key.items()))
65
65
  self.width: Optional[int] = None
66
66
  self.height: Optional[int] = None
67
67
  if group_categories is None:
68
- group_categories = [[idx] for idx in self.categories]
68
+ group_categories = [[cat_name] for cat_name in self.categories_name_as_key]
69
69
  self.group_categories = group_categories
70
- if exclude_category_ids is None:
71
- exclude_category_ids = []
72
- self.exclude_category_ids = exclude_category_ids
70
+ if exclude_category_names is None:
71
+ exclude_category_names = []
72
+ self.exclude_category_names = exclude_category_names
73
73
  self.dummy_for_group_generated = [False for _ in self.group_categories]
74
74
  self.absolute_coords = absolute_coords
75
75
 
@@ -83,17 +83,17 @@ class DetectResultGenerator:
83
83
 
84
84
  if self.width is None and self.height is None:
85
85
  raise ValueError("Initialize height and width first")
86
-
86
+ detect_result_list = self._detection_result_sanity_check(detect_result_list)
87
87
  count = self._create_condition(detect_result_list)
88
- for category_id in self.categories:
89
- if category_id not in self.exclude_category_ids:
90
- if count[category_id] < 1:
91
- if not self._dummy_for_group_generated(category_id):
88
+ for category_name in self.categories_name_as_key:
89
+ if category_name not in self.exclude_category_names:
90
+ if count[category_name] < 1:
91
+ if not self._dummy_for_group_generated(category_name):
92
92
  detect_result_list.append(
93
93
  DetectionResult(
94
94
  box=[0.0, 0.0, float(self.width), float(self.height)], # type: ignore
95
- class_id=int(category_id),
96
- class_name=self.categories[category_id],
95
+ class_id=self.categories_name_as_key[category_name],
96
+ class_name=category_name,
97
97
  score=0.0,
98
98
  absolute_coords=self.absolute_coords,
99
99
  )
@@ -102,8 +102,8 @@ class DetectResultGenerator:
102
102
  self.dummy_for_group_generated = self._initialize_dummy_for_group_generated()
103
103
  return detect_result_list
104
104
 
105
- def _create_condition(self, detect_result_list: list[DetectionResult]) -> dict[int, int]:
106
- count = Counter([ann.class_id for ann in detect_result_list])
105
+ def _create_condition(self, detect_result_list: list[DetectionResult]) -> dict[ObjectTypes, int]:
106
+ count = Counter([ann.class_name for ann in detect_result_list])
107
107
  cat_to_group_sum = {}
108
108
  for group in self.group_categories:
109
109
  group_sum = 0
@@ -113,9 +113,25 @@ class DetectResultGenerator:
113
113
  cat_to_group_sum[el] = group_sum
114
114
  return cat_to_group_sum
115
115
 
116
- def _dummy_for_group_generated(self, category_id: int) -> bool:
116
+ @staticmethod
117
+ def _detection_result_sanity_check(detect_result_list: list[DetectionResult]) -> list[DetectionResult]:
118
+ """
119
+ Go through each detect_result in the list and check if the box argument has sensible coordinates:
120
+ ulx >= 0 and lrx - ulx >= 0 (same for y coordinate). Remove the detection result if this condition is not
121
+ satisfied. We need this check because if some detection results are not sane, we might end up with some
122
+ none existing categories.
123
+ """
124
+ sane_detect_results = []
125
+ for detect_result in detect_result_list:
126
+ if detect_result.box:
127
+ ulx, uly, lrx, lry = detect_result.box
128
+ if ulx >= 0 and lrx - ulx >= 0 and uly >= 0 and lry - uly >= 0:
129
+ sane_detect_results.append(detect_result)
130
+ return sane_detect_results
131
+
132
+ def _dummy_for_group_generated(self, category_name: ObjectTypes) -> bool:
117
133
  for idx, group in enumerate(self.group_categories):
118
- if category_id in group:
134
+ if category_name in group:
119
135
  is_generated = self.dummy_for_group_generated[idx]
120
136
  self.dummy_for_group_generated[idx] = True
121
137
  return is_generated
@@ -176,10 +192,12 @@ class SubImageLayoutService(PipelineComponent):
176
192
  self.predictor = sub_image_detector
177
193
  super().__init__(self._get_name(sub_image_detector.name), self.predictor.model_id)
178
194
  if self.detect_result_generator is not None:
179
- if self.detect_result_generator.categories != self.predictor.categories.get_categories():
195
+ if self.detect_result_generator.categories_name_as_key != self.predictor.categories.get_categories(
196
+ as_dict=True, name_as_key=True
197
+ ):
180
198
  raise ValueError(
181
199
  f"The categories of the 'detect_result_generator' must be the same as the categories of the "
182
- f"'sub_image_detector'. Got {self.detect_result_generator.categories} #"
200
+ f"'sub_image_detector'. Got {self.detect_result_generator.categories_name_as_key} #"
183
201
  f"and {self.predictor.categories.get_categories()}."
184
202
  )
185
203
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: deepdoctection
3
- Version: 0.37.2
3
+ Version: 0.38
4
4
  Summary: Repository for Document AI
5
5
  Home-page: https://github.com/deepdoctection/deepdoctection
6
6
  Author: Dr. Janis Meyer
@@ -127,6 +127,16 @@ Requires-Dist: types-urllib3>=1.26.25.14; extra == "dev"
127
127
  Provides-Extra: test
128
128
  Requires-Dist: pytest==8.0.2; extra == "test"
129
129
  Requires-Dist: pytest-cov; extra == "test"
130
+ Dynamic: author
131
+ Dynamic: classifier
132
+ Dynamic: description
133
+ Dynamic: description-content-type
134
+ Dynamic: home-page
135
+ Dynamic: license
136
+ Dynamic: provides-extra
137
+ Dynamic: requires-dist
138
+ Dynamic: requires-python
139
+ Dynamic: summary
130
140
 
131
141
 
132
142
  <p align="center">