deepdoctection 0.37.3__py3-none-any.whl → 0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +1 -1
- deepdoctection/analyzer/_config.py +2 -1
- deepdoctection/analyzer/factory.py +9 -4
- deepdoctection/configs/conf_dd_one.yaml +126 -85
- deepdoctection/datapoint/box.py +2 -4
- deepdoctection/datapoint/image.py +11 -4
- deepdoctection/datapoint/view.py +124 -36
- deepdoctection/extern/hfdetr.py +4 -3
- deepdoctection/pipe/doctectionpipe.py +1 -1
- deepdoctection/pipe/refine.py +6 -13
- deepdoctection/pipe/segment.py +229 -46
- deepdoctection/pipe/sub_layout.py +40 -22
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.38.dist-info}/METADATA +12 -2
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.38.dist-info}/RECORD +17 -17
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.38.dist-info}/WHEEL +1 -1
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.38.dist-info}/LICENSE +0 -0
- {deepdoctection-0.37.3.dist-info → deepdoctection-0.38.dist-info}/top_level.txt +0 -0
deepdoctection/extern/hfdetr.py
CHANGED
|
@@ -41,6 +41,7 @@ with try_import() as tr_import_guard:
|
|
|
41
41
|
from transformers import ( # pylint: disable=W0611
|
|
42
42
|
AutoFeatureExtractor,
|
|
43
43
|
DetrFeatureExtractor,
|
|
44
|
+
DetrImageProcessor,
|
|
44
45
|
PretrainedConfig,
|
|
45
46
|
TableTransformerForObjectDetection,
|
|
46
47
|
)
|
|
@@ -55,7 +56,7 @@ def _detr_post_processing(
|
|
|
55
56
|
def detr_predict_image(
|
|
56
57
|
np_img: PixelValues,
|
|
57
58
|
predictor: TableTransformerForObjectDetection,
|
|
58
|
-
feature_extractor:
|
|
59
|
+
feature_extractor: DetrImageProcessor,
|
|
59
60
|
device: torch.device,
|
|
60
61
|
threshold: float,
|
|
61
62
|
nms_threshold: float,
|
|
@@ -224,13 +225,13 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
|
|
|
224
225
|
)
|
|
225
226
|
|
|
226
227
|
@staticmethod
|
|
227
|
-
def get_pre_processor(path_feature_extractor_config: PathLikeOrStr) ->
|
|
228
|
+
def get_pre_processor(path_feature_extractor_config: PathLikeOrStr) -> DetrImageProcessor:
|
|
228
229
|
"""
|
|
229
230
|
Builds the feature extractor
|
|
230
231
|
|
|
231
232
|
:return: DetrFeatureExtractor
|
|
232
233
|
"""
|
|
233
|
-
return
|
|
234
|
+
return DetrImageProcessor.from_pretrained(
|
|
234
235
|
pretrained_model_name_or_path=os.fspath(path_feature_extractor_config)
|
|
235
236
|
)
|
|
236
237
|
|
|
@@ -188,7 +188,7 @@ class DoctectionPipe(Pipeline):
|
|
|
188
188
|
|
|
189
189
|
df = MapData(df, _proto_process(path, doc_path))
|
|
190
190
|
if dataset_dataflow is None:
|
|
191
|
-
df = MapData(df, _to_image(dpi=os.environ.get("DPI", 300))) # pylint: disable=E1120
|
|
191
|
+
df = MapData(df, _to_image(dpi=int(os.environ.get("DPI", 300)))) # pylint: disable=E1120
|
|
192
192
|
return df
|
|
193
193
|
|
|
194
194
|
@staticmethod
|
deepdoctection/pipe/refine.py
CHANGED
|
@@ -295,28 +295,21 @@ def _html_table(
|
|
|
295
295
|
return html
|
|
296
296
|
|
|
297
297
|
|
|
298
|
-
def generate_html_string(table: ImageAnnotation) -> list[str]:
|
|
298
|
+
def generate_html_string(table: ImageAnnotation, cell_names: Sequence[ObjectTypes]) -> list[str]:
|
|
299
299
|
"""
|
|
300
300
|
Takes the table segmentation by using table cells row number, column numbers etc. and generates a html
|
|
301
301
|
representation.
|
|
302
302
|
|
|
303
303
|
:param table: An annotation that has a not None image and fully segmented cell annotation.
|
|
304
|
+
:param cell_names: List of cell names that are used for the table segmentation. Note: It must be ensured that
|
|
305
|
+
that all cells have a row number, column number, row span and column span and that the dissection
|
|
306
|
+
by rows and columns is completely covered by cells.
|
|
304
307
|
:return: HTML representation of the table
|
|
305
308
|
"""
|
|
306
309
|
if table.image is None:
|
|
307
310
|
raise ImageError("table.image cannot be None")
|
|
308
311
|
table_image = table.image
|
|
309
|
-
cells = table_image.get_annotation(
|
|
310
|
-
category_names=[
|
|
311
|
-
LayoutType.CELL,
|
|
312
|
-
CellType.HEADER,
|
|
313
|
-
CellType.BODY,
|
|
314
|
-
CellType.SPANNING,
|
|
315
|
-
CellType.ROW_HEADER,
|
|
316
|
-
CellType.COLUMN_HEADER,
|
|
317
|
-
CellType.PROJECTED_ROW_HEADER,
|
|
318
|
-
]
|
|
319
|
-
)
|
|
312
|
+
cells = table_image.get_annotation(category_names=cell_names)
|
|
320
313
|
number_of_rows = table_image.summary.get_sub_category(TableType.NUMBER_OF_ROWS).category_id
|
|
321
314
|
number_of_cols = table_image.summary.get_sub_category(TableType.NUMBER_OF_COLUMNS).category_id
|
|
322
315
|
table_list = []
|
|
@@ -485,7 +478,7 @@ class TableSegmentationRefinementService(PipelineComponent):
|
|
|
485
478
|
self.dp_manager.set_summary_annotation(
|
|
486
479
|
TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
|
|
487
480
|
)
|
|
488
|
-
html = generate_html_string(table)
|
|
481
|
+
html = generate_html_string(table, self.cell_names)
|
|
489
482
|
self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
|
|
490
483
|
|
|
491
484
|
def clone(self) -> TableSegmentationRefinementService:
|
deepdoctection/pipe/segment.py
CHANGED
|
@@ -28,13 +28,13 @@ from typing import Literal, Optional, Sequence, Union
|
|
|
28
28
|
import numpy as np
|
|
29
29
|
|
|
30
30
|
from ..datapoint.annotation import ImageAnnotation
|
|
31
|
-
from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_boxes, iou
|
|
31
|
+
from ..datapoint.box import BoundingBox, global_to_local_coords, intersection_box, intersection_boxes, iou, merge_boxes
|
|
32
32
|
from ..datapoint.image import Image
|
|
33
33
|
from ..extern.base import DetectionResult
|
|
34
34
|
from ..mapper.maputils import MappingContextManager
|
|
35
35
|
from ..mapper.match import match_anns_by_intersection
|
|
36
36
|
from ..utils.error import ImageError
|
|
37
|
-
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType
|
|
37
|
+
from ..utils.settings import CellType, LayoutType, ObjectTypes, Relationships, TableType, TypeOrStr, get_type
|
|
38
38
|
from .base import MetaAnnotation, PipelineComponent
|
|
39
39
|
from .refine import generate_html_string
|
|
40
40
|
from .registry import pipeline_component_registry
|
|
@@ -55,6 +55,15 @@ class SegmentationResult:
|
|
|
55
55
|
cs: int
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
@dataclass
|
|
59
|
+
class ItemHeaderResult:
|
|
60
|
+
"""
|
|
61
|
+
Simple mutable storage for item header results
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
annotation_id: str
|
|
65
|
+
|
|
66
|
+
|
|
58
67
|
def choose_items_by_iou(
|
|
59
68
|
dp: Image,
|
|
60
69
|
item_proposals: list[ImageAnnotation],
|
|
@@ -314,7 +323,7 @@ def _tile_by_stretching_rows_leftwise_column_downwise(
|
|
|
314
323
|
|
|
315
324
|
|
|
316
325
|
def tile_tables_with_items_per_table(
|
|
317
|
-
dp: Image, table: ImageAnnotation, item_name:
|
|
326
|
+
dp: Image, table: ImageAnnotation, item_name: ObjectTypes, stretch_rule: Literal["left", "equal"] = "left"
|
|
318
327
|
) -> Image:
|
|
319
328
|
"""
|
|
320
329
|
Tiling a table with items (i.e. rows or columns). To ensure that every position in a table can be assigned to a row
|
|
@@ -355,9 +364,9 @@ def tile_tables_with_items_per_table(
|
|
|
355
364
|
|
|
356
365
|
def stretch_items(
|
|
357
366
|
dp: Image,
|
|
358
|
-
table_name:
|
|
359
|
-
row_name:
|
|
360
|
-
col_name:
|
|
367
|
+
table_name: ObjectTypes,
|
|
368
|
+
row_name: ObjectTypes,
|
|
369
|
+
col_name: ObjectTypes,
|
|
361
370
|
remove_iou_threshold_rows: float,
|
|
362
371
|
remove_iou_threshold_cols: float,
|
|
363
372
|
) -> Image:
|
|
@@ -491,7 +500,7 @@ def create_intersection_cells(
|
|
|
491
500
|
cols: Sequence[ImageAnnotation],
|
|
492
501
|
table_annotation_id: str,
|
|
493
502
|
cell_class_id: int,
|
|
494
|
-
sub_item_names: Sequence[
|
|
503
|
+
sub_item_names: Sequence[ObjectTypes],
|
|
495
504
|
) -> tuple[Sequence[DetectionResult], Sequence[SegmentationResult]]:
|
|
496
505
|
"""
|
|
497
506
|
Given rows and columns with row- and column number sub categories, create a list of `DetectionResult` and
|
|
@@ -511,6 +520,7 @@ def create_intersection_cells(
|
|
|
511
520
|
detect_result_cells = []
|
|
512
521
|
segment_result_cells = []
|
|
513
522
|
idx = 0
|
|
523
|
+
break_outer_loop = False
|
|
514
524
|
for row in rows:
|
|
515
525
|
for col in cols:
|
|
516
526
|
detect_result_cells.append(
|
|
@@ -531,17 +541,59 @@ def create_intersection_cells(
|
|
|
531
541
|
)
|
|
532
542
|
)
|
|
533
543
|
idx += 1
|
|
534
|
-
# it is possible to have less intersection boxes, e.g. if one cell has height/width 0
|
|
544
|
+
# it is possible to have less intersection boxes, e.g. if one cell has height/width 0. We need to break both
|
|
545
|
+
# loops.
|
|
535
546
|
if idx >= len(boxes_cells):
|
|
547
|
+
break_outer_loop = True
|
|
536
548
|
break
|
|
549
|
+
if break_outer_loop:
|
|
550
|
+
break
|
|
537
551
|
return detect_result_cells, segment_result_cells
|
|
538
552
|
|
|
539
553
|
|
|
554
|
+
def header_cell_to_item_detect_result(
|
|
555
|
+
dp: Image,
|
|
556
|
+
table: ImageAnnotation,
|
|
557
|
+
item_name: ObjectTypes,
|
|
558
|
+
item_header_name: ObjectTypes,
|
|
559
|
+
segment_rule: Literal["iou", "ioa"],
|
|
560
|
+
threshold: float,
|
|
561
|
+
) -> list[ItemHeaderResult]:
|
|
562
|
+
"""
|
|
563
|
+
Match header cells to items (rows or columns) based on intersection-over-union (iou) or intersection-over-area (ioa)
|
|
564
|
+
and return a list of ItemHeaderResult.
|
|
565
|
+
|
|
566
|
+
:param dp: The image containing the table and items.
|
|
567
|
+
:param table: The table image annotation.
|
|
568
|
+
:param item_name: The type of items (e.g., rows or columns) to match with header cells.
|
|
569
|
+
:param item_header_name: The type of header cells to match with items.
|
|
570
|
+
:param segment_rule: The rule to use for matching, either 'iou' or 'ioa'.
|
|
571
|
+
:param threshold: The iou/ioa threshold for matching header cells with items.
|
|
572
|
+
:return: A list of ItemHeaderResult containing the matched header cells.
|
|
573
|
+
"""
|
|
574
|
+
child_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
575
|
+
item_index, _, items, _ = match_anns_by_intersection(
|
|
576
|
+
dp,
|
|
577
|
+
item_header_name,
|
|
578
|
+
item_name,
|
|
579
|
+
segment_rule,
|
|
580
|
+
threshold,
|
|
581
|
+
True,
|
|
582
|
+
child_ann_ids,
|
|
583
|
+
child_ann_ids,
|
|
584
|
+
)
|
|
585
|
+
item_headers = []
|
|
586
|
+
for idx, item in enumerate(items):
|
|
587
|
+
if idx in item_index:
|
|
588
|
+
item_headers.append(ItemHeaderResult(annotation_id=item.annotation_id))
|
|
589
|
+
return item_headers
|
|
590
|
+
|
|
591
|
+
|
|
540
592
|
def segment_pubtables(
|
|
541
593
|
dp: Image,
|
|
542
594
|
table: ImageAnnotation,
|
|
543
|
-
item_names: Sequence[
|
|
544
|
-
spanning_cell_names: Sequence[
|
|
595
|
+
item_names: Sequence[ObjectTypes],
|
|
596
|
+
spanning_cell_names: Sequence[ObjectTypes],
|
|
545
597
|
segment_rule: Literal["iou", "ioa"],
|
|
546
598
|
threshold_rows: float,
|
|
547
599
|
threshold_cols: float,
|
|
@@ -553,7 +605,7 @@ def segment_pubtables(
|
|
|
553
605
|
|
|
554
606
|
Row and column positions as well as row and column lengths are determined for all types of spanning cells.
|
|
555
607
|
All simple cells that are covered by a spanning cell as well in the table position (double allocation) are then
|
|
556
|
-
|
|
608
|
+
replaced by the spanning cell and deactivated.
|
|
557
609
|
|
|
558
610
|
:param dp: Image
|
|
559
611
|
:param table: table ImageAnnotation
|
|
@@ -566,6 +618,7 @@ def segment_pubtables(
|
|
|
566
618
|
to the column.
|
|
567
619
|
:return: A list of len(number of cells) of SegmentationResult for spanning cells
|
|
568
620
|
"""
|
|
621
|
+
|
|
569
622
|
child_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
570
623
|
cell_index_rows, row_index, _, _ = match_anns_by_intersection(
|
|
571
624
|
dp,
|
|
@@ -600,29 +653,77 @@ def segment_pubtables(
|
|
|
600
653
|
for idx, cell in enumerate(spanning_cells):
|
|
601
654
|
cell_positions_rows = cell_index_rows == idx
|
|
602
655
|
rows_of_cell = [rows[k] for k in row_index[cell_positions_rows]]
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
row_number =
|
|
656
|
+
if rows_of_cell:
|
|
657
|
+
min_row_cell = min(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
|
|
658
|
+
max_row_cell = max(rows_of_cell, key=lambda row: row.get_sub_category(CellType.ROW_NUMBER).category_id)
|
|
659
|
+
max_row = max_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
|
|
660
|
+
min_row = min_row_cell.get_sub_category(CellType.ROW_NUMBER).category_id
|
|
661
|
+
rs = max_row - min_row + 1
|
|
662
|
+
row_number = min_row
|
|
610
663
|
else:
|
|
664
|
+
rs = 0
|
|
611
665
|
row_number = 0
|
|
612
666
|
|
|
613
667
|
cell_positions_cols = cell_index_cols == idx
|
|
614
668
|
cols_of_cell = [columns[k] for k in col_index[cell_positions_cols]]
|
|
615
|
-
cs = (
|
|
616
|
-
max(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
|
|
617
|
-
- min(col.get_sub_category(CellType.COLUMN_NUMBER).category_id for col in cols_of_cell)
|
|
618
|
-
+ 1
|
|
619
|
-
)
|
|
620
669
|
|
|
621
|
-
if
|
|
622
|
-
|
|
670
|
+
if cols_of_cell:
|
|
671
|
+
min_col_cell = min(
|
|
672
|
+
cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
673
|
+
)
|
|
674
|
+
max_col_cell = max(
|
|
675
|
+
cols_of_cell, key=lambda col: col.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
676
|
+
)
|
|
677
|
+
max_col = max_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
678
|
+
min_col = min_col_cell.get_sub_category(CellType.COLUMN_NUMBER).category_id
|
|
679
|
+
cs = max_col - min_col + 1
|
|
680
|
+
col_number = min_col
|
|
623
681
|
else:
|
|
682
|
+
cs = 0
|
|
624
683
|
col_number = 0
|
|
625
684
|
|
|
685
|
+
if rows_of_cell and cols_of_cell:
|
|
686
|
+
# We resize all bounding boxes of spanning cells so that they match with the grid structure, determined
|
|
687
|
+
# by the rows ans columns.
|
|
688
|
+
merge_box_image_row = merge_boxes(
|
|
689
|
+
*[min_row_cell.get_bounding_box(dp.image_id), max_row_cell.get_bounding_box(dp.image_id)]
|
|
690
|
+
)
|
|
691
|
+
merge_box_image_column = merge_boxes(
|
|
692
|
+
*[min_col_cell.get_bounding_box(dp.image_id), max_col_cell.get_bounding_box(dp.image_id)]
|
|
693
|
+
)
|
|
694
|
+
merge_box_image = intersection_box(merge_box_image_row, merge_box_image_column)
|
|
695
|
+
merge_box_table_row = merge_boxes(
|
|
696
|
+
*[
|
|
697
|
+
min_row_cell.get_bounding_box(table.annotation_id),
|
|
698
|
+
max_row_cell.get_bounding_box(table.annotation_id),
|
|
699
|
+
]
|
|
700
|
+
)
|
|
701
|
+
merge_box_table_column = merge_boxes(
|
|
702
|
+
*[
|
|
703
|
+
min_col_cell.get_bounding_box(table.annotation_id),
|
|
704
|
+
max_col_cell.get_bounding_box(table.annotation_id),
|
|
705
|
+
]
|
|
706
|
+
)
|
|
707
|
+
merge_box_table = intersection_box(merge_box_table_row, merge_box_table_column)
|
|
708
|
+
merge_box_spanning_cell_row = merge_boxes(
|
|
709
|
+
*[
|
|
710
|
+
min_row_cell.get_bounding_box(min_row_cell.annotation_id),
|
|
711
|
+
max_row_cell.get_bounding_box(max_row_cell.annotation_id),
|
|
712
|
+
]
|
|
713
|
+
)
|
|
714
|
+
merge_box_spanning_cell_column = merge_boxes(
|
|
715
|
+
*[
|
|
716
|
+
min_col_cell.get_bounding_box(min_col_cell.annotation_id),
|
|
717
|
+
max_col_cell.get_bounding_box(max_col_cell.annotation_id),
|
|
718
|
+
]
|
|
719
|
+
)
|
|
720
|
+
merge_box_spanning_cell = intersection_box(merge_box_spanning_cell_row, merge_box_spanning_cell_column)
|
|
721
|
+
if cell.image is None:
|
|
722
|
+
raise ImageError("cell.image cannot be None")
|
|
723
|
+
cell.image.set_embedding(dp.image_id, merge_box_image)
|
|
724
|
+
cell.image.set_embedding(table.annotation_id, merge_box_table)
|
|
725
|
+
cell.image.set_embedding(cell.annotation_id, merge_box_spanning_cell)
|
|
726
|
+
|
|
626
727
|
raw_table_segments.append(
|
|
627
728
|
SegmentationResult(
|
|
628
729
|
annotation_id=cell.annotation_id,
|
|
@@ -674,10 +775,10 @@ class TableSegmentationService(PipelineComponent):
|
|
|
674
775
|
tile_table_with_items: bool,
|
|
675
776
|
remove_iou_threshold_rows: float,
|
|
676
777
|
remove_iou_threshold_cols: float,
|
|
677
|
-
table_name:
|
|
678
|
-
cell_names: Sequence[
|
|
679
|
-
item_names: Sequence[
|
|
680
|
-
sub_item_names: Sequence[
|
|
778
|
+
table_name: TypeOrStr,
|
|
779
|
+
cell_names: Sequence[TypeOrStr],
|
|
780
|
+
item_names: Sequence[TypeOrStr],
|
|
781
|
+
sub_item_names: Sequence[TypeOrStr],
|
|
681
782
|
stretch_rule: Literal["left", "equal"] = "left",
|
|
682
783
|
):
|
|
683
784
|
"""
|
|
@@ -705,10 +806,10 @@ class TableSegmentationService(PipelineComponent):
|
|
|
705
806
|
self.tile_table = tile_table_with_items
|
|
706
807
|
self.remove_iou_threshold_rows = remove_iou_threshold_rows
|
|
707
808
|
self.remove_iou_threshold_cols = remove_iou_threshold_cols
|
|
708
|
-
self.table_name = table_name
|
|
709
|
-
self.cell_names = cell_names
|
|
710
|
-
self.item_names = item_names # row names must be before column name
|
|
711
|
-
self.sub_item_names = sub_item_names
|
|
809
|
+
self.table_name = get_type(table_name)
|
|
810
|
+
self.cell_names = [get_type(cell_name) for cell_name in cell_names]
|
|
811
|
+
self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
|
|
812
|
+
self.sub_item_names = [get_type(sub_item_name) for sub_item_name in sub_item_names]
|
|
712
813
|
self.stretch_rule = stretch_rule
|
|
713
814
|
self.item_iou_threshold = 0.0001
|
|
714
815
|
super().__init__("table_segment")
|
|
@@ -876,11 +977,13 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
876
977
|
remove_iou_threshold_rows: float,
|
|
877
978
|
remove_iou_threshold_cols: float,
|
|
878
979
|
cell_class_id: int,
|
|
879
|
-
table_name:
|
|
880
|
-
cell_names: Sequence[
|
|
881
|
-
spanning_cell_names: Sequence[
|
|
882
|
-
item_names: Sequence[
|
|
883
|
-
sub_item_names: Sequence[
|
|
980
|
+
table_name: TypeOrStr,
|
|
981
|
+
cell_names: Sequence[TypeOrStr],
|
|
982
|
+
spanning_cell_names: Sequence[TypeOrStr],
|
|
983
|
+
item_names: Sequence[TypeOrStr],
|
|
984
|
+
sub_item_names: Sequence[TypeOrStr],
|
|
985
|
+
item_header_cell_names: Sequence[TypeOrStr],
|
|
986
|
+
item_header_thresholds: Sequence[float],
|
|
884
987
|
cell_to_image: bool = True,
|
|
885
988
|
crop_cell_image: bool = False,
|
|
886
989
|
stretch_rule: Literal["left", "equal"] = "left",
|
|
@@ -900,6 +1003,11 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
900
1003
|
:param spanning_cell_names: layout type of spanning cells
|
|
901
1004
|
:param item_names: layout type of items (e.g. row and column)
|
|
902
1005
|
:param sub_item_names: layout type of sub items (e.g. row number and column number)
|
|
1006
|
+
:param item_header_cell_names: layout type of item header cells (e.g. CellType.COLUMN_HEADER,
|
|
1007
|
+
CellType.ROW_HEADER). Note that column header, resp. row header will be first assigned to rows, resp. columns
|
|
1008
|
+
and then transferred to cells.
|
|
1009
|
+
:param item_header_thresholds: iou/ioa threshold for matching header cells with items. The first threshold
|
|
1010
|
+
corresponds to matching the first entry of item_names.
|
|
903
1011
|
:param cell_to_image: If set to 'True' it will create an 'Image' for LayoutType.cell
|
|
904
1012
|
:param crop_cell_image: If set to 'True' it will crop a numpy array image for LayoutType.cell.
|
|
905
1013
|
Requires 'cell_to_image=True'
|
|
@@ -909,17 +1017,20 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
909
1017
|
self.threshold_rows = threshold_rows
|
|
910
1018
|
self.threshold_cols = threshold_cols
|
|
911
1019
|
self.tile_table = tile_table_with_items
|
|
912
|
-
self.table_name = table_name
|
|
913
|
-
self.cell_names = cell_names
|
|
914
|
-
self.spanning_cell_names = spanning_cell_names
|
|
1020
|
+
self.table_name = get_type(table_name)
|
|
1021
|
+
self.cell_names = [get_type(cell_name) for cell_name in cell_names]
|
|
1022
|
+
self.spanning_cell_names = [get_type(cell_name) for cell_name in spanning_cell_names]
|
|
915
1023
|
self.remove_iou_threshold_rows = remove_iou_threshold_rows
|
|
916
1024
|
self.remove_iou_threshold_cols = remove_iou_threshold_cols
|
|
917
1025
|
self.cell_class_id = cell_class_id
|
|
918
1026
|
self.cell_to_image = cell_to_image
|
|
919
1027
|
self.crop_cell_image = crop_cell_image
|
|
920
|
-
self.item_names = item_names # row names must be before column name
|
|
921
|
-
self.sub_item_names = sub_item_names
|
|
1028
|
+
self.item_names = [get_type(item_name) for item_name in item_names] # row names must be before column name
|
|
1029
|
+
self.sub_item_names = [get_type(item_name) for item_name in sub_item_names]
|
|
922
1030
|
self.stretch_rule = stretch_rule
|
|
1031
|
+
self.item_header_cell_names = [get_type(item_name) for item_name in item_header_cell_names]
|
|
1032
|
+
self.item_header_thresholds = item_header_thresholds
|
|
1033
|
+
|
|
923
1034
|
super().__init__("table_transformer_segment")
|
|
924
1035
|
|
|
925
1036
|
def serve(self, dp: Image) -> None:
|
|
@@ -932,10 +1043,18 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
932
1043
|
self.remove_iou_threshold_cols,
|
|
933
1044
|
)
|
|
934
1045
|
table_anns = dp.get_annotation(category_names=self.table_name)
|
|
1046
|
+
has_item_headers = {item: False for item in self.item_names}
|
|
935
1047
|
for table in table_anns:
|
|
936
1048
|
item_ann_ids = table.get_relationship(Relationships.CHILD)
|
|
937
|
-
for item_sub_item_name in zip(
|
|
938
|
-
|
|
1049
|
+
for item_sub_item_name in zip(
|
|
1050
|
+
self.item_names, self.sub_item_names, self.item_header_cell_names, self.item_header_thresholds
|
|
1051
|
+
): # one pass for rows and one for cols
|
|
1052
|
+
item_name, sub_item_name, item_header_cell_name, item_header_threshold = (
|
|
1053
|
+
item_sub_item_name[0],
|
|
1054
|
+
item_sub_item_name[1],
|
|
1055
|
+
item_sub_item_name[2],
|
|
1056
|
+
item_sub_item_name[3],
|
|
1057
|
+
)
|
|
939
1058
|
if self.tile_table:
|
|
940
1059
|
dp = tile_tables_with_items_per_table(dp, table, item_name, self.stretch_rule)
|
|
941
1060
|
items = dp.get_annotation(category_names=item_name, annotation_ids=item_ann_ids)
|
|
@@ -949,10 +1068,24 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
949
1068
|
)
|
|
950
1069
|
)
|
|
951
1070
|
|
|
1071
|
+
item_headers_detect_results = header_cell_to_item_detect_result(
|
|
1072
|
+
dp, table, item_name, item_header_cell_name, self.segment_rule, item_header_threshold
|
|
1073
|
+
)
|
|
1074
|
+
if item_headers_detect_results:
|
|
1075
|
+
has_item_headers[item_name] = True
|
|
1076
|
+
|
|
952
1077
|
for item_number, item in enumerate(items, 1):
|
|
953
1078
|
self.dp_manager.set_category_annotation(
|
|
954
1079
|
sub_item_name, item_number, sub_item_name, item.annotation_id
|
|
955
1080
|
)
|
|
1081
|
+
for item_header_detect_result in item_headers_detect_results:
|
|
1082
|
+
self.dp_manager.set_category_annotation(
|
|
1083
|
+
category_name=item_header_cell_name,
|
|
1084
|
+
category_id=None,
|
|
1085
|
+
sub_cat_key=item_header_cell_name,
|
|
1086
|
+
annotation_id=item_header_detect_result.annotation_id,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
956
1089
|
rows = dp.get_annotation(category_names=self.item_names[0], annotation_ids=item_ann_ids)
|
|
957
1090
|
columns = dp.get_annotation(category_names=self.item_names[1], annotation_ids=item_ann_ids)
|
|
958
1091
|
detect_result_cells, segment_result_cells = create_intersection_cells(
|
|
@@ -979,6 +1112,7 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
979
1112
|
CellType.COLUMN_SPAN, segment_result.cs, CellType.COLUMN_SPAN, segment_result.annotation_id
|
|
980
1113
|
)
|
|
981
1114
|
cell_rn_cn_to_ann_id[(segment_result.row_num, segment_result.col_num)] = segment_result.annotation_id
|
|
1115
|
+
|
|
982
1116
|
spanning_cell_raw_segments = segment_pubtables(
|
|
983
1117
|
dp,
|
|
984
1118
|
table,
|
|
@@ -988,7 +1122,15 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
988
1122
|
self.threshold_rows,
|
|
989
1123
|
self.threshold_cols,
|
|
990
1124
|
)
|
|
1125
|
+
|
|
991
1126
|
for segment_result in spanning_cell_raw_segments:
|
|
1127
|
+
if (
|
|
1128
|
+
(segment_result.rs == 1 and segment_result.cs == 1)
|
|
1129
|
+
or segment_result.rs == 0
|
|
1130
|
+
or segment_result.cs == 0
|
|
1131
|
+
):
|
|
1132
|
+
self.dp_manager.deactivate_annotation(segment_result.annotation_id)
|
|
1133
|
+
continue
|
|
992
1134
|
self.dp_manager.set_category_annotation(
|
|
993
1135
|
CellType.ROW_NUMBER, segment_result.row_num, CellType.ROW_NUMBER, segment_result.annotation_id
|
|
994
1136
|
)
|
|
@@ -1009,6 +1151,19 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1009
1151
|
cell_ann_id = cell_rn_cn_to_ann_id[cell_position]
|
|
1010
1152
|
self.dp_manager.deactivate_annotation(cell_ann_id)
|
|
1011
1153
|
|
|
1154
|
+
for segment_result in spanning_cell_raw_segments:
|
|
1155
|
+
if (
|
|
1156
|
+
(segment_result.rs == 1 and segment_result.cs == 1)
|
|
1157
|
+
or segment_result.rs == 0
|
|
1158
|
+
or segment_result.cs == 0
|
|
1159
|
+
):
|
|
1160
|
+
continue
|
|
1161
|
+
for rs in range(segment_result.rs):
|
|
1162
|
+
for cs in range(segment_result.cs):
|
|
1163
|
+
cell_rn_cn_to_ann_id[
|
|
1164
|
+
(segment_result.row_num + rs, segment_result.col_num + cs)
|
|
1165
|
+
] = segment_result.annotation_id
|
|
1166
|
+
|
|
1012
1167
|
cells = []
|
|
1013
1168
|
if table.image:
|
|
1014
1169
|
cells = table.image.get_annotation(category_names=self.cell_names)
|
|
@@ -1022,6 +1177,32 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1022
1177
|
number_of_cols = 0
|
|
1023
1178
|
max_row_span = 0
|
|
1024
1179
|
max_col_span = 0
|
|
1180
|
+
|
|
1181
|
+
for idx, item_vals in enumerate(zip(self.item_names, self.item_header_cell_names, self.sub_item_names)):
|
|
1182
|
+
item_obj_type, item_header_cell_name, sub_item_name = item_vals[0], item_vals[1], item_vals[2]
|
|
1183
|
+
|
|
1184
|
+
if has_item_headers[item_obj_type]:
|
|
1185
|
+
items = dp.get_annotation(category_names=item_obj_type)
|
|
1186
|
+
|
|
1187
|
+
for item_ann in items:
|
|
1188
|
+
if item_header_cell_name in item_ann.sub_categories:
|
|
1189
|
+
item_number = item_ann.get_sub_category(sub_item_name).category_id
|
|
1190
|
+
for key, value in cell_rn_cn_to_ann_id.items():
|
|
1191
|
+
if key[idx] == item_number:
|
|
1192
|
+
cell_ann = dp.get_annotation(annotation_ids=value)[0]
|
|
1193
|
+
self.dp_manager.set_category_annotation(
|
|
1194
|
+
item_header_cell_name,
|
|
1195
|
+
None,
|
|
1196
|
+
item_header_cell_name,
|
|
1197
|
+
cell_ann.annotation_id
|
|
1198
|
+
)
|
|
1199
|
+
else:
|
|
1200
|
+
cell_ann = dp.get_annotation(annotation_ids=value)[0]
|
|
1201
|
+
self.dp_manager.set_category_annotation(item_header_cell_name,
|
|
1202
|
+
None,
|
|
1203
|
+
CellType.BODY,
|
|
1204
|
+
cell_ann.annotation_id)
|
|
1205
|
+
|
|
1025
1206
|
# TODO: the summaries should be sub categories of the underlying ann
|
|
1026
1207
|
self.dp_manager.set_summary_annotation(
|
|
1027
1208
|
TableType.NUMBER_OF_ROWS, TableType.NUMBER_OF_ROWS, number_of_rows, annotation_id=table.annotation_id
|
|
@@ -1038,7 +1219,7 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1038
1219
|
self.dp_manager.set_summary_annotation(
|
|
1039
1220
|
TableType.MAX_COL_SPAN, TableType.MAX_COL_SPAN, max_col_span, annotation_id=table.annotation_id
|
|
1040
1221
|
)
|
|
1041
|
-
html = generate_html_string(table)
|
|
1222
|
+
html = generate_html_string(table, self.cell_names + self.spanning_cell_names)
|
|
1042
1223
|
self.dp_manager.set_container_annotation(TableType.HTML, -1, TableType.HTML, table.annotation_id, html)
|
|
1043
1224
|
|
|
1044
1225
|
def clone(self) -> PubtablesSegmentationService:
|
|
@@ -1055,6 +1236,8 @@ class PubtablesSegmentationService(PipelineComponent):
|
|
|
1055
1236
|
self.spanning_cell_names,
|
|
1056
1237
|
self.item_names,
|
|
1057
1238
|
self.sub_item_names,
|
|
1239
|
+
self.item_header_cell_names,
|
|
1240
|
+
self.item_header_thresholds,
|
|
1058
1241
|
self.cell_to_image,
|
|
1059
1242
|
self.crop_cell_image,
|
|
1060
1243
|
self.stretch_rule,
|
|
@@ -49,27 +49,27 @@ class DetectResultGenerator:
|
|
|
49
49
|
|
|
50
50
|
def __init__(
|
|
51
51
|
self,
|
|
52
|
-
|
|
53
|
-
group_categories: Optional[list[list[
|
|
54
|
-
|
|
52
|
+
categories_name_as_key: Mapping[ObjectTypes, int],
|
|
53
|
+
group_categories: Optional[list[list[ObjectTypes]]] = None,
|
|
54
|
+
exclude_category_names: Optional[Sequence[ObjectTypes]] = None,
|
|
55
55
|
absolute_coords: bool = True,
|
|
56
56
|
) -> None:
|
|
57
57
|
"""
|
|
58
|
-
:param
|
|
58
|
+
:param categories_name_as_key: The dict of all possible detection categories
|
|
59
59
|
:param group_categories: If you only want to generate only one DetectResult for a group of categories, provided
|
|
60
60
|
that the sum of the group is less than one, then you can pass a list of list for
|
|
61
61
|
grouping category ids.
|
|
62
62
|
:param absolute_coords: 'absolute_coords' value to be set in 'DetectionResult'
|
|
63
63
|
"""
|
|
64
|
-
self.
|
|
64
|
+
self.categories_name_as_key = MappingProxyType(dict(categories_name_as_key.items()))
|
|
65
65
|
self.width: Optional[int] = None
|
|
66
66
|
self.height: Optional[int] = None
|
|
67
67
|
if group_categories is None:
|
|
68
|
-
group_categories = [[
|
|
68
|
+
group_categories = [[cat_name] for cat_name in self.categories_name_as_key]
|
|
69
69
|
self.group_categories = group_categories
|
|
70
|
-
if
|
|
71
|
-
|
|
72
|
-
self.
|
|
70
|
+
if exclude_category_names is None:
|
|
71
|
+
exclude_category_names = []
|
|
72
|
+
self.exclude_category_names = exclude_category_names
|
|
73
73
|
self.dummy_for_group_generated = [False for _ in self.group_categories]
|
|
74
74
|
self.absolute_coords = absolute_coords
|
|
75
75
|
|
|
@@ -83,17 +83,17 @@ class DetectResultGenerator:
|
|
|
83
83
|
|
|
84
84
|
if self.width is None and self.height is None:
|
|
85
85
|
raise ValueError("Initialize height and width first")
|
|
86
|
-
|
|
86
|
+
detect_result_list = self._detection_result_sanity_check(detect_result_list)
|
|
87
87
|
count = self._create_condition(detect_result_list)
|
|
88
|
-
for
|
|
89
|
-
if
|
|
90
|
-
if count[
|
|
91
|
-
if not self._dummy_for_group_generated(
|
|
88
|
+
for category_name in self.categories_name_as_key:
|
|
89
|
+
if category_name not in self.exclude_category_names:
|
|
90
|
+
if count[category_name] < 1:
|
|
91
|
+
if not self._dummy_for_group_generated(category_name):
|
|
92
92
|
detect_result_list.append(
|
|
93
93
|
DetectionResult(
|
|
94
94
|
box=[0.0, 0.0, float(self.width), float(self.height)], # type: ignore
|
|
95
|
-
class_id=
|
|
96
|
-
class_name=
|
|
95
|
+
class_id=self.categories_name_as_key[category_name],
|
|
96
|
+
class_name=category_name,
|
|
97
97
|
score=0.0,
|
|
98
98
|
absolute_coords=self.absolute_coords,
|
|
99
99
|
)
|
|
@@ -102,8 +102,8 @@ class DetectResultGenerator:
|
|
|
102
102
|
self.dummy_for_group_generated = self._initialize_dummy_for_group_generated()
|
|
103
103
|
return detect_result_list
|
|
104
104
|
|
|
105
|
-
def _create_condition(self, detect_result_list: list[DetectionResult]) -> dict[
|
|
106
|
-
count = Counter([ann.
|
|
105
|
+
def _create_condition(self, detect_result_list: list[DetectionResult]) -> dict[ObjectTypes, int]:
|
|
106
|
+
count = Counter([ann.class_name for ann in detect_result_list])
|
|
107
107
|
cat_to_group_sum = {}
|
|
108
108
|
for group in self.group_categories:
|
|
109
109
|
group_sum = 0
|
|
@@ -113,9 +113,25 @@ class DetectResultGenerator:
|
|
|
113
113
|
cat_to_group_sum[el] = group_sum
|
|
114
114
|
return cat_to_group_sum
|
|
115
115
|
|
|
116
|
-
|
|
116
|
+
@staticmethod
|
|
117
|
+
def _detection_result_sanity_check(detect_result_list: list[DetectionResult]) -> list[DetectionResult]:
|
|
118
|
+
"""
|
|
119
|
+
Go through each detect_result in the list and check if the box argument has sensible coordinates:
|
|
120
|
+
ulx >= 0 and lrx - ulx >= 0 (same for y coordinate). Remove the detection result if this condition is not
|
|
121
|
+
satisfied. We need this check because if some detection results are not sane, we might end up with some
|
|
122
|
+
none existing categories.
|
|
123
|
+
"""
|
|
124
|
+
sane_detect_results = []
|
|
125
|
+
for detect_result in detect_result_list:
|
|
126
|
+
if detect_result.box:
|
|
127
|
+
ulx, uly, lrx, lry = detect_result.box
|
|
128
|
+
if ulx >= 0 and lrx - ulx >= 0 and uly >= 0 and lry - uly >= 0:
|
|
129
|
+
sane_detect_results.append(detect_result)
|
|
130
|
+
return sane_detect_results
|
|
131
|
+
|
|
132
|
+
def _dummy_for_group_generated(self, category_name: ObjectTypes) -> bool:
|
|
117
133
|
for idx, group in enumerate(self.group_categories):
|
|
118
|
-
if
|
|
134
|
+
if category_name in group:
|
|
119
135
|
is_generated = self.dummy_for_group_generated[idx]
|
|
120
136
|
self.dummy_for_group_generated[idx] = True
|
|
121
137
|
return is_generated
|
|
@@ -176,10 +192,12 @@ class SubImageLayoutService(PipelineComponent):
|
|
|
176
192
|
self.predictor = sub_image_detector
|
|
177
193
|
super().__init__(self._get_name(sub_image_detector.name), self.predictor.model_id)
|
|
178
194
|
if self.detect_result_generator is not None:
|
|
179
|
-
if self.detect_result_generator.
|
|
195
|
+
if self.detect_result_generator.categories_name_as_key != self.predictor.categories.get_categories(
|
|
196
|
+
as_dict=True, name_as_key=True
|
|
197
|
+
):
|
|
180
198
|
raise ValueError(
|
|
181
199
|
f"The categories of the 'detect_result_generator' must be the same as the categories of the "
|
|
182
|
-
f"'sub_image_detector'. Got {self.detect_result_generator.
|
|
200
|
+
f"'sub_image_detector'. Got {self.detect_result_generator.categories_name_as_key} #"
|
|
183
201
|
f"and {self.predictor.categories.get_categories()}."
|
|
184
202
|
)
|
|
185
203
|
|