deepdoctection 0.37.3__py3-none-any.whl → 0.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

@@ -109,8 +109,13 @@ def _proto_process(
109
109
 
110
110
 
111
111
  @curry
112
- def _to_image(dp: Union[str, Mapping[str, Union[str, bytes]]], dpi: Optional[int] = None) -> Optional[Image]:
113
- return to_image(dp, dpi)
112
+ def _to_image(
113
+ dp: Union[str, Mapping[str, Union[str, bytes]]],
114
+ dpi: Optional[int] = None,
115
+ width: Optional[int] = None,
116
+ height: Optional[int] = None,
117
+ ) -> Optional[Image]:
118
+ return to_image(dp, dpi, width, height)
114
119
 
115
120
 
116
121
  def _doc_to_dataflow(path: PathLikeOrStr, max_datapoints: Optional[int] = None) -> DataFlow:
@@ -188,7 +193,19 @@ class DoctectionPipe(Pipeline):
188
193
 
189
194
  df = MapData(df, _proto_process(path, doc_path))
190
195
  if dataset_dataflow is None:
191
- df = MapData(df, _to_image(dpi=os.environ.get("DPI", 300))) # pylint: disable=E1120
196
+ if dpi := os.environ["DPI"]:
197
+ df = MapData(df, _to_image(dpi=int(dpi))) # pylint: disable=E1120
198
+ else:
199
+ width, height = kwargs.get("width", ""), kwargs.get("height", "")
200
+ if not width or not height:
201
+ width = os.environ["IMAGE_WIDTH"]
202
+ height = os.environ["IMAGE_HEIGHT"]
203
+ if not width or not height:
204
+ raise ValueError(
205
+ "DPI, IMAGE_WIDTH and IMAGE_HEIGHT are all None, but "
206
+ "either DPI or IMAGE_WIDTH and IMAGE_HEIGHT must be set"
207
+ )
208
+ df = MapData(df, _to_image(width=int(width), height=int(height))) # pylint: disable=E1120
192
209
  return df
193
210
 
194
211
  @staticmethod
deepdoctection/pipe/lm.py CHANGED
@@ -24,6 +24,7 @@ from copy import copy
24
24
  from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union
25
25
 
26
26
  from ..datapoint.image import Image
27
+ from ..extern.base import SequenceClassResult
27
28
  from ..mapper.laylmstruct import image_to_layoutlm_features, image_to_lm_features
28
29
  from ..utils.settings import BioTag, LayoutType, ObjectTypes, PageType, TokenClasses, WordType
29
30
  from .base import MetaAnnotation, PipelineComponent
@@ -264,6 +265,7 @@ class LMSequenceClassifierService(PipelineComponent):
264
265
  padding: Literal["max_length", "do_not_pad", "longest"] = "max_length",
265
266
  truncation: bool = True,
266
267
  return_overflowing_tokens: bool = False,
268
+ use_other_as_default_category: bool = False
267
269
  ) -> None:
268
270
  """
269
271
  :param tokenizer: Tokenizer, typing allows currently anything. This will be changed in the future
@@ -279,11 +281,16 @@ class LMSequenceClassifierService(PipelineComponent):
279
281
  :param return_overflowing_tokens: If a sequence (due to a truncation strategy) overflows the overflowing tokens
280
282
  can be returned as an additional batch element. Not that in this case, the number of input
281
283
  batch samples will be smaller than the output batch samples.
284
+ :param use_other_as_default_category: When predicting document classes, it might be possible that some pages
285
+ do not get sent to the model because they are empty. If set to `True` it
286
+ will assign images with no features the category `TokenClasses.OTHER`.
287
+
282
288
  """
283
289
  self.language_model = language_model
284
290
  self.padding = padding
285
291
  self.truncation = truncation
286
292
  self.return_overflowing_tokens = return_overflowing_tokens
293
+ self.use_other_as_default_category = use_other_as_default_category
287
294
  self.tokenizer = tokenizer
288
295
  self.mapping_to_lm_input_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
289
296
  super().__init__(self._get_name(), self.language_model.model_id)
@@ -299,12 +306,20 @@ class LMSequenceClassifierService(PipelineComponent):
299
306
 
300
307
  def serve(self, dp: Image) -> None:
301
308
  lm_input = self.mapping_to_lm_input_func(**self.required_kwargs)(dp)
309
+ lm_output = None
302
310
  if lm_input is None:
303
- return
304
- lm_output = self.language_model.predict(**lm_input)
305
- self.dp_manager.set_summary_annotation(
306
- PageType.DOCUMENT_TYPE, lm_output.class_name, lm_output.class_id, None, lm_output.score
307
- )
311
+ if self.use_other_as_default_category:
312
+ class_id = self.language_model.categories.get_categories(as_dict=True,
313
+ name_as_key=True).get(TokenClasses.OTHER, 1)
314
+ lm_output = SequenceClassResult(class_name=TokenClasses.OTHER,
315
+ class_id = class_id,
316
+ score=-1.)
317
+ else:
318
+ lm_output = self.language_model.predict(**lm_input)
319
+ if lm_output:
320
+ self.dp_manager.set_summary_annotation(
321
+ PageType.DOCUMENT_TYPE, lm_output.class_name, lm_output.class_id, None, lm_output.score
322
+ )
308
323
 
309
324
  def clone(self) -> LMSequenceClassifierService:
310
325
  return self.__class__(
@@ -295,28 +295,21 @@ def _html_table(
295
295
  return html
296
296
 
297
297
 
298
- def generate_html_string(table: ImageAnnotation) -> list[str]:
298
+ def generate_html_string(table: ImageAnnotation, cell_names: Sequence[ObjectTypes]) -> list[str]:
299
299
  """
300
300
  Takes the table segmentation by using table cells row number, column numbers etc. and generates a html
301
301
  representation.
302
302
 
303
303
  :param table: An annotation that has a not None image and fully segmented cell annotation.
304
+ :param cell_names: List of cell names that are used for the table segmentation. Note: It must be ensured that
305
+ that all cells have a row number, column number, row span and column span and that the dissection
306
+ by rows and columns is completely covered by cells.
304
307
  :return: HTML representation of the table
305
308
  """
306
309
  if table.image is None:
307
310
  raise ImageError("table.image cannot be None")
308
311
  table_image = table.image
309
- cells = table_image.get_annotation(
310
- category_names=[
311
- LayoutType.CELL,
312
- CellType.HEADER,
313
- CellType.BODY,
314
- CellType.SPANNING,
315
- CellType.ROW_HEADER,
316
- CellType.COLUMN_HEADER,
317
- CellType.PROJECTED_ROW_HEADER,
318
- ]
319
- )
312
+ cells = table_image.get_annotation(category_names=cell_names)
320
313
  number_of_rows = table_image.summary.get_sub_category(TableType.NUMBER_OF_ROWS).category_id
321
314
  number_of_cols = table_image.summary.get_sub_category(TableType.NUMBER_OF_COLUMNS).category_id
322
315
  table_list = []
@@ -485,7 +478,7 @@ class TableSegmentationRefinementService(PipelineComponent):
485
478
  self.dp_manager.set_summary_annotation(
486
479
  TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
487
480
  )
488
- html = generate_html_string(table)
481
+ html = generate_html_string(table, self.cell_names)
489
482
  self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
490
483
 
491
484
  def clone(self) -> TableSegmentationRefinementService:
@@ -28,13 +28,13 @@ from typing import Literal, Optional, Sequence, Union
28
28
  import numpy as np
29
29
 
30
30
  from ..datapoint.annotation import ImageAnnotation
31
- from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_boxes, iou
31
+ from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_box, intersection_boxes, iou, merge_boxes
32
32
  from ..datapoint.image import Image
33
33
  from ..extern.base import DetectionResult
34
34
  from ..mapper.maputils import MappingContextManager
35
35
  from ..mapper.match import match_anns_by_intersection
36
36
  from ..utils.error import ImageError
37
- from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
37
+ from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, TypeOrStr, get_type
38
38
  from .base import MetaAnnotation, PipelineComponent
39
39
  from .refine import generate_html_string
40
40
  from .registry import pipeline_component_registry
@@ -55,6 +55,15 @@ class SegmentationResult:
55
55
  cs: int
56
56
 
57
57
 
58
+ @dataclass
59
+ class ItemHeaderResult:
60
+ """
61
+ Simple mutable storage for item header results
62
+ """
63
+
64
+ annotation_id: str
65
+
66
+
58
67
  def choose_items_by_iou(
59
68
  dp: Image,
60
69
  item_proposals: list[ImageAnnotation],
@@ -314,7 +323,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
314
323
 
315
324
 
316
325
  def tile_tables_with_items_per_table(
317
- dp: Image, table: ImageAnnotation, item_name: str, stretch_rule: Literal["left", "equal"] = "left"
326
+ dp: Image, table: ImageAnnotation, item_name: ObjectTypes, stretch_rule: Literal["left", "equal"] = "left"
318
327
  ) -> Image:
319
328
  """
320
329
  Tiling a table with items (i.e. rows or columns). To ensure that every position in a table can be assigned to a row
@@ -355,9 +364,9 @@ def tile_tables_with_items_per_table(
355
364
 
356
365
  def stretch_items(
357
366
  dp: Image,
358
- table_name: str,
359
- row_name: str,
360
- col_name: str,
367
+ table_name: ObjectTypes,
368
+ row_name: ObjectTypes,
369
+ col_name: ObjectTypes,
361
370
  remove_iou_threshold_rows: float,
362
371
  remove_iou_threshold_cols: float,
363
372
  ) -> Image:
@@ -491,7 +500,7 @@ def create_intersection_cells(
491
500
  cols: Sequence[ImageAnnotation],
492
501
  table_annotation_id: str,
493
502
  cell_class_id: int,
494
- sub_item_names: Sequence[CellType],
503
+ sub_item_names: Sequence[ObjectTypes],
495
504
  ) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
496
505
  """
497
506
  Given rows and columns with row- and column number sub categories, create a list of `DetectionResult` and
@@ -511,6 +520,7 @@ def create_intersection_cells(
511
520
  detect_result_cells = []
512
521
  segment_result_cells = []
513
522
  idx = 0
523
+ break_outer_loop = False
514
524
  for row in rows:
515
525
  for col in cols:
516
526
  detect_result_cells.append(
@@ -531,17 +541,59 @@ def create_intersection_cells(
531
541
  )
532
542
  )
533
543
  idx += 1
534
- # it is possible to have less intersection boxes, e.g. if one cell has height/width 0
544
+ # it is possible to have less intersection boxes, e.g. if one cell has height/width 0. We need to break both
545
+ # loops.
535
546
  if idx >= len(boxes_cells):
547
+ break_outer_loop = True
536
548
  break
549
+ if break_outer_loop:
550
+ break
537
551
  return detect_result_cells, segment_result_cells
538
552
 
539
553
 
554
+ def header_cell_to_item_detect_result(
555
+ dp: Image,
556
+ table: ImageAnnotation,
557
+ item_name: ObjectTypes,
558
+ item_header_name: ObjectTypes,
559
+ segment_rule: Literal["iou", "ioa"],
560
+ threshold: float,
561
+ ) -> list[ItemHeaderResult]:
562
+ """
563
+ Match header cells to items (rows or columns) based on intersection-over-union (iou) or intersection-over-area (ioa)
564
+ and return a list of ItemHeaderResult.
565
+
566
+ :param dp: The image containing the table and items.
567
+ :param table: The table image annotation.
568
+ :param item_name: The type of items (e.g., rows or columns) to match with header cells.
569
+ :param item_header_name: The type of header cells to match with items.
570
+ :param segment_rule: The rule to use for matching, either 'iou' or 'ioa'.
571
+ :param threshold: The iou/ioa threshold for matching header cells with items.
572
+ :return: A list of ItemHeaderResult containing the matched header cells.
573
+ """
574
+ child_ann_ids = table.get_relationship(Relationships.CHILD)
575
+ item_index, _, items, _ = match_anns_by_intersection(
576
+ dp,
577
+ item_header_name,
578
+ item_name,
579
+ segment_rule,
580
+ threshold,
581
+ True,
582
+ child_ann_ids,
583
+ child_ann_ids,
584
+ )
585
+ item_headers = []
586
+ for idx, item in enumerate(items):
587
+ if idx in item_index:
588
+ item_headers.append(ItemHeaderResult(annotation_id=item.annotation_id))
589
+ return item_headers
590
+
591
+
540
592
  def segment_pubtables(
541
593
  dp: Image,
542
594
  table: ImageAnnotation,
543
- item_names: Sequence[LayoutType],
544
- spanning_cell_names: Sequence[Union[LayoutType, CellType]],
595
+ item_names: Sequence[ObjectTypes],
596
+ spanning_cell_names: Sequence[ObjectTypes],
545
597
  segment_rule: Literal["iou", "ioa"],
546
598
  threshold_rows: float,
547
599
  threshold_cols: float,
@@ -553,7 +605,7 @@ def segment_pubtables(
553
605
 
554
606
  Row and column positions as well as row and column lengths are determined for all types of spanning cells.
555
607
  All simple cells that are covered by a spanning cell as well in the table position (double allocation) are then
556
- removed.
608
+ replaced by the spanning cell and deactivated.
557
609
 
558
610
  :param dp: Image
559
611
  :param table: table ImageAnnotation
@@ -566,6 +618,7 @@ def segment_pubtables(
566
618
  to the column.
567
619
  :return: A list of len(number of cells) of SegmentationResult for spanning cells
568
620
  """
621
+
569
622
  child_ann_ids = table.get_relationship(Relationships.CHILD)
570
623
  cell_index_rows, row_index, _, _ = match_anns_by_intersection(
571
624
  dp,
@@ -600,29 +653,77 @@ def segment_pubtables(
600
653
  for idx, cell in enumerate(spanning_cells):
601
654
  cell_positions_rows = cell_index_rows == idx
602
655
  rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
603
- rs = (
604
- max(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
605
- - min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
606
- + 1
607
- )
608
- if len(rows_of_cell):
609
- row_number = min(row.get_sub_category(CellType.ROW_NUMBER).category_id for row in rows_of_cell)
656
+ if rows_of_cell:
657
+ min_row_cell = min(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
658
+ max_row_cell = max(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
659
+ max_row = max_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
660
+ min_row = min_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
661
+ rs = max_row - min_row + 1
662
+ row_number = min_row
610
663
  else:
664
+ rs = 0
611
665
  row_number = 0
612
666
 
613
667
  cell_positions_cols = cell_index_cols == idx
614
668
  cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
615
- cs = (
616
- max(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
617
- - min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
618
- + 1
619
- )
620
669
 
621
- if len(cols_of_cell):
622
- col_number = min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
670
+ if cols_of_cell:
671
+ min_col_cell = min(
672
+ cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
673
+ )
674
+ max_col_cell = max(
675
+ cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
676
+ )
677
+ max_col = max_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
678
+ min_col = min_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
679
+ cs = max_col - min_col + 1
680
+ col_number = min_col
623
681
  else:
682
+ cs = 0
624
683
  col_number = 0
625
684
 
685
+ if rows_of_cell and cols_of_cell:
686
+ # We resize all bounding boxes of spanning cells so that they match with the grid structure, determined
687
+ # by the rows ans columns.
688
+ merge_box_image_row = merge_boxes(
689
+ *[min_row_cell.get_bounding_box(dp.image_id), max_row_cell.get_bounding_box(dp.image_id)]
690
+ )
691
+ merge_box_image_column = merge_boxes(
692
+ *[min_col_cell.get_bounding_box(dp.image_id), max_col_cell.get_bounding_box(dp.image_id)]
693
+ )
694
+ merge_box_image = intersection_box(merge_box_image_row, merge_box_image_column)
695
+ merge_box_table_row = merge_boxes(
696
+ *[
697
+ min_row_cell.get_bounding_box(table.annotation_id),
698
+ max_row_cell.get_bounding_box(table.annotation_id),
699
+ ]
700
+ )
701
+ merge_box_table_column = merge_boxes(
702
+ *[
703
+ min_col_cell.get_bounding_box(table.annotation_id),
704
+ max_col_cell.get_bounding_box(table.annotation_id),
705
+ ]
706
+ )
707
+ merge_box_table = intersection_box(merge_box_table_row, merge_box_table_column)
708
+ merge_box_spanning_cell_row = merge_boxes(
709
+ *[
710
+ min_row_cell.get_bounding_box(min_row_cell.annotation_id),
711
+ max_row_cell.get_bounding_box(max_row_cell.annotation_id),
712
+ ]
713
+ )
714
+ merge_box_spanning_cell_column = merge_boxes(
715
+ *[
716
+ min_col_cell.get_bounding_box(min_col_cell.annotation_id),
717
+ max_col_cell.get_bounding_box(max_col_cell.annotation_id),
718
+ ]
719
+ )
720
+ merge_box_spanning_cell = intersection_box(merge_box_spanning_cell_row, merge_box_spanning_cell_column)
721
+ if cell.image is None:
722
+ raise ImageError("cell.image cannot be None")
723
+ cell.image.set_embedding(dp.image_id, merge_box_image)
724
+ cell.image.set_embedding(table.annotation_id, merge_box_table)
725
+ cell.image.set_embedding(cell.annotation_id, merge_box_spanning_cell)
726
+
626
727
  raw_table_segments.append(
627
728
  SegmentationResult(
628
729
  annotation_id=cell.annotation_id,
@@ -674,10 +775,10 @@ class TableSegmentationService(PipelineComponent):
674
775
  tile_table_with_items: bool,
675
776
  remove_iou_threshold_rows: float,
676
777
  remove_iou_threshold_cols: float,
677
- table_name: ObjectTypes,
678
- cell_names: Sequence[ObjectTypes],
679
- item_names: Sequence[ObjectTypes],
680
- sub_item_names: Sequence[ObjectTypes],
778
+ table_name: TypeOrStr,
779
+ cell_names: Sequence[TypeOrStr],
780
+ item_names: Sequence[TypeOrStr],
781
+ sub_item_names: Sequence[TypeOrStr],
681
782
  stretch_rule: Literal["left", "equal"] = "left",
682
783
  ):
683
784
  """
@@ -705,10 +806,10 @@ class TableSegmentationService(PipelineComponent):
705
806
  self.tile_table = tile_table_with_items
706
807
  self.remove_iou_threshold_rows = remove_iou_threshold_rows
707
808
  self.remove_iou_threshold_cols = remove_iou_threshold_cols
708
- self.table_name = table_name
709
- self.cell_names = cell_names
710
- self.item_names = item_names # row names must be before column name
711
- self.sub_item_names = sub_item_names
809
+ self.table_name = get_type(table_name)
810
+ self.cell_names = [get_type(cell_name) for cell_name in cell_names]
811
+ self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
812
+ self.sub_item_names = [get_type(sub_item_name) for sub_item_name in sub_item_names]
712
813
  self.stretch_rule = stretch_rule
713
814
  self.item_iou_threshold = 0.0001
714
815
  super().__init__("table_segment")
@@ -876,11 +977,13 @@ class PubtablesSegmentationService(PipelineComponent):
876
977
  remove_iou_threshold_rows: float,
877
978
  remove_iou_threshold_cols: float,
878
979
  cell_class_id: int,
879
- table_name: ObjectTypes,
880
- cell_names: Sequence[Union[LayoutType, CellType]],
881
- spanning_cell_names: Sequence[Union[LayoutType, CellType]],
882
- item_names: Sequence[LayoutType],
883
- sub_item_names: Sequence[CellType],
980
+ table_name: TypeOrStr,
981
+ cell_names: Sequence[TypeOrStr],
982
+ spanning_cell_names: Sequence[TypeOrStr],
983
+ item_names: Sequence[TypeOrStr],
984
+ sub_item_names: Sequence[TypeOrStr],
985
+ item_header_cell_names: Sequence[TypeOrStr],
986
+ item_header_thresholds: Sequence[float],
884
987
  cell_to_image: bool = True,
885
988
  crop_cell_image: bool = False,
886
989
  stretch_rule: Literal["left", "equal"] = "left",
@@ -900,6 +1003,11 @@ class PubtablesSegmentationService(PipelineComponent):
900
1003
  :param spanning_cell_names: layout type of spanning cells
901
1004
  :param item_names: layout type of items (e.g. row and column)
902
1005
  :param sub_item_names: layout type of sub items (e.g. row number and column number)
1006
+ :param item_header_cell_names: layout type of item header cells (e.g. CellType.COLUMN_HEADER,
1007
+ CellType.ROW_HEADER). Note that column header, resp. row header will be first assigned to rows, resp. columns
1008
+ and then transferred to cells.
1009
+ :param item_header_thresholds: iou/ioa threshold for matching header cells with items. The first threshold
1010
+ corresponds to matching the first entry of item_names.
903
1011
  :param cell_to_image: If set to 'True' it will create an 'Image' for LayoutType.cell
904
1012
  :param crop_cell_image: If set to 'True' it will crop a numpy array image for LayoutType.cell.
905
1013
  Requires 'cell_to_image=True'
@@ -909,17 +1017,20 @@ class PubtablesSegmentationService(PipelineComponent):
909
1017
  self.threshold_rows = threshold_rows
910
1018
  self.threshold_cols = threshold_cols
911
1019
  self.tile_table = tile_table_with_items
912
- self.table_name = table_name
913
- self.cell_names = cell_names
914
- self.spanning_cell_names = spanning_cell_names
1020
+ self.table_name = get_type(table_name)
1021
+ self.cell_names = [get_type(cell_name) for cell_name in cell_names]
1022
+ self.spanning_cell_names = [get_type(cell_name) for cell_name in spanning_cell_names]
915
1023
  self.remove_iou_threshold_rows = remove_iou_threshold_rows
916
1024
  self.remove_iou_threshold_cols = remove_iou_threshold_cols
917
1025
  self.cell_class_id = cell_class_id
918
1026
  self.cell_to_image = cell_to_image
919
1027
  self.crop_cell_image = crop_cell_image
920
- self.item_names = item_names # row names must be before column name
921
- self.sub_item_names = sub_item_names
1028
+ self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
1029
+ self.sub_item_names = [get_type(item_name) for item_name in sub_item_names]
922
1030
  self.stretch_rule = stretch_rule
1031
+ self.item_header_cell_names = [get_type(item_name) for item_name in item_header_cell_names]
1032
+ self.item_header_thresholds = item_header_thresholds
1033
+
923
1034
  super().__init__("table_transformer_segment")
924
1035
 
925
1036
  def serve(self, dp: Image) -> None:
@@ -932,10 +1043,18 @@ class PubtablesSegmentationService(PipelineComponent):
932
1043
  self.remove_iou_threshold_cols,
933
1044
  )
934
1045
  table_anns = dp.get_annotation(category_names=self.table_name)
1046
+ has_item_headers = {item: False for item in self.item_names}
935
1047
  for table in table_anns:
936
1048
  item_ann_ids = table.get_relationship(Relationships.CHILD)
937
- for item_sub_item_name in zip(self.item_names, self.sub_item_names): # one pass for rows and one for cols
938
- item_name, sub_item_name = item_sub_item_name[0], item_sub_item_name[1]
1049
+ for item_sub_item_name in zip(
1050
+ self.item_names, self.sub_item_names, self.item_header_cell_names, self.item_header_thresholds
1051
+ ): # one pass for rows and one for cols
1052
+ item_name, sub_item_name, item_header_cell_name, item_header_threshold = (
1053
+ item_sub_item_name[0],
1054
+ item_sub_item_name[1],
1055
+ item_sub_item_name[2],
1056
+ item_sub_item_name[3],
1057
+ )
939
1058
  if self.tile_table:
940
1059
  dp = tile_tables_with_items_per_table(dp, table, item_name, self.stretch_rule)
941
1060
  items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
@@ -949,10 +1068,24 @@ class PubtablesSegmentationService(PipelineComponent):
949
1068
  )
950
1069
  )
951
1070
 
1071
+ item_headers_detect_results = header_cell_to_item_detect_result(
1072
+ dp, table, item_name, item_header_cell_name, self.segment_rule, item_header_threshold
1073
+ )
1074
+ if item_headers_detect_results:
1075
+ has_item_headers[item_name] = True
1076
+
952
1077
  for item_number, item in enumerate(items, 1):
953
1078
  self.dp_manager.set_category_annotation(
954
1079
  sub_item_name, item_number, sub_item_name, item.annotation_id
955
1080
  )
1081
+ for item_header_detect_result in item_headers_detect_results:
1082
+ self.dp_manager.set_category_annotation(
1083
+ category_name=item_header_cell_name,
1084
+ category_id=None,
1085
+ sub_cat_key=item_header_cell_name,
1086
+ annotation_id=item_header_detect_result.annotation_id,
1087
+ )
1088
+
956
1089
  rows = dp.get_annotation(category_names=self.item_names[0], annotation_ids=item_ann_ids)
957
1090
  columns = dp.get_annotation(category_names=self.item_names[1], annotation_ids=item_ann_ids)
958
1091
  detect_result_cells, segment_result_cells = create_intersection_cells(
@@ -979,6 +1112,7 @@ class PubtablesSegmentationService(PipelineComponent):
979
1112
  CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
980
1113
  )
981
1114
  cell_rn_cn_to_ann_id[(segment_result.row_num, segment_result.col_num)] = segment_result.annotation_id
1115
+
982
1116
  spanning_cell_raw_segments = segment_pubtables(
983
1117
  dp,
984
1118
  table,
@@ -988,7 +1122,15 @@ class PubtablesSegmentationService(PipelineComponent):
988
1122
  self.threshold_rows,
989
1123
  self.threshold_cols,
990
1124
  )
1125
+
991
1126
  for segment_result in spanning_cell_raw_segments:
1127
+ if (
1128
+ (segment_result.rs == 1 and segment_result.cs == 1)
1129
+ or segment_result.rs == 0
1130
+ or segment_result.cs == 0
1131
+ ):
1132
+ self.dp_manager.deactivate_annotation(segment_result.annotation_id)
1133
+ continue
992
1134
  self.dp_manager.set_category_annotation(
993
1135
  CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
994
1136
  )
@@ -1009,6 +1151,19 @@ class PubtablesSegmentationService(PipelineComponent):
1009
1151
  cell_ann_id = cell_rn_cn_to_ann_id[cell_position]
1010
1152
  self.dp_manager.deactivate_annotation(cell_ann_id)
1011
1153
 
1154
+ for segment_result in spanning_cell_raw_segments:
1155
+ if (
1156
+ (segment_result.rs == 1 and segment_result.cs == 1)
1157
+ or segment_result.rs == 0
1158
+ or segment_result.cs == 0
1159
+ ):
1160
+ continue
1161
+ for rs in range(segment_result.rs):
1162
+ for cs in range(segment_result.cs):
1163
+ cell_rn_cn_to_ann_id[
1164
+ (segment_result.row_num + rs, segment_result.col_num + cs)
1165
+ ] = segment_result.annotation_id
1166
+
1012
1167
  cells = []
1013
1168
  if table.image:
1014
1169
  cells = table.image.get_annotation(category_names=self.cell_names)
@@ -1022,6 +1177,28 @@ class PubtablesSegmentationService(PipelineComponent):
1022
1177
  number_of_cols = 0
1023
1178
  max_row_span = 0
1024
1179
  max_col_span = 0
1180
+
1181
+ for idx, item_vals in enumerate(zip(self.item_names, self.item_header_cell_names, self.sub_item_names)):
1182
+ item_obj_type, item_header_cell_name, sub_item_name = item_vals[0], item_vals[1], item_vals[2]
1183
+
1184
+ if has_item_headers[item_obj_type]:
1185
+ items = dp.get_annotation(category_names=item_obj_type)
1186
+
1187
+ for item_ann in items:
1188
+ if item_header_cell_name in item_ann.sub_categories:
1189
+ item_number = item_ann.get_sub_category(sub_item_name).category_id
1190
+ for key, value in cell_rn_cn_to_ann_id.items():
1191
+ if key[idx] == item_number:
1192
+ cell_ann = dp.get_annotation(annotation_ids=value)[0]
1193
+ self.dp_manager.set_category_annotation(
1194
+ item_header_cell_name, None, item_header_cell_name, cell_ann.annotation_id
1195
+ )
1196
+ else:
1197
+ cell_ann = dp.get_annotation(annotation_ids=value)[0]
1198
+ self.dp_manager.set_category_annotation(
1199
+ item_header_cell_name, None, CellType.BODY, cell_ann.annotation_id
1200
+ )
1201
+
1025
1202
  # TODO: the summaries should be sub categories of the underlying ann
1026
1203
  self.dp_manager.set_summary_annotation(
1027
1204
  TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
@@ -1038,7 +1215,7 @@ class PubtablesSegmentationService(PipelineComponent):
1038
1215
  self.dp_manager.set_summary_annotation(
1039
1216
  TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
1040
1217
  )
1041
- html = generate_html_string(table)
1218
+ html = generate_html_string(table, self.cell_names + self.spanning_cell_names)
1042
1219
  self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
1043
1220
 
1044
1221
  def clone(self) -> PubtablesSegmentationService:
@@ -1055,6 +1232,8 @@ class PubtablesSegmentationService(PipelineComponent):
1055
1232
  self.spanning_cell_names,
1056
1233
  self.item_names,
1057
1234
  self.sub_item_names,
1235
+ self.item_header_cell_names,
1236
+ self.item_header_thresholds,
1058
1237
  self.cell_to_image,
1059
1238
  self.crop_cell_image,
1060
1239
  self.stretch_rule,